muverqqw commited on
Commit
59dea23
·
1 Parent(s): cc355cb

Update modeling_alinlight.py

Browse files
Files changed (1) hide show
  1. modeling_alinlight.py +10 -42
modeling_alinlight.py CHANGED
@@ -13,6 +13,9 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
 
 
16
  import math
17
  import torch
18
  import torch.nn as nn
@@ -41,7 +44,6 @@ class AlinlightPreTrainedModel(PreTrainedModel):
41
  def _init_weights(self, module):
42
  std = self.config.initializer_range
43
  if isinstance(module, nn.Linear):
44
- # Scale down residual projections to improve training stability at depth
45
  if getattr(module, '_is_residual_projection', False):
46
  module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
47
  else:
@@ -72,23 +74,6 @@ class AlinlightRMSNorm(nn.Module):
72
  return self.weight * x.to(input_dtype)
73
 
74
 
75
- class GatedNorm(nn.Module):
76
- """
77
- Gated Normalization wrapper.
78
- Allows the model to learn to skip normalization via a learnable gate.
79
- """
80
- def __init__(self, original_norm, initial_gate_value=-1.0):
81
- super().__init__()
82
- self.norm = original_norm
83
- # Initialize gate to -1.0 (sigmoid(-1) ≈ 0.27) to start conservatively
84
- self.gate = nn.Parameter(torch.tensor(initial_gate_value))
85
-
86
- def forward(self, x, *args, **kwargs):
87
- normed = self.norm(x, *args, **kwargs)
88
- g = torch.sigmoid(self.gate)
89
- return (1.0 - g) * x + g * normed
90
-
91
-
92
  class AlinlightRotaryEmbedding(nn.Module):
93
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
94
  super().__init__()
@@ -152,22 +137,14 @@ class AlinlightMLP(nn.Module):
152
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
153
  self.act_fn = nn.SiLU()
154
 
155
- # Use GatedNorm for the inner normalization
156
- self.pre_down_norm = GatedNorm(
157
- AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
158
- )
159
-
160
- # Tag for specialized initialization
161
  self.down_proj._is_residual_projection = True
162
 
163
  def forward(self, x):
164
- intermediate = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
165
- intermediate = self.pre_down_norm(intermediate)
166
- return self.down_proj(intermediate)
167
 
168
 
169
  # ==========================================
170
- # 3. ATTENTION
171
  # ==========================================
172
 
173
  class AlinlightAttention(nn.Module):
@@ -192,9 +169,8 @@ class AlinlightAttention(nn.Module):
192
 
193
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
194
  if self.use_qk_norm:
195
- # Use GatedNorm for QK Normalization
196
- self.q_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
197
- self.k_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
198
 
199
  self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
200
 
@@ -233,8 +209,7 @@ class AlinlightAttention(nn.Module):
233
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
234
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
235
 
236
- # 3. Sliding Window (Slicing)
237
- kv_seq_len = key_states.shape[2] # NOTE: This is the length BEFORE slicing
238
 
239
  if self.sliding_window is not None and kv_seq_len > self.sliding_window:
240
  slicing_tokens = kv_seq_len - self.sliding_window
@@ -246,12 +221,12 @@ class AlinlightAttention(nn.Module):
246
 
247
  past_key_value = (key_states, value_states) if use_cache else None
248
 
249
- # 4. GQA Repeat
250
  if self.num_key_value_groups > 1:
251
  key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
252
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
253
 
254
- # 5. Attention Mechanism
255
  attn_weights = None
256
 
257
  if output_attentions or self.attn_logit_softcapping is not None:
@@ -264,9 +239,7 @@ class AlinlightAttention(nn.Module):
264
  attn_weights = attn_weights + attention_mask
265
 
266
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
267
-
268
  attn_weights_for_output = attn_weights if output_attentions else None
269
-
270
  attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
271
  attn_output = torch.matmul(attn_weights_dropped, value_states)
272
  else:
@@ -339,7 +312,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
339
  self.vocab_size = config.vocab_size
340
 
341
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
342
-
343
  self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
344
 
345
  embed_pdrop = getattr(config, 'embed_pdrop', 0.0)
