| import gc |
| import os |
| import re |
| from typing import Union, List |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForCausalLM |
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch |
| from awq.quantize.quantizer import real_quantize_model_weight |
| from awq.quantize.qmodule import WQLinear |
| from tqdm import tqdm |
|
|
| import tinychat.utils.constants |
|
|
| version_message = """ |
| [Warning] The awq quantized checkpoint seems to be in v1 format. |
| If the model cannot be loaded successfully, please use the latest awq library to re-quantized the model, or repack the current checkpoint with tinychat/offline-weight-repacker.py |
| """ |
|
|
|
|
| def ckpt_version_check(quant_path): |
| if not quant_path.endswith("v2.pt"): |
| print(version_message) |
|
|
|
|
| def mem_efficient_load_checkpoint( |
| model: nn.Module, |
| ckpts_folder: Union[str, os.PathLike], |
| ): |
| checkpoint_files = [ |
| ckpts_folder + "/" + f for f in os.listdir(ckpts_folder) if f.endswith(".pt") |
| ] |
|
|
| |
| model_keys = sorted((list(model.state_dict().keys()))) |
| suffix = r"\.pt$" |
| ckpt_keys = sorted( |
| [re.sub(suffix, "", f) for f in os.listdir(ckpts_folder) if f.endswith(".pt")] |
| ) |
| assert len(model_keys) == len( |
| ckpt_keys |
| ), f"The number of checkpoint files do not match the model. \n Model has {len(model_keys)} keys, while finding {len(ckpt_keys)} checkpoint files in the folder." |
| for key1, key2 in zip(model_keys, ckpt_keys): |
| assert ( |
| key1 == key2 |
| ), f"The checkpoint files do not match the model. \nmodel key {key1} != checkpoint key {key2}" |
|
|
| with tqdm(total=len(checkpoint_files)) as pbar: |
| pbar.set_description("Loading checkpoint shards") |
| for checkpoint_file in checkpoint_files: |
| checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) |
| model.load_state_dict(checkpoint, strict=False) |
| |
| del checkpoint |
| gc.collect() |
| pbar.update(1) |
| return model |
|
|
|
|
| def load_awq_model(model, checkpoint, w_bit, group_size, device): |
| q_config = {"zero_point": True, "q_group_size": group_size} |
| real_quantize_model_weight(model, w_bit, q_config, init_only=True) |
|
|
| if hasattr(model.config, "tie_encoder_decoder"): |
| model.config.tie_encoder_decoder = False |
| if hasattr(model.config, "tie_word_embeddings"): |
| model.config.tie_word_embeddings = False |
| if tinychat.utils.constants.mem_efficient_load: |
| assert os.path.isdir( |
| checkpoint |
| ), "You are in mem_efficient_load mode. \n Please set --load_quant the path to the folder containing all checkpoint files." |
| model = mem_efficient_load_checkpoint( |
| model, |
| checkpoint, |
| ).to(device) |
| else: |
| ckpt_version_check(checkpoint) |
| pbar = tqdm(range(1)) |
| pbar.set_description("Loading checkpoint") |
| for i in pbar: |
| model = load_checkpoint_and_dispatch( |
| model, |
| checkpoint, |
| no_split_module_classes=[ |
| "OPTDecoderLayer", |
| "LlamaDecoderLayer", |
| "BloomBlock", |
| "MPTBlock", |
| "DecoderLayer", |
| "CLIPEncoderLayer", |
| ], |
| ).to(device) |
| return model |
|
|
|
|
| def make_quant_linear(module, names, w_bit, groupsize, device, name=""): |
| if isinstance(module, WQLinear): |
| return |
| for attr in dir(module): |
| tmp = getattr(module, attr) |
| name1 = name + "." + attr if name != "" else attr |
| if name1 in names: |
| delattr(module, attr) |
| setattr( |
| module, |
| attr, |
| WQLinear( |
| w_bit, |
| groupsize, |
| tmp.in_features, |
| tmp.out_features, |
| tmp.bias is not None, |
| device, |
| dtype=tmp.weight.dtype, |
| ), |
| ) |
| for name1, child in module.named_children(): |
| make_quant_linear( |
| child, |
| names, |
| w_bit, |
| groupsize, |
| device, |
| name + "." + name1 if name != "" else name1, |
| ) |
|
|
|
|
| def find_layers(module, layers=[nn.Linear], name=""): |
| if type(module) in layers: |
| return {name: module} |
| res = {} |
| for name1, child in module.named_children(): |
| res.update( |
| find_layers( |
| child, layers=layers, name=name + "." + name1 if name != "" else name1 |
| ) |
| ) |
| return res |
|
|
|
|
| def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device): |
| layers = find_layers(model) |
| for name in ["lm_head"]: |
| if name in layers: |
| del layers[name] |
| make_quant_linear(model, layers, w_bit, group_size, device) |
| del layers |
|
|
| if tinychat.utils.constants.mem_efficient_load: |
| |
| assert os.path.isdir( |
| checkpoint |
| ), "You are in mem_efficient_load mode. \n Please set --load_quant the path to the folder containing all checkpoint files." |
| model = mem_efficient_load_checkpoint( |
| model, |
| checkpoint, |
| ) |
| else: |
| ckpt_version_check(checkpoint) |
| pbar = tqdm(range(1)) |
| pbar.set_description("Loading checkpoint") |
| for i in pbar: |
| if checkpoint.endswith(".safetensors"): |
| from safetensors.torch import load_file as safe_load |
|
|
| model.load_state_dict(safe_load(checkpoint)) |
| else: |
| model.load_state_dict(torch.load(checkpoint)) |
|
|
| return model.to(device) |
|
|