Yuchan commited on
Commit
06bcf92
ยท
verified ยท
1 Parent(s): bd22708

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +3 -24
Inference.py CHANGED
@@ -148,29 +148,10 @@ class CrossBlock(layers.Layer):
148
  super().__init__()
149
  self.clip_value = clip_value
150
  self.eps = eps
 
151
  # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
152
  def call(self, x, z):
153
- # a์˜ shape: (Batch, Seq_len, D_model)
154
- g_q = (tf.nn.tanh(x) + 1.0) / 2.0
155
- g_k = (tf.nn.tanh(z) + 1.0) / 2.0
156
- score = (g_q * g_k)
157
- score = tf.cumsum(score, axis=1)
158
-
159
- seq_len = tf.shape(score)[1]
160
- # [1, 2, 3, ..., L]์„ D_model ์ฐจ์›์œผ๋กœ ํ™•์žฅ
161
- count_for_mean = tf.cast(tf.range(seq_len) + 1, score.dtype)
162
- count_for_mean = tf.reshape(count_for_mean, (1, seq_len, 1))
163
-
164
- # ๋ˆ„์ ํ•ฉ์„ ํ˜„์žฌ๊นŒ์ง€์˜ ํ† ํฐ ๊ฐœ์ˆ˜๋กœ ๋‚˜๋ˆ„์–ด ํ‰๊ท  ๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ (B, L, D)
165
- score_mean = score / count_for_mean
166
-
167
- # ์ •๊ทœํ™” ๋ถ„๋ชจ ์„ค์ •
168
- denom = tf.maximum(score_mean, self.eps)
169
- score_norm = score / denom
170
- # -----------------------------------------------
171
-
172
- score_clipped = tf.clip_by_value(score_norm, -self.clip_value, self.clip_value)
173
- y = score_clipped * z
174
  return y
175
 
176
  class LoU(layers.Layer):
@@ -182,7 +163,7 @@ class LoU(layers.Layer):
182
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
183
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
184
 
185
- self.glu = SwiGLU(d_model, 320)
186
  self.cross = CrossBlock()
187
 
188
  def call(self, x, z):
@@ -196,8 +177,6 @@ class LoU(layers.Layer):
196
  out = self.cross(out, z)
197
  out = self.glu(out)
198
  return tf.cast(out, x.dtype)
199
-
200
-
201
  # =======================
202
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
203
  # =======================
 
148
  super().__init__()
149
  self.clip_value = clip_value
150
  self.eps = eps
151
+ self.attn = layers.MultiHeadAttention(8, 20)
152
  # ๐Ÿ’ก ์ˆ˜์ •: ์ถœ๋ ฅ ์ฐจ์›์„ 1์—์„œ d_model๋กœ ๋ณ€๊ฒฝ
153
  def call(self, x, z):
154
+ y = self.attn(x, z, z)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return y
156
 
157
  class LoU(layers.Layer):
 
163
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
164
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
165
 
166
+ self.glu = SwiGLU(d_model, 350)
167
  self.cross = CrossBlock()
168
 
169
  def call(self, x, z):
 
177
  out = self.cross(out, z)
178
  out = self.glu(out)
179
  return tf.cast(out, x.dtype)
 
 
180
  # =======================
181
  # 4) AlphaS2S ๋ชจ๋ธ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
182
  # =======================