quantization / llm-awq /tinychat /utils /load_quant.py
chen459664's picture
Add files using upload-large-folder tool
8c6338f verified
import gc
import os
import re
from typing import Union, List
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from awq.quantize.qmodule import WQLinear
from tqdm import tqdm
import tinychat.utils.constants
version_message = """
[Warning] The awq quantized checkpoint seems to be in v1 format.
If the model cannot be loaded successfully, please use the latest awq library to re-quantized the model, or repack the current checkpoint with tinychat/offline-weight-repacker.py
"""
def ckpt_version_check(quant_path):
if not quant_path.endswith("v2.pt"):
print(version_message)
def mem_efficient_load_checkpoint(
model: nn.Module,
ckpts_folder: Union[str, os.PathLike],
):
checkpoint_files = [
ckpts_folder + "/" + f for f in os.listdir(ckpts_folder) if f.endswith(".pt")
]
# Check if the ckpts match the model
model_keys = sorted((list(model.state_dict().keys())))
suffix = r"\.pt$"
ckpt_keys = sorted(
[re.sub(suffix, "", f) for f in os.listdir(ckpts_folder) if f.endswith(".pt")]
)
assert len(model_keys) == len(
ckpt_keys
), f"The number of checkpoint files do not match the model. \n Model has {len(model_keys)} keys, while finding {len(ckpt_keys)} checkpoint files in the folder."
for key1, key2 in zip(model_keys, ckpt_keys):
assert (
key1 == key2
), f"The checkpoint files do not match the model. \nmodel key {key1} != checkpoint key {key2}"
with tqdm(total=len(checkpoint_files)) as pbar:
pbar.set_description("Loading checkpoint shards")
for checkpoint_file in checkpoint_files:
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint, strict=False)
# Force Python to clean up.
del checkpoint
gc.collect()
pbar.update(1)
return model
def load_awq_model(model, checkpoint, w_bit, group_size, device):
q_config = {"zero_point": True, "q_group_size": group_size}
real_quantize_model_weight(model, w_bit, q_config, init_only=True)
if hasattr(model.config, "tie_encoder_decoder"):
model.config.tie_encoder_decoder = False
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
if tinychat.utils.constants.mem_efficient_load:
assert os.path.isdir(
checkpoint
), "You are in mem_efficient_load mode. \n Please set --load_quant the path to the folder containing all checkpoint files."
model = mem_efficient_load_checkpoint(
model,
checkpoint,
).to(device)
else:
ckpt_version_check(checkpoint)
pbar = tqdm(range(1))
pbar.set_description("Loading checkpoint")
for i in pbar:
model = load_checkpoint_and_dispatch(
model,
checkpoint,
no_split_module_classes=[
"OPTDecoderLayer",
"LlamaDecoderLayer",
"BloomBlock",
"MPTBlock",
"DecoderLayer",
"CLIPEncoderLayer",
],
).to(device)
return model
def make_quant_linear(module, names, w_bit, groupsize, device, name=""):
if isinstance(module, WQLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + "." + attr if name != "" else attr
if name1 in names:
delattr(module, attr)
setattr(
module,
attr,
WQLinear(
w_bit,
groupsize,
tmp.in_features,
tmp.out_features,
tmp.bias is not None,
device,
dtype=tmp.weight.dtype,
),
)
for name1, child in module.named_children():
make_quant_linear(
child,
names,
w_bit,
groupsize,
device,
name + "." + name1 if name != "" else name1,
)
def find_layers(module, layers=[nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device):
layers = find_layers(model)
for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant_linear(model, layers, w_bit, group_size, device)
del layers
if tinychat.utils.constants.mem_efficient_load:
# TODO: mem-efficient load for llama
assert os.path.isdir(
checkpoint
), "You are in mem_efficient_load mode. \n Please set --load_quant the path to the folder containing all checkpoint files."
model = mem_efficient_load_checkpoint(
model,
checkpoint,
)
else:
ckpt_version_check(checkpoint)
pbar = tqdm(range(1))
pbar.set_description("Loading checkpoint")
for i in pbar:
if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
return model.to(device)