| import gc |
| import shutil |
| from pathlib import Path |
| from typing import Dict |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| """ |
| Sample usage: |
| |
| ```bash |
| python -m scripts.convert_checkpoint -h |
| |
| python -m scripts.convert_checkpoint converted |
| ``` |
| """ |
|
|
|
|
| def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: |
| converted = {} |
| converted["transformer.wte.weight"] = state_dict["tok_embeddings.weight"].to(dtype) |
| converted["lm_head.weight"] = state_dict["output.weight"].to(dtype) |
| converted["transformer.ln_f.scale"] = state_dict["norm.weight"].to(dtype) |
|
|
| for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): |
| |
| |
| converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat( |
| ( |
| state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype), |
| state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype), |
| state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype), |
| ) |
| ) |
| converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = state_dict[ |
| f"layers.{layer_idx}.attention.wo.weight" |
| ].to(dtype) |
| |
| converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = state_dict[ |
| f"layers.{layer_idx}.feed_forward.w1.weight" |
| ].to(dtype) |
| converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = state_dict[ |
| f"layers.{layer_idx}.feed_forward.w2.weight" |
| ].to(dtype) |
| converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = state_dict[ |
| f"layers.{layer_idx}.feed_forward.w3.weight" |
| ].to(dtype) |
| |
| converted[f"transformer.h.{layer_idx}.rms_1.scale"] = state_dict[f"layers.{layer_idx}.attention_norm.weight"].to(dtype) |
| converted[f"transformer.h.{layer_idx}.rms_2.scale"] = state_dict[f"layers.{layer_idx}.ffn_norm.weight"].to(dtype) |
| return converted |
|
|
|
|
| shard_dims = { |
| "lm_head.weight": 0, |
| "wte.weight": 1, |
| "attn.c_attn.weight": 0, |
| "attn.c_proj.weight": 1, |
| "mlp.c_fc1.weight": 0, |
| "mlp.c_fc2.weight": 0, |
| "mlp.c_proj.weight": 1 |
| } |
|
|
|
|
| def meta_weights_for_nano_model( |
| *, |
| output_dir: Path = Path("checkpoints/lit-llama"), |
| ckpt_dir: Path = Path("checkpoints/llama/"), |
| model_size: str = "7B", |
| dtype: str = "float32", |
| ) -> None: |
| output_dir = output_dir / model_size |
| ckpt_dir = ckpt_dir / model_size |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| shutil.copy(ckpt_dir.parent / "tokenizer.model", output_dir.parent) |
|
|
| dt = getattr(torch, dtype, None) |
| if not isinstance(dt, torch.dtype): |
| raise ValueError(f"{dtype} is not a valid dtype.") |
| dtype = dt |
|
|
| checkpoint_files = sorted(ckpt_dir.glob("*.pth")) |
| checkpoint_files.sort() |
| n_checkpoints = len(checkpoint_files) |
|
|
| if n_checkpoints == 0: |
| raise RuntimeError(f"No checkpoints were found at ckpt_dir {ckpt_dir}. `consolidated.0*.pth` files expected at that location.") |
|
|
| |
| |
| combined = None |
| for file in tqdm(checkpoint_files, total=n_checkpoints): |
| checkpoint = torch.load(file, map_location="cpu") |
| converted = convert_state_dict(checkpoint, dtype=dtype) |
| if combined is None: |
| combined = converted |
| continue |
| for name, param in converted.items(): |
| dim = None |
| for k, d in shard_dims.items(): |
| if k in name: |
| dim = d |
| break |
| if dim is None: |
| |
| |
| continue |
| combined[name] = torch.cat((combined[name], param), dim=dim) |
|
|
| del checkpoint |
| del converted |
| gc.collect() |
|
|
| for name, param in combined.items(): |
| if "c_attn" not in name: |
| continue |
|
|
| |
|
|
| src_chunk_len = param.shape[0] // n_checkpoints |
| mat_len = src_chunk_len // 3 |
| dst_chunk_len = mat_len * n_checkpoints |
| attn = torch.clone(param) |
| for i in range(n_checkpoints): |
| for j in range(3): |
| param[j * dst_chunk_len + i * mat_len: j * dst_chunk_len + (i+1) * mat_len] = \ |
| attn[i * src_chunk_len + j * mat_len: i * src_chunk_len + (j+1) * mat_len] |
|
|
| del attn |
| gc.collect() |
|
|
| torch.save(combined, output_dir / "lit-llama.pth") |
|
|
|
|
| if __name__ == "__main__": |
| from jsonargparse import CLI |
|
|
| CLI(meta_weights_for_nano_model) |
|
|