|
|
|
|
| """# shared_subspace_encoder.py"""
|
|
|
| from typing import Optional
|
|
|
| import torch
|
| from torch import nn
|
|
|
| from transformers.configuration_utils import PretrainedConfig
|
| from transformers.modeling_utils import PreTrainedModel
|
| from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
|
|
| from ..layers.mla import MultiheadLatentAttention, RotaryEmbedding
|
| from ..layers.feedforward import SubspaceFeedForward
|
| from ..models.shared_space_config import SharedSpaceDecoderConfig
|
|
|
| """
|
| RMSNorm
|
| From: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
|
| """
|
|
|
| class DeepseekV3RMSNorm(nn.Module):
|
| def __init__(self, hidden_size, eps=1e-6):
|
| """
|
| DeepseekV3RMSNorm is equivalent to T5LayerNorm
|
| """
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.ones(hidden_size))
|
| self.variance_epsilon = eps
|
|
|
| def forward(self, hidden_states):
|
| input_dtype = hidden_states.dtype
|
| hidden_states = hidden_states.to(torch.float32)
|
| variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| return self.weight * hidden_states.to(input_dtype)
|
|
|
| def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
|
| """
|
| Create a normalization layer based on the config norm_type.
|
|
|
| Args:
|
| hidden_size: The dimension to normalize over
|
| config: Configuration containing norm_type and epsilon values
|
|
|
| Returns:
|
| Either a LayerNorm or RMSNorm layer
|
| """
|
| if config.norm_type == "layernorm":
|
| return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
| elif config.norm_type == "rmsnorm":
|
| return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
| else:
|
|
|
| raise ValueError(f"Unknown norm_type: {config.norm_type}")
|
|
|
| """#### *PreTrainedModel"""
|
|
|
| class SharedSpaceDecoderPreTrainedModel(PreTrainedModel):
|
| """
|
| The **PreTrainedModel object:
|
| - Is instantiated when TODO
|
| - Initializes:
|
| - TODO
|
| - Provides access to TODO
|
| - Executes TODO
|
| """
|
|
|
| config_class = SharedSpaceDecoderConfig
|
| base_model_prefix = "model"
|
|
|
| def _init_weights(self, module: nn.Module) -> None:
|
| """Weight initialization hook used by :class:`PreTrainedModel`.
|
|
|
| ``PreTrainedModel.post_init`` will recursively apply this function to
|
| every submodule right after construction. HuggingFace models override
|
| it so that creating a model from scratch yields the same initialization
|
| as ``from_pretrained`` when no checkpoint is supplied.
|
|
|
| This decoder-specific initialization strategy includes:
|
| - Proper handling of configurable normalization layers (LayerNorm or RMSNorm)
|
| - Special initialization for language modeling heads
|
| - Considerations for causal attention and autoregressive modeling
|
| - Support for both dense and decomposed vocabulary embeddings
|
| """
|
|
|
| if isinstance(module, nn.Linear):
|
|
|
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| if module.bias is not None:
|
| module.bias.data.zero_()
|
|
|
| elif isinstance(module, nn.Embedding):
|
|
|
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| if module.padding_idx is not None:
|
| module.weight.data[module.padding_idx].zero_()
|
|
|
| elif isinstance(module, DeepseekV3RMSNorm):
|
|
|
| module.weight.data.fill_(1.0)
|
|
|
| elif isinstance(module, nn.LayerNorm):
|
|
|
| module.bias.data.zero_()
|
| module.weight.data.fill_(1.0)
|
|
|
|
|
| class SharedSpaceDecoderLayer(nn.Module):
|
| """
|
| The **Layer object:
|
| - Is instantiated by :class:`SharedSpaceDecoderModel` for each
|
| Transformer block in the decoder.
|
| - Initializes:
|
| - ``self_attn`` – multi-head latent attention implementing either
|
| dense or latent projections depending on the configuration.
|
| - ``ffn`` – a :class:`SubspaceFeedForward` block.
|
| - RMSNorm layers for pre-attention and pre-FFN normalization.
|
| - Provides access to the attention and feed-forward submodules via the
|
| attributes ``self_attn`` and ``ffn``.
|
| - Executes a single decoder block in :meth:`forward`.
|
| """
|
|
|
| def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None:
|
|
|
| super().__init__()
|
|
|
|
|
| self.attn_input_norm = create_norm_layer(config.hidden_size, config)
|
|
|
|
|
| self.self_attn = MultiheadLatentAttention(config, layer_idx)
|
|
|
|
|
| self.ffn_input_norm = create_norm_layer(config.hidden_size, config)
|
|
|
|
|
| self.ffn = SubspaceFeedForward(config, layer_idx)
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| attention_mask: Optional[torch.Tensor],
|
| ) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
| residual_strm = hidden_states
|
|
|
|
|
| attn_input = self.attn_input_norm(hidden_states)
|
|
|
|
|
| attn_output = self.self_attn(
|
| attn_input,
|
| position_embeddings,
|
| attention_mask,
|
| )
|
|
|
|
|
|
|
| hidden_states = residual_strm + attn_output
|
|
|
|
|
|
|
|
|
| residual_strm = hidden_states
|
|
|
|
|
| ffn_input = self.ffn_input_norm(hidden_states)
|
|
|
|
|
| ffn_output = self.ffn(ffn_input)
|
|
|
|
|
| hidden_states = residual_strm + ffn_output
|
|
|
| return hidden_states
|
|
|
|
|
| class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel):
|
| """
|
| The **Model object:
|
| - Initializes:
|
| - The vocabulary embeddings (and optional decomposition)
|
| - Position embeddings (calculated in RotaryEmbedding)
|
| - All of the **Layer objects.
|
| - Provides interface to vocab embeddings.
|
| - Executes the whole decoder model in `forward` with causal attention.
|
|
|
| This is the base decoder without the language modeling head.
|
| Use SubspaceDecoderForCausalLM for language modeling tasks.
|
| """
|
|
|
| def __init__(self, config: SharedSpaceDecoderConfig) -> None:
|
| super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if config.vocab_subspace:
|
|
|
|
|
|
|
| self.vocab_embed = nn.Embedding(
|
| config.vocab_size,
|
| config.vocab_rank
|
| )
|
|
|
|
|
|
|
|
|
| self.vocab_proj = nn.Linear(
|
| config.vocab_rank,
|
| config.hidden_size,
|
| bias=False
|
| )
|
|
|
|
|
| else:
|
|
|
| self.vocab_embed = nn.Embedding(
|
| config.vocab_size,
|
| config.hidden_size
|
| )
|
|
|
| self.vocab_proj = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.rope = RotaryEmbedding(config)
|
|
|
|
|
|
|
|
|
|
|
| layers = []
|
|
|
|
|
| for i in range(config.num_hidden_layers):
|
|
|
| layers.append(
|
| SharedSpaceDecoderLayer(
|
| config,
|
| layer_idx = i
|
| )
|
| )
|
|
|
|
|
| self.layers = nn.ModuleList(layers)
|
|
|
|
|
| self.post_init()
|
|
|
|
|
|
|
|
|
| def embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
| """
|
| Return token embeddings for input ids.
|
| This will perform the up projection to model space if the vocabulary is
|
| decomposed.
|
|
|
| input_ids have shape [batch_size, seq_len]
|
| """
|
|
|
|
|
| if self.vocab_proj is not None:
|
|
|
|
|
|
|
|
|
| x = self.vocab_embed(input_ids)
|
|
|
|
|
| return(self.vocab_proj(x))
|
|
|
|
|
| else:
|
|
|
| return self.vocab_embed(input_ids)
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.LongTensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| **kwargs,
|
| ) -> torch.Tensor:
|
| """
|
| Run the full decoder stack with causal attention.
|
|
|
| Inputs:
|
| input_ids [batch_size, seq_len]
|
| attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding
|
|
|
| Returns:
|
| Final decoder layer output [batch_size, seq_len, model_size]
|
| """
|
|
|
|
|
|
|
| hidden_states = self.embed(input_ids)
|
|
|
|
|
|
|
|
|
| seq_len = hidden_states.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| R_cos = self.rope.cos[:seq_len]
|
| R_sin = self.rope.sin[:seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| use_sdpa_attention_masks = (
|
| self.attn_implementation == "sdpa"
|
| and self.position_embedding_type == "absolute"
|
| and head_mask is None
|
| and not output_attentions
|
| )
|
| """
|
|
|
|
|
|
|
| if True:
|
|
|
|
|
| extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| attention_mask,
|
| hidden_states.dtype,
|
| tgt_len = seq_len
|
| )
|
| attention_mask = extended_attention_mask
|
|
|
|
|
|
|
|
|
|
|
| for layer_i, layer in enumerate(self.layers):
|
|
|
|
|
| hidden_states = layer(
|
| hidden_states,
|
| (R_cos, R_sin),
|
| attention_mask,
|
| )
|
|
|
|
|
| return hidden_states
|
|
|
|
|