| |
|
| | from model_neo import NeoMiniConfig, NeoMini
|
| | import torch
|
| |
|
| | def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt",
|
| | new_max_len=16384):
|
| | """Extend model's context window from 2048 to 4096 tokens"""
|
| |
|
| | print(f"Extending context window to {new_max_len} tokens...")
|
| |
|
| |
|
| | config = NeoMiniConfig()
|
| | config.max_seq_len = new_max_len
|
| |
|
| |
|
| | extended_model = NeoMini(config)
|
| |
|
| |
|
| | checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| | original_state = checkpoint['model_state_dict']
|
| |
|
| |
|
| | extended_state = extended_model.state_dict()
|
| |
|
| | for key in original_state:
|
| | if key in extended_state:
|
| | if 'pos' in key and extended_state[key].shape != original_state[key].shape:
|
| |
|
| | print(f"Interpolating position embeddings: {key}")
|
| | old_pos_emb = original_state[key]
|
| | new_pos_emb = torch.nn.functional.interpolate(
|
| | old_pos_emb.unsqueeze(0).unsqueeze(0),
|
| | size=(new_max_len, old_pos_emb.shape[-1]),
|
| | mode='linear'
|
| | ).squeeze(0).squeeze(0)
|
| | extended_state[key] = new_pos_emb
|
| | else:
|
| | extended_state[key] = original_state[key]
|
| |
|
| | extended_model.load_state_dict(extended_state)
|
| |
|
| |
|
| | extended_checkpoint = {
|
| | 'model_state_dict': extended_model.state_dict(),
|
| | 'config': config.to_dict()
|
| | }
|
| |
|
| | output_path = "checkpoints/extended_context_model.pt"
|
| | torch.save(extended_checkpoint, output_path)
|
| | print(f"Extended model saved to {output_path}")
|
| |
|
| | return extended_model, config
|
| |
|
| | if __name__ == "__main__":
|
| | extend_model_context()
|
| |
|