| | |
| | """ |
| | Load and use compressed models saved by compress_model.py |
| | """ |
| |
|
| | import os |
| | import json |
| | import torch |
| | from transformers import AutoTokenizer |
| | from sentence_transformers import SentenceTransformer |
| | import tensorly as tl |
| | from tltorch.factorized_layers import FactorizedLinear, FactorizedEmbedding |
| |
|
| | |
| | tl.set_backend("pytorch") |
| |
|
| |
|
| | def reconstruct_factorized_layer(layer_info, state_dict_prefix): |
| | """Reconstruct a factorized layer from saved metadata.""" |
| | layer_type = layer_info["type"] |
| | |
| | |
| | factorization = layer_info.get("factorization", "cp") |
| | rank = layer_info.get("rank", 4) |
| | |
| | if layer_type == "FactorizedLinear": |
| | |
| | in_features = layer_info.get("in_features") |
| | out_features = layer_info.get("out_features") |
| | |
| | if in_features is None or out_features is None: |
| | raise ValueError(f"Missing in_features or out_features for FactorizedLinear layer") |
| | |
| | |
| | import torch.nn as nn |
| | linear = nn.Linear(in_features, out_features, bias=layer_info.get("bias", True)) |
| | |
| | |
| | layer = FactorizedLinear.from_linear( |
| | linear, |
| | rank=rank, |
| | factorization=factorization.upper(), |
| | implementation='reconstructed' |
| | ) |
| | |
| | elif layer_type == "FactorizedEmbedding": |
| | |
| | num_embeddings = layer_info.get("num_embeddings") |
| | embedding_dim = layer_info.get("embedding_dim") |
| | |
| | if num_embeddings is None or embedding_dim is None: |
| | raise ValueError(f"Missing num_embeddings or embedding_dim for FactorizedEmbedding layer") |
| | |
| | |
| | import torch.nn as nn |
| | embedding = nn.Embedding( |
| | num_embeddings=num_embeddings, |
| | embedding_dim=embedding_dim, |
| | padding_idx=layer_info.get("padding_idx", None), |
| | max_norm=layer_info.get("max_norm", None), |
| | norm_type=layer_info.get("norm_type", 2.0), |
| | scale_grad_by_freq=layer_info.get("scale_grad_by_freq", False), |
| | sparse=layer_info.get("sparse", False) |
| | ) |
| | |
| | |
| | layer = FactorizedEmbedding.from_embedding( |
| | embedding, |
| | rank=rank, |
| | factorization=factorization |
| | ) |
| | |
| | else: |
| | raise ValueError(f"Unknown factorized layer type: {layer_type}") |
| | |
| | return layer |
| |
|
| |
|
| | def set_module_by_path(model, path, new_module): |
| | """Set a module in the model by its dotted path.""" |
| | parts = path.split('.') |
| | parent = model |
| | |
| | |
| | for part in parts[:-1]: |
| | parent = getattr(parent, part) |
| | |
| | |
| | setattr(parent, parts[-1], new_module) |
| |
|
| |
|
| | def load_compressed_model(load_dir: str, device="cpu"): |
| | """Load a compressed model from the saved artifacts.""" |
| | |
| | |
| | factorization_info_path = os.path.join(load_dir, "factorization_info.json") |
| | if not os.path.exists(factorization_info_path): |
| | raise FileNotFoundError(f"No factorization_info.json found in {load_dir}") |
| | |
| | with open(factorization_info_path, "r") as f: |
| | factorized_info = json.load(f) |
| | |
| | |
| | checkpoint_path = os.path.join(load_dir, "pytorch_model.bin") |
| | if not os.path.exists(checkpoint_path): |
| | |
| | checkpoint_path = os.path.join(load_dir, "model_state.pt") |
| | if not os.path.exists(checkpoint_path): |
| | raise FileNotFoundError(f"No model checkpoint found in {load_dir}") |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| | |
| | |
| | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
| | state_dict = checkpoint["state_dict"] |
| | is_sentence_encoder = checkpoint.get("is_sentence_encoder", False) |
| | model_name = checkpoint.get("model_name", "unknown") |
| | else: |
| | |
| | state_dict = checkpoint |
| | is_sentence_encoder = False |
| | model_name = "unknown" |
| | |
| | print(f"Loading compressed model (sentence_encoder={is_sentence_encoder})") |
| | |
| | |
| | if is_sentence_encoder: |
| | |
| | |
| | print("Note: Loading sentence encoders requires the original model architecture.") |
| | print("The compressed weights will be loaded, but the model structure needs to be reconstructed manually.") |
| | |
| | |
| | return { |
| | "state_dict": state_dict, |
| | "factorized_info": factorized_info, |
| | "is_sentence_encoder": True, |
| | "model_name": model_name, |
| | } |
| | |
| | else: |
| | |
| | |
| | print("Note: Loading compressed models requires knowing the original model architecture.") |
| | |
| | return { |
| | "state_dict": state_dict, |
| | "factorized_info": factorized_info, |
| | "is_sentence_encoder": False, |
| | "model_name": model_name, |
| | } |
| |
|
| |
|
| | def load_compressed_sentence_transformer(original_model_name: str, compressed_dir: str, device="cpu"): |
| | """ |
| | Load a compressed SentenceTransformer model. |
| | |
| | Args: |
| | original_model_name: Name of the original model (e.g., "nomic-ai/CodeRankEmbed") |
| | compressed_dir: Directory containing the compressed model |
| | device: Device to load the model on |
| | |
| | Returns: |
| | Compressed SentenceTransformer model |
| | """ |
| | |
| | model = SentenceTransformer(original_model_name, device=device, trust_remote_code=True) |
| | |
| | |
| | artifacts = load_compressed_model(compressed_dir, device) |
| | |
| | if not artifacts.get("is_sentence_encoder"): |
| | raise ValueError("The compressed model is not a sentence encoder") |
| | |
| | |
| | state_dict = artifacts["state_dict"] |
| | factorized_info = artifacts["factorized_info"] |
| | |
| | |
| | for layer_path, layer_info in factorized_info.items(): |
| | |
| | factorized_layer = reconstruct_factorized_layer(layer_info, layer_path) |
| | |
| | |
| | set_module_by_path(model, layer_path, factorized_layer) |
| | |
| | |
| | model.load_state_dict(state_dict, strict=False) |
| | |
| | return model |
| |
|
| |
|
| | def example_usage(): |
| | """Example of how to use the compressed model loader.""" |
| | |
| | compressed_dir = "coderank_compressed" |
| | original_model = "nomic-ai/CodeRankEmbed" |
| | |
| | print(f"Loading compressed model from {compressed_dir}") |
| | |
| | try: |
| | |
| | model = load_compressed_sentence_transformer( |
| | original_model_name=original_model, |
| | compressed_dir=compressed_dir, |
| | device="cpu" |
| | ) |
| | |
| | |
| | sentences = ["def hello_world():\n print('Hello, World!')", "System.out.println('Hello, World!');"] |
| | embeddings = model.encode(sentences) |
| | |
| | print(f"✔ Successfully loaded compressed model") |
| | print(f" Embedding shape: {embeddings.shape}") |
| | |
| | except Exception as e: |
| | print(f"⚠ Error loading compressed model: {e}") |
| | print("\nTo manually load the compressed model:") |
| | print("1. Load the factorization_info.json to see the compressed layer structure") |
| | print("2. Reconstruct the model with factorized layers based on the metadata") |
| | print("3. Load the state dict from pytorch_model.bin") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | example_usage() |