Update modeling_alinlight.py
Browse files- modeling_alinlight.py +10 -42
modeling_alinlight.py
CHANGED
|
@@ -13,6 +13,9 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
import math
|
| 17 |
import torch
|
| 18 |
import torch.nn as nn
|
|
@@ -41,7 +44,6 @@ class AlinlightPreTrainedModel(PreTrainedModel):
|
|
| 41 |
def _init_weights(self, module):
|
| 42 |
std = self.config.initializer_range
|
| 43 |
if isinstance(module, nn.Linear):
|
| 44 |
-
# Scale down residual projections to improve training stability at depth
|
| 45 |
if getattr(module, '_is_residual_projection', False):
|
| 46 |
module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
|
| 47 |
else:
|
|
@@ -72,23 +74,6 @@ class AlinlightRMSNorm(nn.Module):
|
|
| 72 |
return self.weight * x.to(input_dtype)
|
| 73 |
|
| 74 |
|
| 75 |
-
class GatedNorm(nn.Module):
|
| 76 |
-
"""
|
| 77 |
-
Gated Normalization wrapper.
|
| 78 |
-
Allows the model to learn to skip normalization via a learnable gate.
|
| 79 |
-
"""
|
| 80 |
-
def __init__(self, original_norm, initial_gate_value=-1.0):
|
| 81 |
-
super().__init__()
|
| 82 |
-
self.norm = original_norm
|
| 83 |
-
# Initialize gate to -1.0 (sigmoid(-1) ≈ 0.27) to start conservatively
|
| 84 |
-
self.gate = nn.Parameter(torch.tensor(initial_gate_value))
|
| 85 |
-
|
| 86 |
-
def forward(self, x, *args, **kwargs):
|
| 87 |
-
normed = self.norm(x, *args, **kwargs)
|
| 88 |
-
g = torch.sigmoid(self.gate)
|
| 89 |
-
return (1.0 - g) * x + g * normed
|
| 90 |
-
|
| 91 |
-
|
| 92 |
class AlinlightRotaryEmbedding(nn.Module):
|
| 93 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 94 |
super().__init__()
|
|
@@ -152,22 +137,14 @@ class AlinlightMLP(nn.Module):
|
|
| 152 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 153 |
self.act_fn = nn.SiLU()
|
| 154 |
|
| 155 |
-
# Use GatedNorm for the inner normalization
|
| 156 |
-
self.pre_down_norm = GatedNorm(
|
| 157 |
-
AlinlightRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
# Tag for specialized initialization
|
| 161 |
self.down_proj._is_residual_projection = True
|
| 162 |
|
| 163 |
def forward(self, x):
|
| 164 |
-
|
| 165 |
-
intermediate = self.pre_down_norm(intermediate)
|
| 166 |
-
return self.down_proj(intermediate)
|
| 167 |
|
| 168 |
|
| 169 |
# ==========================================
|
| 170 |
-
# 3. ATTENTION
|
| 171 |
# ==========================================
|
| 172 |
|
| 173 |
class AlinlightAttention(nn.Module):
|
|
@@ -192,9 +169,8 @@ class AlinlightAttention(nn.Module):
|
|
| 192 |
|
| 193 |
self.use_qk_norm = getattr(config, "use_qk_norm", True)
|
| 194 |
if self.use_qk_norm:
|
| 195 |
-
|
| 196 |
-
self.
|
| 197 |
-
self.k_norm = GatedNorm(AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps))
|
| 198 |
|
| 199 |
self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
|
| 200 |
|
|
@@ -233,8 +209,7 @@ class AlinlightAttention(nn.Module):
|
|
| 233 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 234 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 235 |
|
| 236 |
-
|
| 237 |
-
kv_seq_len = key_states.shape[2] # NOTE: This is the length BEFORE slicing
|
| 238 |
|
| 239 |
if self.sliding_window is not None and kv_seq_len > self.sliding_window:
|
| 240 |
slicing_tokens = kv_seq_len - self.sliding_window
|
|
@@ -246,12 +221,12 @@ class AlinlightAttention(nn.Module):
|
|
| 246 |
|
| 247 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 248 |
|
| 249 |
-
#
|
| 250 |
if self.num_key_value_groups > 1:
|
| 251 |
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 252 |
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 253 |
|
| 254 |
-
#
|
| 255 |
attn_weights = None
|
| 256 |
|
| 257 |
if output_attentions or self.attn_logit_softcapping is not None:
|
|
@@ -264,9 +239,7 @@ class AlinlightAttention(nn.Module):
|
|
| 264 |
attn_weights = attn_weights + attention_mask
|
| 265 |
|
| 266 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 267 |
-
|
| 268 |
attn_weights_for_output = attn_weights if output_attentions else None
|
| 269 |
-
|
| 270 |
attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 271 |
attn_output = torch.matmul(attn_weights_dropped, value_states)
|
| 272 |
else:
|
|
@@ -339,7 +312,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
|
|
| 339 |
self.vocab_size = config.vocab_size
|
| 340 |
|
| 341 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 342 |
-
|
| 343 |
self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
|
| 344 |
|
| 345 |
embed_pdrop = getattr(config, 'embed_pdrop', 0.0)
|
|
@@ -414,7 +386,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
|
|
| 414 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 415 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 416 |
|
| 417 |
-
# --- SAFETY CHECK FOR GRADIENT CHECKPOINTING ---
|
| 418 |
if self.gradient_checkpointing and self.training:
|
| 419 |
if use_cache:
|
| 420 |
logger.warning_once(
|
|
@@ -459,7 +430,6 @@ class AlinlightModel(AlinlightPreTrainedModel):
|
|
| 459 |
if self.gradient_checkpointing and self.training:
|
| 460 |
def create_custom_forward(module):
|
| 461 |
def custom_forward(*inputs):
|
| 462 |
-
# Force use_cache=False inside checkpoint to be safe
|
| 463 |
return module(*inputs, output_attentions=output_attentions, use_cache=False, rotary_pos_emb=(cos, sin))
|
| 464 |
return custom_forward
|
| 465 |
|
|
@@ -520,8 +490,6 @@ class AlinlightForCausalLM(AlinlightPreTrainedModel, GenerationMixin):
|
|
| 520 |
if config.tie_word_embeddings:
|
| 521 |
self.lm_head.weight = self.model.embed_tokens.weight
|
| 522 |
|
| 523 |
-
# Note: self.post_init() is called here, and inside AlinlightModel.
|
| 524 |
-
# This re-initialization is consistent with standard HF models (e.g. Llama).
|
| 525 |
self.post_init()
|
| 526 |
|
| 527 |
def get_input_embeddings(self): return self.model.embed_tokens
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
|
| 16 |
+
# -*- coding: utf-8 -*-
|
| 17 |
+
# Copyright 2026 EngineerGL Research.
|
| 18 |
+
|
| 19 |
import math
|
| 20 |
import torch
|
| 21 |
import torch.nn as nn
|
|
|
|
| 44 |
def _init_weights(self, module):
|
| 45 |
std = self.config.initializer_range
|
| 46 |
if isinstance(module, nn.Linear):
|
|
|
|
| 47 |
if getattr(module, '_is_residual_projection', False):
|
| 48 |
module.weight.data.normal_(mean=0.0, std=std / math.sqrt(2 * self.config.num_hidden_layers))
|
| 49 |
else:
|
|
|
|
| 74 |
return self.weight * x.to(input_dtype)
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class AlinlightRotaryEmbedding(nn.Module):
|
| 78 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 79 |
super().__init__()
|
|
|
|
| 137 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 138 |
self.act_fn = nn.SiLU()
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
self.down_proj._is_residual_projection = True
|
| 141 |
|
| 142 |
def forward(self, x):
|
| 143 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
# ==========================================
|
| 147 |
+
# 3. ATTENTION (Стабильная QK-Norm без гейтов)
|
| 148 |
# ==========================================
|
| 149 |
|
| 150 |
class AlinlightAttention(nn.Module):
|
|
|
|
| 169 |
|
| 170 |
self.use_qk_norm = getattr(config, "use_qk_norm", True)
|
| 171 |
if self.use_qk_norm:
|
| 172 |
+
self.q_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 173 |
+
self.k_norm = AlinlightRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
|
|
| 174 |
|
| 175 |
self.attn_logit_softcapping = getattr(config, 'attn_logit_softcapping', None)
|
| 176 |
|
|
|
|
| 209 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 210 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 211 |
|
| 212 |
+
kv_seq_len = key_states.shape[2]
|
|
|
|
| 213 |
|
| 214 |
if self.sliding_window is not None and kv_seq_len > self.sliding_window:
|
| 215 |
slicing_tokens = kv_seq_len - self.sliding_window
|
|
|
|
| 221 |
|
| 222 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 223 |
|
| 224 |
+
# 3. GQA Repeat
|
| 225 |
if self.num_key_value_groups > 1:
|
| 226 |
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 227 |
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 228 |
|
| 229 |
+
# 4. Attention
|
| 230 |
attn_weights = None
|
| 231 |
|
| 232 |
if output_attentions or self.attn_logit_softcapping is not None:
|
|
|
|
| 239 |
attn_weights = attn_weights + attention_mask
|
| 240 |
|
| 241 |
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
| 242 |
attn_weights_for_output = attn_weights if output_attentions else None
|
|
|
|
| 243 |
attn_weights_dropped = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 244 |
attn_output = torch.matmul(attn_weights_dropped, value_states)
|
| 245 |
else:
|
|
|
|
| 312 |
self.vocab_size = config.vocab_size
|
| 313 |
|
| 314 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
|
| 315 |
self.embed_scale = math.sqrt(config.hidden_size) if getattr(config, 'embed_scale', False) else 1.0
|
| 316 |
|
| 317 |
embed_pdrop = getattr(config, 'embed_pdrop', 0.0)
|
|
|
|
| 386 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 387 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 388 |
|
|
|
|
| 389 |
if self.gradient_checkpointing and self.training:
|
| 390 |
if use_cache:
|
| 391 |
logger.warning_once(
|
|
|
|
| 430 |
if self.gradient_checkpointing and self.training:
|
| 431 |
def create_custom_forward(module):
|
| 432 |
def custom_forward(*inputs):
|
|
|
|
| 433 |
return module(*inputs, output_attentions=output_attentions, use_cache=False, rotary_pos_emb=(cos, sin))
|
| 434 |
return custom_forward
|
| 435 |
|
|
|
|
| 490 |
if config.tie_word_embeddings:
|
| 491 |
self.lm_head.weight = self.model.embed_tokens.weight
|
| 492 |
|
|
|
|
|
|
|
| 493 |
self.post_init()
|
| 494 |
|
| 495 |
def get_input_embeddings(self): return self.model.embed_tokens
|