Yuchan commited on
Commit
8290c73
·
verified ·
1 Parent(s): 76d2b30

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +12 -4
Mo.py CHANGED
@@ -124,6 +124,17 @@ class SwiGLU(layers.Layer):
124
  out = self.W1(tf.nn.silu(a) * b)
125
  return tf.cast(out, x.dtype)
126
 
 
 
 
 
 
 
 
 
 
 
 
127
  class LoU(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
@@ -137,9 +148,7 @@ class LoU(layers.Layer):
137
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
138
 
139
  self.glu = SwiGLU(d_model, 320)
140
- self.cross = CrossBlock()
141
-
142
- def call(self, x, z):
143
  x_f32 = tf.cast(x, tf.float32)
144
  residual = x_f32
145
  x_f32 = self.norm1(x)
@@ -171,7 +180,6 @@ class LoU(layers.Layer):
171
  x_comb = score_clipped * V
172
 
173
  out = self.norm(x_comb + residual)
174
- out = self.cross(out, z)
175
  out = self.glu(out)
176
  return tf.cast(out, x.dtype)
177
 
 
124
  out = self.W1(tf.nn.silu(a) * b)
125
  return tf.cast(out, x.dtype)
126
 
127
+ class SwiGLU(layers.Layer):
128
+ def __init__(self, d_model, d_ff):
129
+ super().__init__()
130
+ self.proj = layers.Dense(d_ff)
131
+ self.out = layers.Dense(d_model)
132
+ def call(self, x):
133
+ x_proj = self.proj(x)
134
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
135
+ return self.out(x_val * tf.nn.silu(x_gate))
136
+
137
+
138
  class LoU(layers.Layer):
139
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
140
  super().__init__()
 
148
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
149
 
150
  self.glu = SwiGLU(d_model, 320)
151
+ def call(self, x):
 
 
152
  x_f32 = tf.cast(x, tf.float32)
153
  residual = x_f32
154
  x_f32 = self.norm1(x)
 
180
  x_comb = score_clipped * V
181
 
182
  out = self.norm(x_comb + residual)
 
183
  out = self.glu(out)
184
  return tf.cast(out, x.dtype)
185