| |
|
|
| import argparse |
| import os |
| import torch |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| required=True, |
| help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)", |
| ) |
| parser.add_argument( |
| "--custom_model_path", |
| type=str, |
| required=False, |
| help="(Optional) Path to the model implementation source if needed", |
| ) |
| args = parser.parse_args() |
|
|
| print(f"Loading config from: {args.model_path}") |
| config = AutoConfig.from_pretrained(args.model_path) |
|
|
| if hasattr(config, "num_small_experts"): |
| num_small_experts = config.num_small_experts |
| else: |
| raise ValueError("The model config does not contain 'num_small_experts'.") |
|
|
| print(f"Number of small experts: {num_small_experts}") |
|
|
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) |
| model.eval() |
|
|
| print("Inspecting small expert weights...") |
| total_params = 0 |
| matched_params = 0 |
| for name, param in model.named_parameters(): |
| total_params += 1 |
| if f"small_experts." in name: |
| matched_params += 1 |
| print(f"[Matched] {name} - shape: {tuple(param.shape)}") |
| print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'") |
|
|
| print("Done.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|