@@ -414,7 +386,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
414
  use_cache = use_cache if use_cache is not None else self.config.use_cache
415
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
416
 
417
- # --- SAFETY CHECK FOR GRADIENT CHECKPOINTING ---
418
  if self.gradient_checkpointing and self.training:
419
  if use_cache:
420
  logger.warning_once(
@@ -459,7 +430,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
459
  if self.gradient_checkpointing and self.training:
460
  def create_custom_forward(module):
461
  def custom_forward(*inputs):
462
- # Force use_cache=False inside checkpoint to be safe
463
  return module(*inputs, output_attentions=output_attentions, use_cache=False, rotary_pos_emb=(cos, sin))
464
  return custom_forward
465
 
@@ -520,8 +490,6 @@ class AlinlightForCausalLM(AlinlightPreTrainedModel, GenerationMixin):
520
  if config.tie_word_embeddings:
521
  self.lm_head.weight = self.model.embed_tokens.weight
522
 
523
- # Note: self.post_init() is called here, and inside AlinlightModel.
524
- # This re-initialization is consistent with standard HF models (e.g. Llama).
525
  self.post_init()
526
 
527
  def get_input_embeddings(self): return self.model.embed_tokens
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ # -*- coding: utf-8 -*-
17
+ # Copyright 2026 EngineerGL Research.
18
+
19
  import math
20
  import torch
21
  import torch.nn as nn
 
44
  def _init_weights(self, module):
45
  std = self.config.initializer_range
46
  if isinstance(module, nn.Linear):
 
47
  if getattr(module, '_is_residual_projection', False):
48
  module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
49
  else:
 
74
  return self.weight * x.to(input_dtype)
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class AlinlightRotaryEmbedding(nn.Module):
78
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
79
  super().__init__()
 
137
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
138
  self.act_fn = nn.SiLU()
139
 
 
 
 
 
 
 
140
  self.down_proj._is_residual_projection = True
141
 
142
  def forward(self, x):
143
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
144
 
145
 
146
  # ==========================================
147
+ # 3. ATTENTION (Стабильная QK-Norm без гейтов)
148
  # ==========================================
149
 
150
  class AlinlightAttention(nn.Module):
 
169
 
170
  self.use_qk_norm = getattr(config, "use_qk_norm", True)
171
  if self.use_qk_norm:
172
+ self.q_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
173
+ self.k_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
 
174
 
175
  self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
176
 
 
209
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
 
212
+ kv_seq_len = key_states.shape[2]
 
213
 
214
  if self.sliding_window is not None and kv_seq_len > self.sliding_window:
215
  slicing_tokens = kv_seq_len - self.sliding_window
 
221
 
222
  past_key_value = (key_states, value_states) if use_cache else None
223
 
224
+ # 3. GQA Repeat
225
  if self.num_key_value_groups > 1:
226
  key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
227
  value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
228
 
229
+ # 4. Attention
230
  attn_weights = None
231
 
232
  if output_attentions or self.attn_logit_softcapping is not None:
 
239
  attn_weights = attn_weights + attention_mask
240
 
241
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
242
  attn_weights_for_output = attn_weights if output_attentions else None
 
243
  attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
244
  attn_output = torch.matmul(attn_weights_dropped, value_states)
245
  else:
 
312
  self.vocab_size = config.vocab_size
313
 
314
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
315
  self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
316
 
317
  embed_pdrop = getattr(config, 'embed_pdrop', 0.0)
 
386
  use_cache = use_cache if use_cache is not None else self.config.use_cache
387
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
388
 
 
389
  if self.gradient_checkpointing and self.training:
390
  if use_cache:
391
  logger.warning_once(
 
430
  if self.gradient_checkpointing and self.training:
431
  def create_custom_forward(module):
432
  def custom_forward(*inputs):
 
433
  return module(*inputs, output_attentions=output_attentions, use_cache=False, rotary_pos_emb=(cos, sin))
434
  return custom_forward
435
 
 
490
  if config.tie_word_embeddings:
491
  self.lm_head.weight = self.model.embed_tokens.weight
492
 
 
 
493
  self.post_init()
494
 
495
  def get_input_embeddings(self): return self.model.embed_tokens