File size: 5,258 Bytes
2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 42458aa 2e0bc42 ebf013f 42458aa 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 ebf013f 2e0bc42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """Shared model architecture for multilingual 3B GPT — must match training exactly."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
VOCAB_SIZE = 32000
DIM = 3072
DEPTH = 26
N_HEADS = 24
HEAD_DIM = DIM // N_HEADS # 128
MAX_SEQ_LEN = 2048
ROPE_THETA = 10000.0
HIDDEN_DIM = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64 # SwiGLU hidden
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
def precompute_freqs_cis(dim, max_seq_len, theta=ROPE_THETA):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(x, freqs_cis):
# x: (B, n_heads, S, head_dim)
B, H, S, D = x.shape
x_complex = torch.view_as_complex(x.float().reshape(B, H, S, D // 2, 2))
freqs = freqs_cis[:S].unsqueeze(0).unsqueeze(1) # (1, 1, S, D//2)
x_rot = torch.view_as_real(x_complex * freqs).reshape(B, H, S, D)
return x_rot.type_as(x)
class FusedAttention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x, freqs_cis, mask=None):
B, S, D = x.shape
qkv = self.qkv(x).reshape(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q = q.transpose(1, 2) # (B, H, S, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
# Scaled dot-product attention
scale = math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) / scale
if mask is not None:
attn = attn + mask
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, S, D)
return self.out_proj(out)
class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads, hidden_dim):
super().__init__()
self.attn_norm = RMSNorm(dim)
self.attn = FusedAttention(dim, n_heads)
self.ffn_norm = RMSNorm(dim)
self.ffn = SwiGLUFFN(dim, hidden_dim)
def forward(self, x, freqs_cis, mask=None):
x = x + self.attn(self.attn_norm(x), freqs_cis, mask)
x = x + self.ffn(self.ffn_norm(x))
return x
class MultilingualGPT(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM)
self.layers = nn.ModuleList([
TransformerBlock(DIM, N_HEADS, HIDDEN_DIM) for _ in range(DEPTH)
])
self.norm = RMSNorm(DIM)
self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False)
# Tied embeddings
self.head.weight = self.tok_emb.weight
# Precompute RoPE
self.register_buffer('freqs_cis', precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN))
def forward(self, tokens, targets=None):
B, S = tokens.shape
x = self.tok_emb(tokens)
mask = torch.triu(torch.full((S, S), float('-inf'), device=tokens.device), diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, S, S)
for layer in self.layers:
x = layer(x, self.freqs_cis, mask)
x = self.norm(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
return logits, loss
def load_model(path, device='cuda'):
"""Load model from checkpoint, stripping prefixes."""
model = MultilingualGPT()
ckpt = torch.load(path, map_location='cpu', weights_only=False)
state = ckpt.get('model_state_dict', ckpt)
# Strip prefixes
cleaned = {}
for k, v in state.items():
new_k = k
for prefix in ['_orig_mod.', 'module.']:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
cleaned[new_k] = v
# Handle tied weights - remove head.weight if present (will be tied)
if 'head.weight' in cleaned and 'tok_emb.weight' in cleaned:
if torch.equal(cleaned['head.weight'], cleaned['tok_emb.weight']):
del cleaned['head.weight']
model.load_state_dict(cleaned, strict=False)
model = model.to(device).eval()
return model
def load_tokenizer(path):
"""Load SentencePiece tokenizer."""
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(path)
return sp
|