Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- llm-awq/awq.egg-info/SOURCES.txt +81 -0
- llm-awq/awq.egg-info/dependency_links.txt +1 -0
- llm-awq/awq.egg-info/requires.txt +16 -0
- llm-awq/awq.egg-info/top_level.txt +3 -0
- llm-awq/awq/quantize/qmodule.py +235 -0
- llm-awq/awq/quantize/w8a8_linear.py +276 -0
- llm-awq/awq/utils/lm_eval_adaptor.py +116 -0
- llm-awq/awq/utils/utils.py +51 -0
- llm-awq/examples/convert_to_hf.py +69 -0
- llm-awq/examples/llava_demo.ipynb +0 -0
- llm-awq/figures/vila-logo.jpg +0 -0
- llm-awq/scripts/codellama_example.sh +25 -0
- llm-awq/scripts/llama2_example.sh +25 -0
- llm-awq/scripts/llama3_example.sh +25 -0
- llm-awq/scripts/llama_example.sh +25 -0
- llm-awq/scripts/opt_example.sh +25 -0
- llm-awq/scripts/qwen_example.sh +25 -0
- llm-awq/scripts/starcoder_example.sh +25 -0
- llm-awq/scripts/vicuna_example.sh +25 -0
- llm-awq/tinychat/benchmark.py +379 -0
- llm-awq/tinychat/demo.py +283 -0
- llm-awq/tinychat/internvl_benchmark.py +167 -0
- llm-awq/tinychat/split_ckpt.py +51 -0
- llm-awq/tinychat/vila15_demo.py +264 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_sot.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_yor.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_ibo.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_kin.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sna.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sot.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_xho.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yaml +34 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yor.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_zul.yaml +4 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_ibo.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lin.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lug.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_orm.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sna.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sot.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_swa.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_twi.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_vai.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_wol.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_xho.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_yor.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_amh.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_eng.yaml +7 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_ewe.yaml +6 -0
- lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_fra.yaml +6 -0
llm-awq/awq.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
awq/entry.py
|
| 5 |
+
awq.egg-info/PKG-INFO
|
| 6 |
+
awq.egg-info/SOURCES.txt
|
| 7 |
+
awq.egg-info/dependency_links.txt
|
| 8 |
+
awq.egg-info/requires.txt
|
| 9 |
+
awq.egg-info/top_level.txt
|
| 10 |
+
awq/kernels/setup.py
|
| 11 |
+
awq/kernels/csrc/attention/setup.py
|
| 12 |
+
awq/quantize/__init__.py
|
| 13 |
+
awq/quantize/auto_clip.py
|
| 14 |
+
awq/quantize/auto_scale.py
|
| 15 |
+
awq/quantize/pre_quant.py
|
| 16 |
+
awq/quantize/qmodule.py
|
| 17 |
+
awq/quantize/quantizer.py
|
| 18 |
+
awq/quantize/smooth.py
|
| 19 |
+
awq/quantize/w8a8_linear.py
|
| 20 |
+
awq/utils/__init__.py
|
| 21 |
+
awq/utils/calib_data.py
|
| 22 |
+
awq/utils/lm_eval_adaptor.py
|
| 23 |
+
awq/utils/module.py
|
| 24 |
+
awq/utils/parallel.py
|
| 25 |
+
awq/utils/utils.py
|
| 26 |
+
tinychat/benchmark.py
|
| 27 |
+
tinychat/demo.py
|
| 28 |
+
tinychat/internvl_benchmark.py
|
| 29 |
+
tinychat/internvl_demo.py
|
| 30 |
+
tinychat/nvila_benchmark.py
|
| 31 |
+
tinychat/nvila_demo.py
|
| 32 |
+
tinychat/offline-weight-repacker.py
|
| 33 |
+
tinychat/split_ckpt.py
|
| 34 |
+
tinychat/vila10_demo.py
|
| 35 |
+
tinychat/vila15_demo.py
|
| 36 |
+
tinychat/models/__init__.py
|
| 37 |
+
tinychat/models/falcon.py
|
| 38 |
+
tinychat/models/internvl3.py
|
| 39 |
+
tinychat/models/llama.py
|
| 40 |
+
tinychat/models/llava_llama.py
|
| 41 |
+
tinychat/models/mpt.py
|
| 42 |
+
tinychat/models/nvila_qwen2.py
|
| 43 |
+
tinychat/models/qwen2.py
|
| 44 |
+
tinychat/models/vila_llama.py
|
| 45 |
+
tinychat/models/internvl/configuration_internvl.py
|
| 46 |
+
tinychat/models/internvl/conversation.py
|
| 47 |
+
tinychat/models/internvl/internvit.py
|
| 48 |
+
tinychat/models/internvl/media.py
|
| 49 |
+
tinychat/models/llava_base/llava_arch.py
|
| 50 |
+
tinychat/models/llava_base/multimodal_encoder/builder.py
|
| 51 |
+
tinychat/models/llava_base/multimodal_encoder/clip_encoder.py
|
| 52 |
+
tinychat/models/llava_base/multimodal_projector/builder.py
|
| 53 |
+
tinychat/models/nvila/builder.py
|
| 54 |
+
tinychat/models/nvila/configuration_llava.py
|
| 55 |
+
tinychat/models/nvila/llava_arch.py
|
| 56 |
+
tinychat/modules/__init__.py
|
| 57 |
+
tinychat/modules/fused_attn.py
|
| 58 |
+
tinychat/modules/fused_internencoder.py
|
| 59 |
+
tinychat/modules/fused_mlp.py
|
| 60 |
+
tinychat/modules/fused_norm.py
|
| 61 |
+
tinychat/modules/fused_siglipdecoder.py
|
| 62 |
+
tinychat/modules/fused_vision_attn.py
|
| 63 |
+
tinychat/serve/controller.py
|
| 64 |
+
tinychat/serve/gradio_web_server.py
|
| 65 |
+
tinychat/serve/llava_conv.py
|
| 66 |
+
tinychat/serve/model_worker.py
|
| 67 |
+
tinychat/serve/model_worker_new.py
|
| 68 |
+
tinychat/stream_generators/NVILA_stream_gen.py
|
| 69 |
+
tinychat/stream_generators/__init__.py
|
| 70 |
+
tinychat/stream_generators/internvl_stream_gen.py
|
| 71 |
+
tinychat/stream_generators/llava_stream_gen.py
|
| 72 |
+
tinychat/stream_generators/stream_gen.py
|
| 73 |
+
tinychat/utils/__init__.py
|
| 74 |
+
tinychat/utils/constants.py
|
| 75 |
+
tinychat/utils/conversation_utils.py
|
| 76 |
+
tinychat/utils/input_metadata.py
|
| 77 |
+
tinychat/utils/llava_image_processing.py
|
| 78 |
+
tinychat/utils/load_quant.py
|
| 79 |
+
tinychat/utils/log_utils.py
|
| 80 |
+
tinychat/utils/prompt_templates.py
|
| 81 |
+
tinychat/utils/tune.py
|
llm-awq/awq.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
llm-awq/awq.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.34.2
|
| 2 |
+
sentencepiece
|
| 3 |
+
tokenizers>=0.12.1
|
| 4 |
+
torch==2.3.0
|
| 5 |
+
torchvision==0.18.0
|
| 6 |
+
transformers==4.46.0
|
| 7 |
+
lm_eval==0.3.0
|
| 8 |
+
texttable
|
| 9 |
+
toml
|
| 10 |
+
attributedict
|
| 11 |
+
protobuf
|
| 12 |
+
gradio==3.35.2
|
| 13 |
+
gradio_client==0.2.9
|
| 14 |
+
fastapi
|
| 15 |
+
uvicorn
|
| 16 |
+
pydantic==1.10.19
|
llm-awq/awq.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
awq
|
| 2 |
+
figures
|
| 3 |
+
tinychat
|
llm-awq/awq/quantize/qmodule.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import awq_inference_engine # with CUDA kernels
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def make_divisible(c, divisor):
|
| 8 |
+
return (c + divisor - 1) // divisor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
|
| 12 |
+
if group_size >= 128:
|
| 13 |
+
size_multiplier = 1
|
| 14 |
+
elif group_size == 64:
|
| 15 |
+
size_multiplier = 2
|
| 16 |
+
elif group_size == 32:
|
| 17 |
+
size_multiplier = 4
|
| 18 |
+
else:
|
| 19 |
+
raise NotImplementedError
|
| 20 |
+
|
| 21 |
+
base_width = make_divisible(in_features // group_size, pack_num)
|
| 22 |
+
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
|
| 23 |
+
return base_width
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def pack_intweight(unpacked_qweight, interleave, kstride):
|
| 27 |
+
# unpacked_qweight: [N, K]
|
| 28 |
+
N = unpacked_qweight.shape[0]
|
| 29 |
+
K = unpacked_qweight.shape[1]
|
| 30 |
+
|
| 31 |
+
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
|
| 32 |
+
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
|
| 33 |
+
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
|
| 34 |
+
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
|
| 35 |
+
|
| 36 |
+
# reorder each 8 weights for fast dequantization
|
| 37 |
+
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
|
| 38 |
+
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
|
| 39 |
+
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
|
| 40 |
+
Packed_Kernel = Packed_Kernel.reshape(N, K)
|
| 41 |
+
|
| 42 |
+
# interleaving every four rows
|
| 43 |
+
Packed_Kernel = Packed_Kernel.reshape(
|
| 44 |
+
N // interleave, interleave, K // kstride, kstride
|
| 45 |
+
)
|
| 46 |
+
# N // 4, K // 64, 4, 64
|
| 47 |
+
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
|
| 48 |
+
Packed_Kernel = Packed_Kernel.reshape(
|
| 49 |
+
N // interleave, K // kstride, kstride, interleave
|
| 50 |
+
)
|
| 51 |
+
# Packing -> (N // 4, K // 64, 64)
|
| 52 |
+
Packed_Kernel = (
|
| 53 |
+
Packed_Kernel[..., 0]
|
| 54 |
+
| (Packed_Kernel[..., 1] << 4)
|
| 55 |
+
| (Packed_Kernel[..., 2] << 8)
|
| 56 |
+
| (Packed_Kernel[..., 3] << 12)
|
| 57 |
+
)
|
| 58 |
+
# reshape to (N // 4, K), FP16 format
|
| 59 |
+
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
|
| 60 |
+
qweight = (
|
| 61 |
+
torch.tensor(Packed_Kernel.astype("int16"))
|
| 62 |
+
.to(unpacked_qweight.device)
|
| 63 |
+
.contiguous()
|
| 64 |
+
)
|
| 65 |
+
return qweight
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ScaledActivation(nn.Module):
|
| 69 |
+
def __init__(self, module, scales):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.act = module
|
| 72 |
+
self.scales = nn.Parameter(scales.data)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class WQLinear(nn.Module):
|
| 79 |
+
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, dtype=torch.float16):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
if w_bit not in [4]:
|
| 83 |
+
raise NotImplementedError("Only 4-bit are supported for now.")
|
| 84 |
+
|
| 85 |
+
self.in_features = in_features
|
| 86 |
+
self.out_features = out_features
|
| 87 |
+
self.w_bit = w_bit
|
| 88 |
+
self.group_size = group_size if group_size != -1 else in_features
|
| 89 |
+
self.split_k_iters = 8
|
| 90 |
+
self.interleave = 4
|
| 91 |
+
# quick sanity check (make sure aligment)
|
| 92 |
+
assert self.in_features % self.group_size == 0
|
| 93 |
+
assert out_features % (32 // self.w_bit) == 0
|
| 94 |
+
pack_num = 32 // self.w_bit
|
| 95 |
+
int16_pack_num = 16 // self.w_bit
|
| 96 |
+
|
| 97 |
+
assert out_features % (self.interleave) == 0
|
| 98 |
+
self.register_buffer(
|
| 99 |
+
"qweight",
|
| 100 |
+
torch.zeros(
|
| 101 |
+
(
|
| 102 |
+
out_features // self.interleave,
|
| 103 |
+
in_features // int16_pack_num * self.interleave,
|
| 104 |
+
),
|
| 105 |
+
dtype=torch.int16,
|
| 106 |
+
device=dev,
|
| 107 |
+
),
|
| 108 |
+
)
|
| 109 |
+
self.register_buffer(
|
| 110 |
+
"scales",
|
| 111 |
+
torch.zeros(
|
| 112 |
+
(
|
| 113 |
+
calculate_zeros_width(in_features, self.group_size) * pack_num,
|
| 114 |
+
out_features,
|
| 115 |
+
),
|
| 116 |
+
dtype=dtype,
|
| 117 |
+
device=dev,
|
| 118 |
+
),
|
| 119 |
+
)
|
| 120 |
+
self.register_buffer(
|
| 121 |
+
"scaled_zeros",
|
| 122 |
+
torch.zeros(
|
| 123 |
+
(
|
| 124 |
+
calculate_zeros_width(in_features, self.group_size) * pack_num,
|
| 125 |
+
out_features,
|
| 126 |
+
),
|
| 127 |
+
dtype=dtype,
|
| 128 |
+
device=dev,
|
| 129 |
+
),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if bias:
|
| 133 |
+
self.register_buffer(
|
| 134 |
+
"bias", torch.zeros((out_features), dtype=dtype, device=dev)
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
self.bias = None
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def from_linear(
|
| 141 |
+
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
|
| 142 |
+
):
|
| 143 |
+
awq_linear = cls(
|
| 144 |
+
w_bit,
|
| 145 |
+
group_size,
|
| 146 |
+
linear.in_features,
|
| 147 |
+
linear.out_features,
|
| 148 |
+
linear.bias is not None,
|
| 149 |
+
linear.weight.device,
|
| 150 |
+
dtype=linear.weight.data.dtype
|
| 151 |
+
)
|
| 152 |
+
if init_only: # just prepare for loading sd
|
| 153 |
+
return awq_linear
|
| 154 |
+
|
| 155 |
+
# need scales and zeros info for real quantization
|
| 156 |
+
assert scales is not None and zeros is not None
|
| 157 |
+
scale_zeros = zeros * scales
|
| 158 |
+
|
| 159 |
+
dtype = scales.dtype
|
| 160 |
+
|
| 161 |
+
pack_num = 32 // awq_linear.w_bit
|
| 162 |
+
qscales = torch.zeros(
|
| 163 |
+
(
|
| 164 |
+
scales.shape[0],
|
| 165 |
+
calculate_zeros_width(linear.in_features, group_size) * pack_num,
|
| 166 |
+
),
|
| 167 |
+
dtype=dtype,
|
| 168 |
+
device=scales.device,
|
| 169 |
+
)
|
| 170 |
+
qscales[:, : scales.shape[1]] = scales
|
| 171 |
+
# awq_linear.scales = scales.clone().half()
|
| 172 |
+
awq_linear.scales = qscales.transpose(1, 0).contiguous()
|
| 173 |
+
if linear.bias is not None:
|
| 174 |
+
awq_linear.bias = linear.bias.clone().to(dtype)
|
| 175 |
+
|
| 176 |
+
intweight = []
|
| 177 |
+
for idx in range(awq_linear.in_features):
|
| 178 |
+
intweight.append(
|
| 179 |
+
torch.round(
|
| 180 |
+
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
|
| 181 |
+
/ qscales[:, idx // group_size]
|
| 182 |
+
).to(torch.int)[:, None]
|
| 183 |
+
)
|
| 184 |
+
intweight = torch.cat(intweight, dim=1)
|
| 185 |
+
# intweight = intweight.t().contiguous()
|
| 186 |
+
intweight = intweight.to(dtype=torch.int32)
|
| 187 |
+
awq_linear.qweight = pack_intweight(
|
| 188 |
+
intweight.contiguous(), interleave=4, kstride=64
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
zeros = zeros.to(dtype=torch.int32)
|
| 192 |
+
scaled_zeros = torch.zeros_like(qscales)
|
| 193 |
+
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
|
| 194 |
+
scaled_zeros[:, : scales.shape[1]] = -(
|
| 195 |
+
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
|
| 196 |
+
).to(dtype)
|
| 197 |
+
awq_linear.scaled_zeros = scaled_zeros.transpose(1, 0).contiguous()
|
| 198 |
+
|
| 199 |
+
return awq_linear
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def forward(self, x):
|
| 203 |
+
# out_shape = x.shape[:-1] + (self.out_features,)
|
| 204 |
+
# inputs = x.reshape(-1, x.shape[-1])
|
| 205 |
+
inputs = x
|
| 206 |
+
if inputs.numel() / inputs.shape[-1] < 8:
|
| 207 |
+
out = awq_inference_engine.gemv_forward_cuda_new(
|
| 208 |
+
inputs,
|
| 209 |
+
self.qweight,
|
| 210 |
+
self.scales,
|
| 211 |
+
self.scaled_zeros,
|
| 212 |
+
inputs.numel() // inputs.shape[-1],
|
| 213 |
+
self.out_features,
|
| 214 |
+
self.in_features,
|
| 215 |
+
self.group_size,
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
out = awq_inference_engine.gemm_forward_cuda_new(
|
| 219 |
+
inputs, self.qweight, self.scales, self.scaled_zeros
|
| 220 |
+
) # - 8.0 * self.scales)
|
| 221 |
+
out = out + self.bias if self.bias is not None else out
|
| 222 |
+
# print(out)
|
| 223 |
+
# assert 0
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
def extra_repr(self) -> str:
|
| 227 |
+
return (
|
| 228 |
+
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
|
| 229 |
+
self.in_features,
|
| 230 |
+
self.out_features,
|
| 231 |
+
self.bias is not None,
|
| 232 |
+
self.w_bit,
|
| 233 |
+
self.group_size,
|
| 234 |
+
)
|
| 235 |
+
)
|
llm-awq/awq/quantize/w8a8_linear.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from qserve (https://github.com/mit-han-lab/qserve/tree/main) and modified by Yuming Lou
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
from torch.nn import Parameter
|
| 6 |
+
import awq_inference_engine
|
| 7 |
+
import torch
|
| 8 |
+
import gc
|
| 9 |
+
from awq.utils.module import set_op_by_name
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class W8A8OF16LinearStaticScale(torch.nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
out_features: int,
|
| 18 |
+
bias: bool = True,
|
| 19 |
+
scale: Union[torch.tensor, float] = 1.0,
|
| 20 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# Keep input parameters
|
| 25 |
+
self.in_features = in_features
|
| 26 |
+
self.out_features = out_features
|
| 27 |
+
# size [1] or size [oc]
|
| 28 |
+
self.register_buffer(
|
| 29 |
+
"dequant_scale", torch.ones(out_features, dtype=torch.half)
|
| 30 |
+
)
|
| 31 |
+
# Parameters.
|
| 32 |
+
# NOTE: torch.nn.functional.linear performs XA^T + b and as a result
|
| 33 |
+
# we allocate the transpose.
|
| 34 |
+
self.create_weights()
|
| 35 |
+
|
| 36 |
+
if bias:
|
| 37 |
+
self.bias = torch.empty(
|
| 38 |
+
self.out_features,
|
| 39 |
+
device=torch.cuda.current_device(),
|
| 40 |
+
dtype=torch.float16,
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
self.register_parameter("bias", None)
|
| 44 |
+
|
| 45 |
+
def create_weights(self) -> None:
|
| 46 |
+
self.register_buffer(
|
| 47 |
+
"weight",
|
| 48 |
+
torch.empty(
|
| 49 |
+
self.out_features,
|
| 50 |
+
self.in_features,
|
| 51 |
+
dtype=torch.int8,
|
| 52 |
+
requires_grad=False,
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def apply_weights(
|
| 57 |
+
self,
|
| 58 |
+
x: torch.Tensor,
|
| 59 |
+
bias: Optional[torch.Tensor],
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
def forward(self, input_):
|
| 64 |
+
# Matrix multiply.
|
| 65 |
+
output = self.apply_weights(input_, self.bias)
|
| 66 |
+
output_bias = self.bias
|
| 67 |
+
return output, output_bias
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class W8A8OF16LinearDynamicInputScale(W8A8OF16LinearStaticScale):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
in_features: int,
|
| 74 |
+
out_features: int,
|
| 75 |
+
bias: bool = True,
|
| 76 |
+
scale: Union[torch.tensor, float] = 1.0,
|
| 77 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(
|
| 80 |
+
in_features=in_features,
|
| 81 |
+
out_features=out_features,
|
| 82 |
+
bias=bias,
|
| 83 |
+
scale=scale,
|
| 84 |
+
params_dtype=params_dtype,
|
| 85 |
+
)
|
| 86 |
+
if bias:
|
| 87 |
+
self.apply_weights = self.apply_weights_bias
|
| 88 |
+
else:
|
| 89 |
+
self.apply_weights = self.apply_weights_no_bias
|
| 90 |
+
|
| 91 |
+
#W bias. Fused bias and W8A8 GEMM
|
| 92 |
+
def apply_weights_bias(
|
| 93 |
+
self,
|
| 94 |
+
# [batch, tokens, channels]
|
| 95 |
+
x: torch.Tensor,
|
| 96 |
+
# [batch * tokens]
|
| 97 |
+
input_scale: torch.Tensor,
|
| 98 |
+
output_buffer: torch.Tensor,
|
| 99 |
+
bias: torch.Tensor = None,
|
| 100 |
+
):
|
| 101 |
+
x_shape = x.shape
|
| 102 |
+
if len(x.shape) > 2:
|
| 103 |
+
assert 0, "Not implemented"
|
| 104 |
+
x = x.view(-1, x_shape[-1])
|
| 105 |
+
# If use awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda
|
| 106 |
+
awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda(
|
| 107 |
+
x, self.weight, self.dequant_scale, input_scale, output_buffer, bias
|
| 108 |
+
)
|
| 109 |
+
if len(x.shape) > 2:
|
| 110 |
+
assert 0, "Not implemented 2"
|
| 111 |
+
output_buffer = output_buffer.view(*x_shape[:-1], -1)
|
| 112 |
+
|
| 113 |
+
#W/H bias. W8A8 GEMM
|
| 114 |
+
def apply_weights_no_bias(
|
| 115 |
+
self,
|
| 116 |
+
# [batch, tokens, channels]
|
| 117 |
+
x: torch.Tensor,
|
| 118 |
+
# [batch * tokens]
|
| 119 |
+
input_scale: torch.Tensor,
|
| 120 |
+
output_buffer: torch.Tensor,
|
| 121 |
+
bias: torch.Tensor = None,
|
| 122 |
+
):
|
| 123 |
+
x_shape = x.shape
|
| 124 |
+
if len(x.shape) > 2:
|
| 125 |
+
assert 0, "Not implemented"
|
| 126 |
+
x = x.view(-1, x_shape[-1])
|
| 127 |
+
# If use awq_inference_engine.w8a8_gemm_forward_cuda
|
| 128 |
+
awq_inference_engine.w8a8_gemm_forward_cuda(
|
| 129 |
+
x, self.weight, self.dequant_scale, input_scale, output_buffer
|
| 130 |
+
)
|
| 131 |
+
if len(x.shape) > 2:
|
| 132 |
+
assert 0, "Not implemented 2"
|
| 133 |
+
output_buffer = output_buffer.view(*x_shape[:-1], -1)
|
| 134 |
+
|
| 135 |
+
def forward(self, input_, input_scale, output_buffer):
|
| 136 |
+
# Matrix multiply.
|
| 137 |
+
self.apply_weights(input_, input_scale, output_buffer, self.bias)
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def from_linear(
|
| 141 |
+
cls,
|
| 142 |
+
linear,
|
| 143 |
+
init_only=False,
|
| 144 |
+
s1_scale=None,
|
| 145 |
+
fc1=False,
|
| 146 |
+
):
|
| 147 |
+
q_linear = cls(
|
| 148 |
+
linear.in_features,
|
| 149 |
+
linear.out_features,
|
| 150 |
+
linear.bias is not None,
|
| 151 |
+
)
|
| 152 |
+
if init_only: # just prepare for loading sd
|
| 153 |
+
return q_linear
|
| 154 |
+
if s1_scale is None:
|
| 155 |
+
s1_scale, _ = torch.max(abs(linear.weight.data), dim=-1, keepdim=True)
|
| 156 |
+
s1_scale = s1_scale.clamp_(min=1e-5).div_(127)
|
| 157 |
+
|
| 158 |
+
if linear.bias is not None:
|
| 159 |
+
q_linear.bias = linear.bias.clone().half().contiguous().cuda()
|
| 160 |
+
## Quantize the weights
|
| 161 |
+
# ---- Quantize the weights to int8 ---- #
|
| 162 |
+
linear_weight = linear.weight.data # OC, IC
|
| 163 |
+
linear_weight = linear_weight.div_(s1_scale.to(linear_weight.device))
|
| 164 |
+
linear_weight = linear_weight.round_().to(torch.int8)
|
| 165 |
+
|
| 166 |
+
q_linear.weight.data[:, :] = linear_weight.half().contiguous().cuda()
|
| 167 |
+
|
| 168 |
+
# ---- Pack the scales ---- #
|
| 169 |
+
q_linear.dequant_scale.data[:] = (
|
| 170 |
+
s1_scale.reshape(-1).half().contiguous().cuda()
|
| 171 |
+
)
|
| 172 |
+
return q_linear.cuda()
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def from_qkv(
|
| 176 |
+
cls,
|
| 177 |
+
q,
|
| 178 |
+
k,
|
| 179 |
+
v,
|
| 180 |
+
init_only=False,
|
| 181 |
+
s1_scale=None,
|
| 182 |
+
):
|
| 183 |
+
q_linear = cls(
|
| 184 |
+
q.in_features,
|
| 185 |
+
q.out_features + k.out_features + v.out_features,
|
| 186 |
+
q.bias is not None,
|
| 187 |
+
)
|
| 188 |
+
if init_only: # just prepare for loading sd
|
| 189 |
+
return q_linear
|
| 190 |
+
weight = torch.cat([q.weight.data, k.weight.data, v.weight.data], dim=0)
|
| 191 |
+
|
| 192 |
+
if s1_scale is None:
|
| 193 |
+
s1_scale, _ = torch.max(abs(weight), dim=-1, keepdim=True)
|
| 194 |
+
s1_scale = s1_scale.clamp_(min=1e-5).div_(127)
|
| 195 |
+
|
| 196 |
+
if q.bias is not None:
|
| 197 |
+
bias = torch.cat([q.bias, k.bias, v.bias], dim=0)
|
| 198 |
+
q_linear.bias = bias.clone().half().contiguous().cuda()
|
| 199 |
+
# ---- Quantize the weights to int8 ---- #
|
| 200 |
+
weight = weight.div_(s1_scale.to(weight.device))
|
| 201 |
+
weight = weight.round_().to(torch.int8)
|
| 202 |
+
|
| 203 |
+
q_linear.weight.data[:, :] = weight.contiguous().cuda()
|
| 204 |
+
|
| 205 |
+
# ---- Pack the scales ---- #
|
| 206 |
+
q_linear.dequant_scale.data[:] = (
|
| 207 |
+
s1_scale.reshape(q.out_features + k.out_features + v.out_features)
|
| 208 |
+
.half()
|
| 209 |
+
.contiguous().cuda()
|
| 210 |
+
)
|
| 211 |
+
return q_linear.cuda()
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class FakeW8A8Linear(torch.nn.Module):
|
| 215 |
+
def __init__(
|
| 216 |
+
self, in_features: int, out_features: int, bias: bool = True, wbit: int = 8
|
| 217 |
+
):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.weight = torch.nn.Parameter(
|
| 220 |
+
torch.empty(out_features, in_features, dtype=torch.half)
|
| 221 |
+
)
|
| 222 |
+
if bias:
|
| 223 |
+
self.bias = torch.nn.Parameter(
|
| 224 |
+
torch.empty(1, out_features, dtype=torch.half)
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
self.bias = None
|
| 228 |
+
self.wbit = wbit
|
| 229 |
+
self.maxv = 2 ** (wbit - 1) - 1
|
| 230 |
+
|
| 231 |
+
def forward(self, input):
|
| 232 |
+
t_shape = input.shape
|
| 233 |
+
input.view(-1, t_shape[-1])
|
| 234 |
+
scales = input.abs().max(dim=-1, keepdim=True)[0]
|
| 235 |
+
scales.clamp_(min=1e-5).div_(self.maxv)
|
| 236 |
+
input.div_(scales).round_().mul_(scales)
|
| 237 |
+
output = torch.functional.F.linear(input, self.weight, self.bias)
|
| 238 |
+
return output
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def from_linear(cls, linear: torch.nn.Linear, wbit=8):
|
| 242 |
+
fake_linear = cls(
|
| 243 |
+
linear.in_features, linear.out_features, linear.bias is not None, wbit
|
| 244 |
+
)
|
| 245 |
+
maxv = 2 ** (wbit - 1) - 1
|
| 246 |
+
scale = (
|
| 247 |
+
torch.max(abs(linear.weight.data.detach()), -1, keepdim=True)[0]
|
| 248 |
+
.clamp_(min=1e-5)
|
| 249 |
+
.div_(maxv)
|
| 250 |
+
)
|
| 251 |
+
weight = linear.weight.data / scale
|
| 252 |
+
weight = weight.round_()
|
| 253 |
+
weight = weight * scale
|
| 254 |
+
fake_linear.weight.copy_(weight.contiguous())
|
| 255 |
+
if linear.bias is not None:
|
| 256 |
+
fake_linear.bias.copy_(
|
| 257 |
+
linear.bias.detach().half().reshape(1, linear.out_features).contiguous()
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
linear.bias = None
|
| 261 |
+
del linear, scale, weight
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
return fake_linear
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def fake_quant(model, wbit=8):
|
| 267 |
+
for name, m in tqdm(
|
| 268 |
+
model.named_modules(),
|
| 269 |
+
desc="Fake quantizing",
|
| 270 |
+
total=len(list(model.named_modules())),
|
| 271 |
+
):
|
| 272 |
+
if isinstance(m, torch.nn.Linear):
|
| 273 |
+
FQlinear = FakeW8A8Linear.from_linear(m, wbit)
|
| 274 |
+
del m
|
| 275 |
+
torch.cuda.empty_cache()
|
| 276 |
+
set_op_by_name(model, name, FQlinear)
|
llm-awq/awq/utils/lm_eval_adaptor.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
import torch
|
| 3 |
+
from lm_eval.base import BaseLM
|
| 4 |
+
import fnmatch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LMEvalAdaptor(BaseLM):
|
| 8 |
+
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
assert isinstance(batch_size, int)
|
| 12 |
+
|
| 13 |
+
self.model_name = model_name
|
| 14 |
+
self.model = model
|
| 15 |
+
self.model.eval()
|
| 16 |
+
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
|
| 19 |
+
# assert isinstance(self.tokenizer, (
|
| 20 |
+
# transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
|
| 21 |
+
# transformers.T5Tokenizer, transformers.T5TokenizerFast,
|
| 22 |
+
# )), "this tokenizer has not been checked for compatibility yet!"
|
| 23 |
+
|
| 24 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 25 |
+
|
| 26 |
+
self._batch_size = batch_size
|
| 27 |
+
|
| 28 |
+
self._max_length = max_length
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def eot_token_id(self):
|
| 32 |
+
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 33 |
+
return self.tokenizer.eos_token_id
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def max_length(self):
|
| 37 |
+
if self._max_length != -1:
|
| 38 |
+
return self._max_length
|
| 39 |
+
if hasattr(self.model.config, "n_ctx"):
|
| 40 |
+
return self.model.config.n_ctx
|
| 41 |
+
elif hasattr(self.model.config, "max_position_embeddings"):
|
| 42 |
+
return self.model.config.max_position_embeddings
|
| 43 |
+
elif hasattr(self.model.config, "n_positions"):
|
| 44 |
+
return self.model.config.n_positions
|
| 45 |
+
elif "bloom" in self.model_name:
|
| 46 |
+
return 2048
|
| 47 |
+
elif "llama" in self.model_name:
|
| 48 |
+
return 2048 # TODO: did not check this
|
| 49 |
+
elif "mpt" in self.model_name:
|
| 50 |
+
return 2048
|
| 51 |
+
elif "falcon" in self.model_name:
|
| 52 |
+
return 2048
|
| 53 |
+
else:
|
| 54 |
+
print(self.model.config)
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def max_gen_toks(self):
|
| 59 |
+
return 256
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def batch_size(self):
|
| 63 |
+
return self._batch_size
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def device(self):
|
| 67 |
+
return "cuda"
|
| 68 |
+
|
| 69 |
+
def tok_encode(self, string: str):
|
| 70 |
+
return self.tokenizer.encode(string, add_special_tokens=False)
|
| 71 |
+
|
| 72 |
+
def tok_decode(self, tokens):
|
| 73 |
+
return self.tokenizer.decode(tokens)
|
| 74 |
+
|
| 75 |
+
def _model_call(self, inps):
|
| 76 |
+
"""
|
| 77 |
+
inps: a torch tensor of shape [batch, sequence]
|
| 78 |
+
the size of sequence may vary from call to call
|
| 79 |
+
|
| 80 |
+
returns: a torch tensor of shape [batch, sequence, vocab] with the
|
| 81 |
+
logits returned from the model
|
| 82 |
+
"""
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
if isinstance(
|
| 85 |
+
self.model,
|
| 86 |
+
transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
|
| 87 |
+
):
|
| 88 |
+
dec_inps = torch.cat(
|
| 89 |
+
[
|
| 90 |
+
torch.tensor(
|
| 91 |
+
self.model.generation_config.decoder_start_token_id,
|
| 92 |
+
)
|
| 93 |
+
.tile(len(inps), 1)
|
| 94 |
+
.to(inps),
|
| 95 |
+
inps,
|
| 96 |
+
],
|
| 97 |
+
dim=1,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
kwargs = {
|
| 101 |
+
"decoder_input_ids": dec_inps,
|
| 102 |
+
}
|
| 103 |
+
else:
|
| 104 |
+
kwargs = {}
|
| 105 |
+
out = self.model(inps, **kwargs)[0]
|
| 106 |
+
if (
|
| 107 |
+
"opt" in self.model_name
|
| 108 |
+
): # there are a few extra tokens in opt, which we should omit
|
| 109 |
+
return out[:, :, :50257]
|
| 110 |
+
else:
|
| 111 |
+
return out # [:, :, :self.tokenizer.vocab_size]
|
| 112 |
+
|
| 113 |
+
def _model_generate(self, context, max_length, eos_token_id):
|
| 114 |
+
return self.model.generate(
|
| 115 |
+
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
|
| 116 |
+
)
|
llm-awq/awq/utils/utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import accelerate
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_module_by_name_suffix(model, module_name: str):
|
| 6 |
+
for name, module in model.named_modules():
|
| 7 |
+
if name.endswith(module_name):
|
| 8 |
+
return module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def simple_dispatch_model(model, device_map):
|
| 12 |
+
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
|
| 13 |
+
|
| 14 |
+
if "" in device_map:
|
| 15 |
+
d = device_map[""]
|
| 16 |
+
model = model.to(torch.device(d))
|
| 17 |
+
model.hf_device_map = device_map
|
| 18 |
+
return model
|
| 19 |
+
|
| 20 |
+
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
|
| 21 |
+
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {
|
| 22 |
+
"cpu",
|
| 23 |
+
"disk",
|
| 24 |
+
}:
|
| 25 |
+
main_device = "cpu"
|
| 26 |
+
else:
|
| 27 |
+
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
| 28 |
+
|
| 29 |
+
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
|
| 30 |
+
prev_hook = None
|
| 31 |
+
for idx, (n, d) in enumerate(cpu_offload_group):
|
| 32 |
+
m = get_module_by_name_suffix(model, n)
|
| 33 |
+
_, prev_hook = accelerate.cpu_offload_with_hook(
|
| 34 |
+
m, execution_device=main_device, prev_module_hook=prev_hook
|
| 35 |
+
)
|
| 36 |
+
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
|
| 37 |
+
if len(cpu_offload_group) > 1:
|
| 38 |
+
get_module_by_name_suffix(
|
| 39 |
+
model, cpu_offload_group[0][0]
|
| 40 |
+
)._hf_hook.prev_module_hook = prev_hook
|
| 41 |
+
|
| 42 |
+
for n, d in device_map.items():
|
| 43 |
+
m = get_module_by_name_suffix(model, n)
|
| 44 |
+
if d != "cpu":
|
| 45 |
+
d = torch.device(d)
|
| 46 |
+
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
|
| 47 |
+
add_hook_to_module(m, hook)
|
| 48 |
+
accelerate.utils.modeling.retie_parameters(model, tied_params)
|
| 49 |
+
model.hf_device_map = device_map
|
| 50 |
+
|
| 51 |
+
return model
|
llm-awq/examples/convert_to_hf.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This script demonstrates how you can convert your model into HF format
|
| 2 |
+
# easily and push the quantized weights on the Hub using simple tools.
|
| 3 |
+
# Make sure to have transformers > 4.34 and that you have ran
|
| 4 |
+
# `huggingface-cli login` on your terminal before running this
|
| 5 |
+
# script
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
# This demo only support single GPU for now
|
| 10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 11 |
+
|
| 12 |
+
from transformers import AutoConfig, AwqConfig, AutoTokenizer
|
| 13 |
+
from huggingface_hub import HfApi
|
| 14 |
+
|
| 15 |
+
api = HfApi()
|
| 16 |
+
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--model_path", type=str, help="path of the original hf model", required=True
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--quantized_model_path",
|
| 23 |
+
type=str,
|
| 24 |
+
help="path of the quantized AWQ model",
|
| 25 |
+
required=True,
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--quantized_model_hub_path",
|
| 29 |
+
type=str,
|
| 30 |
+
help="path of the quantized AWQ model to push on the Hub",
|
| 31 |
+
required=True,
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--w_bit", type=int, default=4, help="")
|
| 34 |
+
parser.add_argument("--q_group_size", default=128, type=int)
|
| 35 |
+
parser.add_argument("--no_zero_point", action="store_true")
|
| 36 |
+
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
original_model_path = args.model_path
|
| 40 |
+
quantized_model_path = args.quantized_model_path
|
| 41 |
+
quantized_model_hub_path = args.quantized_model_hub_path
|
| 42 |
+
|
| 43 |
+
# Load the corresponding AWQConfig
|
| 44 |
+
quantization_config = AwqConfig(
|
| 45 |
+
bits=args.w_bit,
|
| 46 |
+
group_size=args.q_group_size,
|
| 47 |
+
zero_point=not args.no_zero_point,
|
| 48 |
+
backend="llm-awq",
|
| 49 |
+
version="gemv",
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Set the attribute `quantization_config` in model's config
|
| 53 |
+
config = AutoConfig.from_pretrained(original_model_path)
|
| 54 |
+
config.quantization_config = quantization_config
|
| 55 |
+
|
| 56 |
+
# Load tokenizer
|
| 57 |
+
tok = AutoTokenizer.from_pretrained(original_model_path)
|
| 58 |
+
|
| 59 |
+
# Push config and tokenizer
|
| 60 |
+
config.push_to_hub(quantized_model_hub_path)
|
| 61 |
+
tok.push_to_hub(quantized_model_hub_path)
|
| 62 |
+
|
| 63 |
+
# Upload model weights
|
| 64 |
+
api.upload_file(
|
| 65 |
+
path_or_fileobj=quantized_model_path,
|
| 66 |
+
path_in_repo="pytorch_model.bin",
|
| 67 |
+
repo_id=quantized_model_hub_path,
|
| 68 |
+
repo_type="model",
|
| 69 |
+
)
|
llm-awq/examples/llava_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llm-awq/figures/vila-logo.jpg
ADDED
|
llm-awq/scripts/codellama_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=CodeLlama-13b-Instruct
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/llama2_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=llama-2-7b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/llama3_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=llama3-8b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/llama_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=llama-7b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/opt_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=opt-6.7b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/opt/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/opt/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/opt/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/opt/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/qwen_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=qwen2.5-7b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/models/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/models/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/models/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/models/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/starcoder_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=starcoder
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/scripts/vicuna_example.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL=vicuna-7b
|
| 2 |
+
|
| 3 |
+
# run AWQ search (optional; we provided the pre-computed results)
|
| 4 |
+
python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
|
| 5 |
+
--w_bit 4 --q_group_size 128 \
|
| 6 |
+
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
|
| 7 |
+
|
| 8 |
+
# evaluate the AWQ quantize model (simulated pseudo quantization)
|
| 9 |
+
python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
|
| 10 |
+
--tasks wikitext \
|
| 11 |
+
--w_bit 4 --q_group_size 128 \
|
| 12 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 13 |
+
--q_backend fake
|
| 14 |
+
|
| 15 |
+
# generate real quantized weights (w4)
|
| 16 |
+
python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
|
| 17 |
+
--w_bit 4 --q_group_size 128 \
|
| 18 |
+
--load_awq awq_cache/$MODEL-w4-g128.pt \
|
| 19 |
+
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
|
| 20 |
+
|
| 21 |
+
# load and evaluate the real quantized model (smaller gpu memory usage)
|
| 22 |
+
python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
|
| 23 |
+
--tasks wikitext \
|
| 24 |
+
--w_bit 4 --q_group_size 128 \
|
| 25 |
+
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
|
llm-awq/tinychat/benchmark.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Usage:
|
| 2 |
+
# Please first install awq/kernels
|
| 3 |
+
# then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py
|
| 4 |
+
import argparse
|
| 5 |
+
import torch
|
| 6 |
+
import time
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
|
| 9 |
+
import tinychat.utils.constants
|
| 10 |
+
from tinychat.utils.load_quant import load_awq_model
|
| 11 |
+
from awq.quantize.quantizer import real_quantize_model_weight
|
| 12 |
+
from tinychat.utils.tune import (
|
| 13 |
+
tune_all_wqlinears,
|
| 14 |
+
device_warmup,
|
| 15 |
+
tune_llava_patch_embedding,
|
| 16 |
+
)
|
| 17 |
+
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def skip(*args, **kwargs):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model_type", type=str, default="LLaMa", help="type of the model"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--model_path",
|
| 31 |
+
type=str,
|
| 32 |
+
default="/data/llm/checkpoints/vicuna-hf/vicuna-7b",
|
| 33 |
+
help="path to the model",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--q_group_size", type=int, default=128)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--verbose",
|
| 38 |
+
default=False,
|
| 39 |
+
action="store_true",
|
| 40 |
+
help="Wheter to print more information.",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--max_seq_len",
|
| 44 |
+
type=int,
|
| 45 |
+
default=8192,
|
| 46 |
+
help="maximum sequence length for kv cache",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--flash_attn",
|
| 53 |
+
action="store_true",
|
| 54 |
+
help="whether to use flash attention",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--chunk_prefilling",
|
| 58 |
+
action="store_true",
|
| 59 |
+
help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--context_length",
|
| 63 |
+
type=list,
|
| 64 |
+
nargs="+",
|
| 65 |
+
help="The length of input. And if chunk_prefilling used, this serves as the length of tokens from history rounds.",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--question_length",
|
| 69 |
+
type=list,
|
| 70 |
+
nargs="+",
|
| 71 |
+
help="The length of new input. Only useful and necessary when benchmarking chunk_prefilling method",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--precision", type=str, default="W4A16", help="compute precision"
|
| 75 |
+
)
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
# some checks
|
| 78 |
+
assert (args.question_length is not None and args.chunk_prefilling) or (
|
| 79 |
+
not args.chunk_prefilling
|
| 80 |
+
), "If you want to benchmark chunk prefilling, you need specify the question length and context length"
|
| 81 |
+
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
|
| 82 |
+
token_num = 256
|
| 83 |
+
# We support fixing a certain kind of length
|
| 84 |
+
if args.chunk_prefilling:
|
| 85 |
+
if len(args.context_length) == 1 and len(args.question_length) > 1:
|
| 86 |
+
args.context_length = [
|
| 87 |
+
args.context_length[0] for _ in range(len(args.question_length))
|
| 88 |
+
]
|
| 89 |
+
elif len(args.question_length) == 1 and len(args.context_length) > 1:
|
| 90 |
+
args.question_length = [
|
| 91 |
+
args.question_length[0] for _ in range(len(args.context_length))
|
| 92 |
+
]
|
| 93 |
+
elif len(args.question_length) != len(args.context_length):
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"The number of items in the question_length and context_length is expected to be either one or equal!"
|
| 96 |
+
)
|
| 97 |
+
tinychat.utils.constants.max_batch_size = args.max_batch_size
|
| 98 |
+
tinychat.utils.constants.max_seq_len = args.max_seq_len
|
| 99 |
+
from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM
|
| 100 |
+
from tinychat.models.vila_llama import VilaLlamaForCausalLM
|
| 101 |
+
|
| 102 |
+
modeling_utils._init_weights = False
|
| 103 |
+
torch.nn.init.kaiming_uniform_ = skip
|
| 104 |
+
torch.nn.init.kaiming_normal_ = skip
|
| 105 |
+
torch.nn.init.uniform_ = skip
|
| 106 |
+
torch.nn.init.normal_ = skip
|
| 107 |
+
|
| 108 |
+
device = "cuda:0"
|
| 109 |
+
model_type_dict = {
|
| 110 |
+
"llama": LlamaForCausalLM,
|
| 111 |
+
"falcon": FalconForCausalLM,
|
| 112 |
+
"mpt": MPTForCausalLM,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
| 116 |
+
assert args.model_type.lower() in [
|
| 117 |
+
"llama",
|
| 118 |
+
"falcon",
|
| 119 |
+
"mpt",
|
| 120 |
+
"vila",
|
| 121 |
+
], "We only support llama & falcon & mpt & vila now"
|
| 122 |
+
if "vila" in args.model_type.lower():
|
| 123 |
+
model = VilaLlamaForCausalLM(config).half()
|
| 124 |
+
print(model)
|
| 125 |
+
if args.precision in ["W4A16"]:
|
| 126 |
+
real_quantize_model_weight(
|
| 127 |
+
model.llm,
|
| 128 |
+
w_bit=4,
|
| 129 |
+
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
|
| 130 |
+
init_only=True,
|
| 131 |
+
)
|
| 132 |
+
make_quant_attn(model.llm, device, args.flash_attn)
|
| 133 |
+
make_quant_norm(model.llm)
|
| 134 |
+
make_fused_mlp(model.llm)
|
| 135 |
+
model = model.to(device)
|
| 136 |
+
device_warmup(device)
|
| 137 |
+
tune_llava_patch_embedding(model.get_vision_tower(), device=device)
|
| 138 |
+
if not args.chunk_prefilling:
|
| 139 |
+
image_num = [
|
| 140 |
+
int(int("".join(i)) * 1 / 196) for i in args.context_length
|
| 141 |
+
] # consider about three thirds of the history tokens are images
|
| 142 |
+
if sum(image_num) > 0:
|
| 143 |
+
image_tensor = 2 * torch.rand((max(image_num), 3, 384, 384)) - 1
|
| 144 |
+
image_tensor = image_tensor.half().to(device)
|
| 145 |
+
else:
|
| 146 |
+
image_tensor = None
|
| 147 |
+
|
| 148 |
+
print("huggingface ckpt loaded")
|
| 149 |
+
|
| 150 |
+
# warming up
|
| 151 |
+
input_ids = [1 for _ in range(2048)]
|
| 152 |
+
inputs = torch.as_tensor([input_ids], device=device)
|
| 153 |
+
out = model(
|
| 154 |
+
inputs, start_pos=0, chunk_prefilling=args.chunk_prefilling
|
| 155 |
+
) # warmup
|
| 156 |
+
|
| 157 |
+
if not args.chunk_prefilling:
|
| 158 |
+
for i, context_length in enumerate(args.context_length):
|
| 159 |
+
context_length = int("".join(context_length))
|
| 160 |
+
time_lis = []
|
| 161 |
+
if image_num[i]:
|
| 162 |
+
images = image_tensor[0 : image_num[i], :, :, :]
|
| 163 |
+
input_ids = [-200 for _ in range(image_num[i])] + [
|
| 164 |
+
1 for _ in range(context_length - 196 * image_num[i])
|
| 165 |
+
]
|
| 166 |
+
else:
|
| 167 |
+
images = None
|
| 168 |
+
input_ids = [1 for _ in range(context_length)]
|
| 169 |
+
print("-" * 80)
|
| 170 |
+
print(
|
| 171 |
+
"Context length: {} with {} pictures".format(
|
| 172 |
+
context_length, image_num[i]
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
with torch.inference_mode():
|
| 176 |
+
for i in range(10): # Run ten times and get the average value
|
| 177 |
+
start_pos = 0
|
| 178 |
+
torch.cuda.synchronize()
|
| 179 |
+
t_st = time.time()
|
| 180 |
+
inputs = torch.as_tensor([input_ids], device=device)
|
| 181 |
+
out = model(
|
| 182 |
+
inputs,
|
| 183 |
+
start_pos=start_pos,
|
| 184 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 185 |
+
images=images,
|
| 186 |
+
)
|
| 187 |
+
start_pos += inputs.shape[1]
|
| 188 |
+
torch.cuda.synchronize()
|
| 189 |
+
t_ed = time.time()
|
| 190 |
+
token = out[:, -1].max(1)[1].unsqueeze(1)
|
| 191 |
+
time_lis.append(t_ed - t_st)
|
| 192 |
+
if args.verbose:
|
| 193 |
+
print(i, t_ed - t_st)
|
| 194 |
+
print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
|
| 195 |
+
print("-" * 80)
|
| 196 |
+
else:
|
| 197 |
+
for i, (context_length, question_length) in enumerate(
|
| 198 |
+
zip(args.context_length, args.question_length)
|
| 199 |
+
):
|
| 200 |
+
context_length = int("".join(context_length))
|
| 201 |
+
question_length = int("".join(question_length))
|
| 202 |
+
input_ids_old = [1 for _ in range(context_length)]
|
| 203 |
+
images = None
|
| 204 |
+
input_ids_new = [1 for _ in range(question_length)]
|
| 205 |
+
time_lis = []
|
| 206 |
+
print("-" * 80)
|
| 207 |
+
print(
|
| 208 |
+
"History length: {} ; Question length: {}".format(
|
| 209 |
+
context_length, question_length
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
with torch.inference_mode():
|
| 213 |
+
for i in range(10): # Run ten times and get the average value
|
| 214 |
+
# history rounds
|
| 215 |
+
start_pos = 0
|
| 216 |
+
if context_length > question_length:
|
| 217 |
+
inputs = torch.as_tensor([input_ids_old], device=device)
|
| 218 |
+
out = model(
|
| 219 |
+
inputs,
|
| 220 |
+
start_pos=start_pos,
|
| 221 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 222 |
+
images=None,
|
| 223 |
+
)
|
| 224 |
+
start_pos += context_length
|
| 225 |
+
|
| 226 |
+
# the present round
|
| 227 |
+
torch.cuda.synchronize()
|
| 228 |
+
t_st = time.time()
|
| 229 |
+
inputs = torch.as_tensor([input_ids_new], device=device)
|
| 230 |
+
out = model(
|
| 231 |
+
inputs,
|
| 232 |
+
start_pos=start_pos,
|
| 233 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 234 |
+
)
|
| 235 |
+
start_pos += inputs.shape[1]
|
| 236 |
+
torch.cuda.synchronize()
|
| 237 |
+
t_ed = time.time()
|
| 238 |
+
|
| 239 |
+
token = out[:, -1].max(1)[1].unsqueeze(1)
|
| 240 |
+
time_lis.append(t_ed - t_st)
|
| 241 |
+
if args.verbose:
|
| 242 |
+
print(i, t_ed - t_st)
|
| 243 |
+
print(
|
| 244 |
+
f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
|
| 245 |
+
)
|
| 246 |
+
print("-" * 80)
|
| 247 |
+
else:
|
| 248 |
+
model = model_type_dict[args.model_type.lower()](config).half()
|
| 249 |
+
if args.precision in ["W4A16"]:
|
| 250 |
+
real_quantize_model_weight(
|
| 251 |
+
model,
|
| 252 |
+
w_bit=4,
|
| 253 |
+
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
|
| 254 |
+
init_only=True,
|
| 255 |
+
)
|
| 256 |
+
model = model.to(device)
|
| 257 |
+
|
| 258 |
+
if args.precision in ["W4A16"]:
|
| 259 |
+
# tune_all_wqlinears(model)
|
| 260 |
+
make_quant_attn(model, device, args.flash_attn)
|
| 261 |
+
make_quant_norm(model)
|
| 262 |
+
make_fused_mlp(model)
|
| 263 |
+
device_warmup(device)
|
| 264 |
+
|
| 265 |
+
print("huggingface ckpt loaded")
|
| 266 |
+
|
| 267 |
+
# warming up
|
| 268 |
+
input_ids = [1 for _ in range(2048)]
|
| 269 |
+
inputs = torch.as_tensor([input_ids], device=device)
|
| 270 |
+
out = model(
|
| 271 |
+
inputs,
|
| 272 |
+
start_pos=0,
|
| 273 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 274 |
+
quant=args.precision in ["W4A16"],
|
| 275 |
+
) # warmup
|
| 276 |
+
|
| 277 |
+
if not args.chunk_prefilling:
|
| 278 |
+
for context_length in args.context_length:
|
| 279 |
+
context_length = int("".join(context_length))
|
| 280 |
+
input_ids = [1 for _ in range(context_length)]
|
| 281 |
+
time_lis = []
|
| 282 |
+
print("-" * 80)
|
| 283 |
+
print("Context length: {}".format(context_length))
|
| 284 |
+
with torch.inference_mode():
|
| 285 |
+
for i in range(10): # Run ten times and get the average value
|
| 286 |
+
start_pos = 0
|
| 287 |
+
torch.cuda.synchronize()
|
| 288 |
+
t_st = time.time()
|
| 289 |
+
inputs = torch.as_tensor([input_ids], device=device)
|
| 290 |
+
out = model(
|
| 291 |
+
inputs,
|
| 292 |
+
start_pos=start_pos,
|
| 293 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 294 |
+
quant=args.precision in ["W4A16"],
|
| 295 |
+
)
|
| 296 |
+
start_pos += inputs.shape[1]
|
| 297 |
+
torch.cuda.synchronize()
|
| 298 |
+
t_ed = time.time()
|
| 299 |
+
token = torch.argmax(out, keepdim=True)[0]
|
| 300 |
+
time_lis.append(t_ed - t_st)
|
| 301 |
+
if args.verbose:
|
| 302 |
+
print(i, t_ed - t_st)
|
| 303 |
+
print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
|
| 304 |
+
# decoing throughput
|
| 305 |
+
time_lis = []
|
| 306 |
+
start_pos = context_length
|
| 307 |
+
torch.cuda.synchronize()
|
| 308 |
+
t_st = time.time()
|
| 309 |
+
for i in range(token_num):
|
| 310 |
+
token = model(
|
| 311 |
+
token,
|
| 312 |
+
start_pos=start_pos,
|
| 313 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 314 |
+
quant=args.precision in ["W4A16"],
|
| 315 |
+
)
|
| 316 |
+
start_pos += 1
|
| 317 |
+
token = torch.argmax(token, keepdim=True)[0]
|
| 318 |
+
torch.cuda.synchronize()
|
| 319 |
+
t_ed = time.time()
|
| 320 |
+
time_lis.append(t_ed - t_st)
|
| 321 |
+
print(
|
| 322 |
+
f"Decoding throughput: {token_num/sum(time_lis):.5f} token/s."
|
| 323 |
+
)
|
| 324 |
+
print("-" * 80)
|
| 325 |
+
else:
|
| 326 |
+
for context_length, question_length in zip(
|
| 327 |
+
args.context_length, args.question_length
|
| 328 |
+
):
|
| 329 |
+
context_length = int("".join(context_length))
|
| 330 |
+
question_length = int("".join(question_length))
|
| 331 |
+
input_ids_old = [1 for _ in range(context_length)]
|
| 332 |
+
input_ids_new = [1 for _ in range(question_length)]
|
| 333 |
+
time_lis = []
|
| 334 |
+
print("-" * 80)
|
| 335 |
+
print(
|
| 336 |
+
"History length: {} ; Question length: {}".format(
|
| 337 |
+
context_length, question_length
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
with torch.inference_mode():
|
| 341 |
+
for i in range(10): # Run ten times and get the average value
|
| 342 |
+
# history rounds
|
| 343 |
+
start_pos = 0
|
| 344 |
+
if context_length > question_length:
|
| 345 |
+
inputs = torch.as_tensor([input_ids_old], device=device)
|
| 346 |
+
out = model(
|
| 347 |
+
inputs,
|
| 348 |
+
start_pos=start_pos,
|
| 349 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 350 |
+
quant=args.precision in ["W4A16"],
|
| 351 |
+
)
|
| 352 |
+
start_pos += inputs.shape[1]
|
| 353 |
+
|
| 354 |
+
# the present round
|
| 355 |
+
torch.cuda.synchronize()
|
| 356 |
+
t_st = time.time()
|
| 357 |
+
inputs = torch.as_tensor([input_ids_new], device=device)
|
| 358 |
+
out = model(
|
| 359 |
+
inputs,
|
| 360 |
+
start_pos=start_pos,
|
| 361 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 362 |
+
quant=args.precision in ["W4A16"],
|
| 363 |
+
)
|
| 364 |
+
start_pos += inputs.shape[1]
|
| 365 |
+
torch.cuda.synchronize()
|
| 366 |
+
t_ed = time.time()
|
| 367 |
+
|
| 368 |
+
token = out[:, -1].max(1)[1].unsqueeze(1)
|
| 369 |
+
time_lis.append(t_ed - t_st)
|
| 370 |
+
if args.verbose:
|
| 371 |
+
print(i, t_ed - t_st)
|
| 372 |
+
print(
|
| 373 |
+
f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
|
| 374 |
+
)
|
| 375 |
+
print("-" * 80)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
if __name__ == "__main__":
|
| 379 |
+
main()
|
llm-awq/tinychat/demo.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
|
| 7 |
+
from attributedict.collections import AttributeDict
|
| 8 |
+
from tinychat.stream_generators import StreamGenerator
|
| 9 |
+
import tinychat.utils.constants
|
| 10 |
+
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
|
| 11 |
+
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
|
| 12 |
+
from tinychat.utils.tune import device_warmup, tune_all_wqlinears
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 17 |
+
|
| 18 |
+
# opt_params in TinyLLMEngine
|
| 19 |
+
gen_params = AttributeDict(
|
| 20 |
+
[
|
| 21 |
+
("seed", -1), # RNG seed
|
| 22 |
+
("n_threads", 1), # TODO: fix this
|
| 23 |
+
("n_predict", 512), # new tokens to predict
|
| 24 |
+
("n_parts", -1), # amount of model parts (-1: determine from model dimensions)
|
| 25 |
+
("n_ctx", 512), # context size
|
| 26 |
+
("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS)
|
| 27 |
+
("n_keep", 0), # number of tokens to keep from initial prompt
|
| 28 |
+
("n_vocab", 50272), # vocabulary size
|
| 29 |
+
# sampling parameters
|
| 30 |
+
("logit_bias", dict()), # logit bias for specific tokens: <int, float>
|
| 31 |
+
("top_k", 40), # <= 0 to use vocab size
|
| 32 |
+
("top_p", 0.95), # 1.0 = disabled
|
| 33 |
+
("tfs_z", 1.00), # 1.0 = disabled
|
| 34 |
+
("typical_p", 1.00), # 1.0 = disabled
|
| 35 |
+
("temp", 0.70), # 1.0 = disabled
|
| 36 |
+
("repeat_penalty", 1.10), # 1.0 = disabled
|
| 37 |
+
(
|
| 38 |
+
"repeat_last_n",
|
| 39 |
+
64,
|
| 40 |
+
), # last n tokens to penalize (0 = disable penalty, -1 = context size)
|
| 41 |
+
("frequency_penalty", 0.00), # 0.0 = disabled
|
| 42 |
+
("presence_penalty", 0.00), # 0.0 = disabled
|
| 43 |
+
("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
| 44 |
+
("mirostat_tau", 5.00), # target entropy
|
| 45 |
+
("mirostat_eta", 0.10), # learning rate
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def stream_output(output_stream):
|
| 51 |
+
print(f"ASSISTANT: ", end="", flush=True)
|
| 52 |
+
pre = 0
|
| 53 |
+
for outputs in output_stream:
|
| 54 |
+
output_text = outputs["text"]
|
| 55 |
+
output_text = output_text.strip().split(" ")
|
| 56 |
+
now = len(output_text) - 1
|
| 57 |
+
if now > pre:
|
| 58 |
+
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
| 59 |
+
pre = now
|
| 60 |
+
print(" ".join(output_text[pre:]), flush=True)
|
| 61 |
+
if "timing" in outputs and outputs["timing"] is not None:
|
| 62 |
+
timing = outputs["timing"]
|
| 63 |
+
context_tokens = timing["context_tokens"]
|
| 64 |
+
context_time = timing["context_time"]
|
| 65 |
+
total_tokens = timing["total_tokens"]
|
| 66 |
+
generation_time_list = timing["generation_time_list"]
|
| 67 |
+
generation_tokens = len(generation_time_list)
|
| 68 |
+
average_speed = (context_time + np.sum(generation_time_list)) / (
|
| 69 |
+
context_tokens + generation_tokens
|
| 70 |
+
)
|
| 71 |
+
print("=" * 50)
|
| 72 |
+
print("Speed of Inference")
|
| 73 |
+
print("-" * 50)
|
| 74 |
+
print(f"TTFT : { context_time:.3f} s for {context_tokens} tokens")
|
| 75 |
+
print(
|
| 76 |
+
f"Speed of Generation : {np.average(generation_time_list)*1000:.2f} ms/token"
|
| 77 |
+
)
|
| 78 |
+
print("=" * 50)
|
| 79 |
+
return " ".join(output_text), total_tokens
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
parser = argparse.ArgumentParser()
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--model_type", type=str, default="LLaMa", help="type of the model"
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--dtype", type=str, default="float16", choices=["float16", "bfloat16"]
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--model_path",
|
| 92 |
+
type=str,
|
| 93 |
+
help="path to the model",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--precision", type=str, default="W4A16", help="compute precision"
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 99 |
+
parser.add_argument("--q_group_size", type=int, default=128)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--load_quant",
|
| 102 |
+
type=str,
|
| 103 |
+
help="path to the pre-quanted 4-bit weights",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--max_seq_len",
|
| 107 |
+
type=int,
|
| 108 |
+
default=2048,
|
| 109 |
+
help="maximum sequence length for kv cache",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--mem_efficient_load",
|
| 116 |
+
action="store_true",
|
| 117 |
+
help="enable mem_efficient_load mod",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--single_round",
|
| 121 |
+
action="store_true",
|
| 122 |
+
help="whether to memorize previous conversations",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--flash_attn",
|
| 126 |
+
action="store_true",
|
| 127 |
+
help="whether to use flash attention",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--chunk_prefilling",
|
| 131 |
+
action="store_true",
|
| 132 |
+
help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
assert args.model_type.lower() in [
|
| 137 |
+
"llama",
|
| 138 |
+
"falcon",
|
| 139 |
+
"mpt",
|
| 140 |
+
"qwen",
|
| 141 |
+
], "We only support llama & falcon & mpt now"
|
| 142 |
+
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
|
| 143 |
+
|
| 144 |
+
gen_params.n_predict = 1024
|
| 145 |
+
gen_params.n_vocab = 32000
|
| 146 |
+
tinychat.utils.constants.max_batch_size = args.max_batch_size
|
| 147 |
+
tinychat.utils.constants.max_seq_len = args.max_seq_len
|
| 148 |
+
tinychat.utils.constants.mem_efficient_load = args.mem_efficient_load
|
| 149 |
+
if tinychat.utils.constants.mem_efficient_load:
|
| 150 |
+
print("=" * 80)
|
| 151 |
+
print(
|
| 152 |
+
"[Info] You have activated mem_efficient_load mode.\n Less on-chip memory will be consumed when loading the model.\n However, the loading process will take more time."
|
| 153 |
+
)
|
| 154 |
+
print("=" * 80)
|
| 155 |
+
# TODO (Haotian): a more elegant implementation here.
|
| 156 |
+
# We need to update these global variables before models use them.
|
| 157 |
+
from tinychat.models import (
|
| 158 |
+
FalconForCausalLM,
|
| 159 |
+
LlamaForCausalLM,
|
| 160 |
+
MPTForCausalLM,
|
| 161 |
+
Qwen2ForCausalLM,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def skip(*args, **kwargs):
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
torch.nn.init.kaiming_uniform_ = skip
|
| 168 |
+
torch.nn.init.kaiming_normal_ = skip
|
| 169 |
+
torch.nn.init.uniform_ = skip
|
| 170 |
+
torch.nn.init.normal_ = skip
|
| 171 |
+
|
| 172 |
+
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
| 173 |
+
if "mpt" in config.__class__.__name__.lower():
|
| 174 |
+
# config.init_device="meta"
|
| 175 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 176 |
+
config.tokenizer_name, trust_remote_code=True
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 180 |
+
args.model_path, use_fast=False, trust_remote_code=True
|
| 181 |
+
)
|
| 182 |
+
torch_dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
| 183 |
+
modeling_utils._init_weights = False
|
| 184 |
+
torch.set_default_dtype(torch_dtype)
|
| 185 |
+
|
| 186 |
+
model_type_dict = {
|
| 187 |
+
"llama": LlamaForCausalLM,
|
| 188 |
+
"falcon": FalconForCausalLM,
|
| 189 |
+
"mpt": MPTForCausalLM,
|
| 190 |
+
"qwen": Qwen2ForCausalLM,
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
if args.precision == "W4A16":
|
| 194 |
+
if args.model_type.lower() == "llama":
|
| 195 |
+
model = model_type_dict["llama"](config).to(torch_dtype)
|
| 196 |
+
model = load_awq_llama_fast(
|
| 197 |
+
model, args.load_quant, 4, args.q_group_size, args.device
|
| 198 |
+
)
|
| 199 |
+
elif args.model_type.lower() == "qwen":
|
| 200 |
+
model = model_type_dict["qwen"](config).to(torch_dtype)
|
| 201 |
+
model = load_awq_llama_fast(
|
| 202 |
+
model, args.load_quant, 4, args.q_group_size, args.device
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
model = model_type_dict[args.model_type.lower()](config).to(torch_dtype)
|
| 206 |
+
model = load_awq_model(
|
| 207 |
+
model, args.load_quant, 4, args.q_group_size, args.device
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
loaded_model = AutoModelForCausalLM.from_pretrained(
|
| 211 |
+
args.model_path,
|
| 212 |
+
config=config,
|
| 213 |
+
torch_dtype=torch_dtype,
|
| 214 |
+
trust_remote_code=True,
|
| 215 |
+
)
|
| 216 |
+
model = (
|
| 217 |
+
model_type_dict[args.model_type.lower()](config)
|
| 218 |
+
.to(torch_dtype)
|
| 219 |
+
.to(args.device)
|
| 220 |
+
)
|
| 221 |
+
model.load_state_dict(loaded_model.state_dict())
|
| 222 |
+
# device warm up
|
| 223 |
+
device_warmup(args.device)
|
| 224 |
+
|
| 225 |
+
# autotune split_k_iters
|
| 226 |
+
# tune_all_wqlinears(model)
|
| 227 |
+
|
| 228 |
+
# TODO (Haotian): Verify if the StreamGenerator still works for the unmodified falcon impl.
|
| 229 |
+
stream_generator = StreamGenerator
|
| 230 |
+
|
| 231 |
+
# Optimize AWQ quantized model
|
| 232 |
+
if args.precision == "W4A16" and (
|
| 233 |
+
args.model_type.lower() == "llama" or args.model_type.lower() == "qwen"
|
| 234 |
+
):
|
| 235 |
+
from tinychat.modules import make_quant_norm, make_quant_attn
|
| 236 |
+
|
| 237 |
+
if args.flash_attn:
|
| 238 |
+
make_quant_attn(model, args.device, args.flash_attn)
|
| 239 |
+
else:
|
| 240 |
+
make_quant_attn(model, args.device)
|
| 241 |
+
make_quant_norm(model)
|
| 242 |
+
model(
|
| 243 |
+
torch.randint(0, 1000, (1, 512), dtype=torch.int, device="cuda:0"),
|
| 244 |
+
0,
|
| 245 |
+
quant=args.precision == "W4A16",
|
| 246 |
+
)
|
| 247 |
+
if args.max_seq_len <= 1024:
|
| 248 |
+
short_prompt = True
|
| 249 |
+
else:
|
| 250 |
+
short_prompt = False
|
| 251 |
+
model_prompter = get_prompter(args.model_type, args.model_path, short_prompt)
|
| 252 |
+
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
|
| 253 |
+
count = 0
|
| 254 |
+
start_pos = 0
|
| 255 |
+
print("=" * 50)
|
| 256 |
+
while True:
|
| 257 |
+
# Get input from the user
|
| 258 |
+
input_prompt = input("USER: ")
|
| 259 |
+
if input_prompt == "":
|
| 260 |
+
print("EXIT...")
|
| 261 |
+
break
|
| 262 |
+
model_prompter.insert_prompt(input_prompt)
|
| 263 |
+
output_stream = stream_generator(
|
| 264 |
+
model,
|
| 265 |
+
tokenizer,
|
| 266 |
+
model_prompter.model_input,
|
| 267 |
+
start_pos,
|
| 268 |
+
gen_params,
|
| 269 |
+
device=args.device,
|
| 270 |
+
stop_token_ids=stop_token_ids,
|
| 271 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 272 |
+
quant_llm=args.precision == "W4A16",
|
| 273 |
+
)
|
| 274 |
+
outputs, total_tokens = stream_output(output_stream)
|
| 275 |
+
if args.chunk_prefilling:
|
| 276 |
+
start_pos += total_tokens
|
| 277 |
+
else:
|
| 278 |
+
start_pos = 0
|
| 279 |
+
if (
|
| 280 |
+
args.single_round is not True and args.max_seq_len > 512
|
| 281 |
+
): # Only memorize previous conversations when kv_cache_size > 512
|
| 282 |
+
model_prompter.update_template(outputs, args.chunk_prefilling)
|
| 283 |
+
count += 1
|
llm-awq/tinychat/internvl_benchmark.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from termcolor import colored
|
| 4 |
+
|
| 5 |
+
import llava
|
| 6 |
+
from llava import conversation as clib
|
| 7 |
+
from llava.media import Image, Video
|
| 8 |
+
import torch
|
| 9 |
+
from awq.quantize import fake_quant
|
| 10 |
+
from awq.quantize.quantizer import real_quantize_model_weight
|
| 11 |
+
from transformers import AutoConfig
|
| 12 |
+
import tinychat
|
| 13 |
+
|
| 14 |
+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 19 |
+
|
| 20 |
+
def skip(*args, **kwargs):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main() -> None:
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model-path",
|
| 28 |
+
"-m",
|
| 29 |
+
type=str,
|
| 30 |
+
default="/home/yuming/workspace/qwen/models/nvila-internal-8b-v1",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--quant_path",
|
| 34 |
+
type=str,
|
| 35 |
+
default="/PATH/TO/QUANT",
|
| 36 |
+
)
|
| 37 |
+
# parser.add_argument("--model-path", "-m", type=str, default="Efficient-Large-Model/J65")
|
| 38 |
+
# parser.add_argument("--quant_path", type=str, default="/home/yuming/workspace/qwen/models/J65/llm/vila2-J65-w4-g128-awq-v2.pt")
|
| 39 |
+
parser.add_argument("--conv-mode", "-c", type=str, default="auto")
|
| 40 |
+
# parser.add_argument("--media", type=str, default="/home/yuming/workspace/space_woaudio.mp4")
|
| 41 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--act_scale_path",
|
| 44 |
+
type=str,
|
| 45 |
+
default="/PATH/TO/SCALE",
|
| 46 |
+
)
|
| 47 |
+
# quantization options
|
| 48 |
+
parser.add_argument("--quant_llm", action="store_true")
|
| 49 |
+
parser.add_argument("--quant_VT", action="store_true")
|
| 50 |
+
# Four basic tasks
|
| 51 |
+
parser.add_argument("--video_caption", action="store_true")
|
| 52 |
+
parser.add_argument("--video_QA", action="store_true")
|
| 53 |
+
parser.add_argument("--image_caption", action="store_true")
|
| 54 |
+
parser.add_argument("--image_QA", action="store_true")
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--all",
|
| 58 |
+
action="store_true",
|
| 59 |
+
help="Whether to quantize visiontower and llm, and test all 4 tasks",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--fakequant_VT",
|
| 63 |
+
action="store_true",
|
| 64 |
+
help="Use fake quant or real quant for VisionTower",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--all_task", action="store_true", help="Whether to test all 4 tasks"
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--video_path", type=str, default="../figures/nvila_demo_video.mp4"
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument("--image_path", type=str, default="../figures/vila-logo.jpg")
|
| 73 |
+
parser.add_argument("--max_seq_len", type=int, default=8192)
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
torch.nn.init.kaiming_uniform_ = skip
|
| 77 |
+
torch.nn.init.kaiming_normal_ = skip
|
| 78 |
+
torch.nn.init.uniform_ = skip
|
| 79 |
+
torch.nn.init.normal_ = skip
|
| 80 |
+
import tinychat.utils.constants
|
| 81 |
+
|
| 82 |
+
tinychat.utils.constants.max_seq_len = args.max_seq_len
|
| 83 |
+
from transformers import modeling_utils
|
| 84 |
+
|
| 85 |
+
modeling_utils._init_weights = False
|
| 86 |
+
|
| 87 |
+
# Load model
|
| 88 |
+
from tinychat.models import InternVL3
|
| 89 |
+
|
| 90 |
+
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
| 91 |
+
config.resume_path = args.model_path
|
| 92 |
+
model = InternVL3(config).half()
|
| 93 |
+
model.language_model = model.language_model.eval()
|
| 94 |
+
if args.quant_llm or args.all:
|
| 95 |
+
from tinychat.modules import (
|
| 96 |
+
make_quant_norm,
|
| 97 |
+
make_quant_attn,
|
| 98 |
+
make_fused_mlp,
|
| 99 |
+
make_fused_vision_attn,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
real_quantize_model_weight(
|
| 103 |
+
model.language_model,
|
| 104 |
+
w_bit=4,
|
| 105 |
+
q_config=dict(q_group_size=128, zero_point=True),
|
| 106 |
+
init_only=True,
|
| 107 |
+
)
|
| 108 |
+
make_quant_attn(model.language_model, "cuda", True)
|
| 109 |
+
make_quant_norm(model.language_model)
|
| 110 |
+
make_fused_mlp(model.language_model)
|
| 111 |
+
model = model.to("cuda")
|
| 112 |
+
model = model.to(args.device)
|
| 113 |
+
if args.quant_VT or args.all:
|
| 114 |
+
from tinychat.modules import QuantInternVisionEncoder
|
| 115 |
+
model.vision_model.encoder = QuantInternVisionEncoder(model.vision_model.encoder)
|
| 116 |
+
model.vision_model.encoder = torch.compile(model.vision_model.encoder)
|
| 117 |
+
|
| 118 |
+
model = model.cuda().eval()
|
| 119 |
+
|
| 120 |
+
if args.video_caption or args.all or args.all_task:
|
| 121 |
+
print("-" * 80)
|
| 122 |
+
print("Video_Caption")
|
| 123 |
+
# Set conversation mode
|
| 124 |
+
clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
|
| 125 |
+
media = Video(args.video_path)
|
| 126 |
+
text = "Elaborate on the visual and narrative elements of the video in detail." # + "1"+" 1"*3069
|
| 127 |
+
prompt = [media, text]
|
| 128 |
+
# Generate response
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
response = model.benchmark(prompt, args.quant_llm)
|
| 131 |
+
if args.video_QA or args.all or args.all_task:
|
| 132 |
+
print("-" * 80)
|
| 133 |
+
print("Video_QA")
|
| 134 |
+
# Set conversation mode
|
| 135 |
+
clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
|
| 136 |
+
media = Video(args.video_path)
|
| 137 |
+
text = "What is the person in the video doing? Select the option that best describes their action: A. Folding paper B. Playing computer games C. Sleeping." # + "1"+" 1"*3069
|
| 138 |
+
prompt = [media, text]
|
| 139 |
+
# Generate response
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
response = model.benchmark(prompt, args.quant_llm)
|
| 142 |
+
if args.image_caption or args.all or args.all_task:
|
| 143 |
+
print("-" * 80)
|
| 144 |
+
print("Image_Caption")
|
| 145 |
+
# Set conversation mode
|
| 146 |
+
clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
|
| 147 |
+
media = Image(args.image_path)
|
| 148 |
+
text = "Describe the image in detail."
|
| 149 |
+
prompt = [media, text]
|
| 150 |
+
# Generate response
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
response = model.benchmark(prompt, args.quant_llm)
|
| 153 |
+
if args.image_QA or args.all or args.all_task:
|
| 154 |
+
print("-" * 80)
|
| 155 |
+
print("Image_QA")
|
| 156 |
+
# Set conversation mode
|
| 157 |
+
clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
|
| 158 |
+
media = Image(args.image_path)
|
| 159 |
+
text = "What does the text in the image say? Choose the option that best matches: A. VILA B. AIIV C. ALIV."
|
| 160 |
+
prompt = [media, text]
|
| 161 |
+
# Generate response
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
response = model.benchmark(prompt, args.quant_llm)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
main()
|
llm-awq/tinychat/split_ckpt.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def split(
|
| 8 |
+
ckpt_path: str,
|
| 9 |
+
out_folder_path: str,
|
| 10 |
+
):
|
| 11 |
+
os.system(f"mkdir -p {out_folder_path}")
|
| 12 |
+
ckpt = torch.load(ckpt_path)
|
| 13 |
+
count = 0
|
| 14 |
+
for key, value in ckpt.items():
|
| 15 |
+
output_dict = {key: value}
|
| 16 |
+
output_name = out_folder_path + "/" + key + ".pt"
|
| 17 |
+
torch.save(output_dict, output_name)
|
| 18 |
+
count += 1
|
| 19 |
+
print(f"Finished splitting the original checkpoint into {count} shards.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ckpt_folder_reader(ckpt_folder_path: str):
|
| 23 |
+
file_list = [f for f in os.listdir(ckpt_folder_path) if f.endswith(".pt")]
|
| 24 |
+
for ckpt in file_list:
|
| 25 |
+
print(ckpt)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--input_path",
|
| 32 |
+
type=str,
|
| 33 |
+
default=None,
|
| 34 |
+
help="Path to the original checkpoint (ends with *.pt)",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--output_path",
|
| 38 |
+
type=str,
|
| 39 |
+
default=None,
|
| 40 |
+
help="Folder to store the splitted checkpoint shards",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
assert (
|
| 45 |
+
args.input_path is not None
|
| 46 |
+
), "Please specify the path to the original checkpoint."
|
| 47 |
+
if args.output_path is None:
|
| 48 |
+
suffix = r"\.pt$"
|
| 49 |
+
args.output_path = re.sub(suffix, "", args.input_path)
|
| 50 |
+
|
| 51 |
+
split(args.input_path, args.output_path)
|
llm-awq/tinychat/vila15_demo.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from transformers import AutoConfig, AutoTokenizer
|
| 8 |
+
from accelerate import load_checkpoint_and_dispatch
|
| 9 |
+
|
| 10 |
+
from tinychat.utils.tune import (
|
| 11 |
+
device_warmup,
|
| 12 |
+
tune_all_wqlinears,
|
| 13 |
+
tune_llava_patch_embedding,
|
| 14 |
+
)
|
| 15 |
+
from tinychat.utils.prompt_templates import (
|
| 16 |
+
get_prompter,
|
| 17 |
+
get_stop_token_ids,
|
| 18 |
+
get_image_token,
|
| 19 |
+
)
|
| 20 |
+
from tinychat.utils.llava_image_processing import (
|
| 21 |
+
process_images,
|
| 22 |
+
load_images,
|
| 23 |
+
vis_images,
|
| 24 |
+
)
|
| 25 |
+
import tinychat.utils.constants
|
| 26 |
+
|
| 27 |
+
# from tinychat.models.llava_llama import LlavaLlamaForCausalLM
|
| 28 |
+
from tinychat.models.vila_llama import VilaLlamaForCausalLM
|
| 29 |
+
from tinychat.stream_generators.llava_stream_gen import LlavaStreamGenerator
|
| 30 |
+
from tinychat.utils.conversation_utils import gen_params, stream_output, TimeStats
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def image_parser(args):
|
| 38 |
+
out = args.image_file.split(args.im_sep)
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def skip(*args, **kwargs):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main(args):
|
| 47 |
+
# Accelerate model initialization
|
| 48 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 49 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 50 |
+
torch.nn.init.kaiming_uniform_ = skip
|
| 51 |
+
torch.nn.init.kaiming_normal_ = skip
|
| 52 |
+
torch.nn.init.uniform_ = skip
|
| 53 |
+
torch.nn.init.normal_ = skip
|
| 54 |
+
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 56 |
+
os.path.join(args.model_path, "llm"), use_fast=False
|
| 57 |
+
)
|
| 58 |
+
tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
|
| 59 |
+
tokenizer.convert_tokens_to_ids(
|
| 60 |
+
[tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
|
| 61 |
+
)[0]
|
| 62 |
+
)
|
| 63 |
+
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
|
| 64 |
+
model = VilaLlamaForCausalLM(config).half()
|
| 65 |
+
tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
|
| 66 |
+
tokenizer.convert_tokens_to_ids(
|
| 67 |
+
[tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
|
| 68 |
+
)[0]
|
| 69 |
+
)
|
| 70 |
+
vision_tower = model.get_vision_tower()
|
| 71 |
+
# if not vision_tower.is_loaded:
|
| 72 |
+
# vision_tower.load_model()
|
| 73 |
+
image_processor = vision_tower.image_processor
|
| 74 |
+
# vision_tower = vision_tower.half()
|
| 75 |
+
|
| 76 |
+
if args.precision == "W16A16":
|
| 77 |
+
pbar = tqdm(range(1))
|
| 78 |
+
pbar.set_description("Loading checkpoint shards")
|
| 79 |
+
for i in pbar:
|
| 80 |
+
model.llm = load_checkpoint_and_dispatch(
|
| 81 |
+
model.llm,
|
| 82 |
+
os.path.join(args.model_path, "llm"),
|
| 83 |
+
no_split_module_classes=[
|
| 84 |
+
"OPTDecoderLayer",
|
| 85 |
+
"LlamaDecoderLayer",
|
| 86 |
+
"BloomBlock",
|
| 87 |
+
"MPTBlock",
|
| 88 |
+
"DecoderLayer",
|
| 89 |
+
"CLIPEncoderLayer",
|
| 90 |
+
],
|
| 91 |
+
).to(args.device)
|
| 92 |
+
model = model.to(args.device)
|
| 93 |
+
|
| 94 |
+
elif args.precision == "W4A16":
|
| 95 |
+
from tinychat.utils.load_quant import load_awq_model
|
| 96 |
+
|
| 97 |
+
model.llm = load_awq_model(model.llm, args.quant_path, 4, 128, args.device)
|
| 98 |
+
from tinychat.modules import (
|
| 99 |
+
make_quant_norm,
|
| 100 |
+
make_quant_attn,
|
| 101 |
+
make_fused_mlp,
|
| 102 |
+
make_fused_vision_attn,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if args.flash_attn:
|
| 106 |
+
print("Enabling flash-attention!")
|
| 107 |
+
make_quant_attn(model.llm, args.device, 1)
|
| 108 |
+
else:
|
| 109 |
+
print("Disabling flash-attention!")
|
| 110 |
+
make_quant_attn(model.llm, args.device)
|
| 111 |
+
make_quant_norm(model.llm)
|
| 112 |
+
# make_fused_mlp(model)
|
| 113 |
+
# make_fused_vision_attn(model,args.device)
|
| 114 |
+
model = model.to(args.device)
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError(f"Precision {args.precision} is not supported.")
|
| 118 |
+
|
| 119 |
+
image_files = image_parser(args)
|
| 120 |
+
image_num = len(image_files)
|
| 121 |
+
images = load_images(image_files)
|
| 122 |
+
if args.vis_image:
|
| 123 |
+
print("=" * 50)
|
| 124 |
+
print("Input Image:")
|
| 125 |
+
vis_images(image_files)
|
| 126 |
+
# Similar operation in model_worker.py
|
| 127 |
+
image_tensor = process_images(images, image_processor, model.config)
|
| 128 |
+
if type(image_tensor) is list:
|
| 129 |
+
image_tensor = [
|
| 130 |
+
image.to(args.device, dtype=torch.float16) for image in image_tensor
|
| 131 |
+
]
|
| 132 |
+
else:
|
| 133 |
+
image_tensor = image_tensor.to(args.device, dtype=torch.float16)
|
| 134 |
+
|
| 135 |
+
device_warmup(args.device)
|
| 136 |
+
tune_llava_patch_embedding(vision_tower, device=args.device)
|
| 137 |
+
|
| 138 |
+
stream_generator = LlavaStreamGenerator
|
| 139 |
+
|
| 140 |
+
if args.max_seq_len <= 1024:
|
| 141 |
+
short_prompt = True
|
| 142 |
+
else:
|
| 143 |
+
short_prompt = False
|
| 144 |
+
model_prompter = get_prompter(
|
| 145 |
+
args.model_type, args.model_path, short_prompt, args.empty_prompt
|
| 146 |
+
)
|
| 147 |
+
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
|
| 148 |
+
count = 0
|
| 149 |
+
|
| 150 |
+
if args.empty_prompt:
|
| 151 |
+
input_indicator = "Input: "
|
| 152 |
+
output_indicator = "Generated: "
|
| 153 |
+
else:
|
| 154 |
+
input_indicator = "USER: "
|
| 155 |
+
output_indicator = "ASSISTANT: "
|
| 156 |
+
|
| 157 |
+
model.eval()
|
| 158 |
+
time_stats = TimeStats()
|
| 159 |
+
start_pos = 0
|
| 160 |
+
while True:
|
| 161 |
+
# Get input from the user
|
| 162 |
+
print("=" * 50)
|
| 163 |
+
input_prompt = input(input_indicator)
|
| 164 |
+
print("-" * 50)
|
| 165 |
+
if input_prompt == "":
|
| 166 |
+
print("EXIT...")
|
| 167 |
+
time_stats.show()
|
| 168 |
+
break
|
| 169 |
+
if count == 0: # Insert image here
|
| 170 |
+
image_token = get_image_token(model, args.model_path)
|
| 171 |
+
image_token_holder = (
|
| 172 |
+
tinychat.utils.constants.LLAVA_DEFAULT_IM_TOKEN_PLACE_HOLDER
|
| 173 |
+
)
|
| 174 |
+
im_token_count = input_prompt.count(image_token_holder)
|
| 175 |
+
if im_token_count == 0:
|
| 176 |
+
model_prompter.insert_prompt(image_token * image_num + input_prompt)
|
| 177 |
+
else:
|
| 178 |
+
assert im_token_count == image_num
|
| 179 |
+
input_prompt = input_prompt.replace(image_token_holder, image_token)
|
| 180 |
+
model_prompter.insert_prompt(input_prompt)
|
| 181 |
+
else:
|
| 182 |
+
model_prompter.insert_prompt(input_prompt)
|
| 183 |
+
if args.chunk_prefilling:
|
| 184 |
+
image_tensor = None # Can insert more images in future
|
| 185 |
+
output_stream = stream_generator(
|
| 186 |
+
model,
|
| 187 |
+
tokenizer,
|
| 188 |
+
model_prompter.model_input,
|
| 189 |
+
start_pos,
|
| 190 |
+
gen_params,
|
| 191 |
+
device=args.device,
|
| 192 |
+
stop_token_ids=stop_token_ids,
|
| 193 |
+
image_tensor=image_tensor,
|
| 194 |
+
chunk_prefilling=args.chunk_prefilling,
|
| 195 |
+
)
|
| 196 |
+
print(output_indicator, end="", flush=True)
|
| 197 |
+
if count == 0:
|
| 198 |
+
outputs, total_tokens = stream_output(output_stream, time_stats)
|
| 199 |
+
else:
|
| 200 |
+
outputs, total_tokens = stream_output(output_stream)
|
| 201 |
+
if args.chunk_prefilling:
|
| 202 |
+
start_pos += total_tokens
|
| 203 |
+
if (
|
| 204 |
+
args.single_round is not True and args.max_seq_len > 512
|
| 205 |
+
): # Only memorize previous conversations when kv_cache_size > 512
|
| 206 |
+
model_prompter.update_template(outputs, args.chunk_prefilling)
|
| 207 |
+
count += 1
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
parser = argparse.ArgumentParser()
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--model_type", type=str, default="LLaMa", help="type of the model"
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--model-path", type=str, default="/data/llm/checkpoints/llava/llava-v1.5-7b"
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--quant-path",
|
| 220 |
+
type=str,
|
| 221 |
+
default="/data/llm/checkpoints/llava/llava-v1.5-7b-w4-g128-awq.pt",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--precision", type=str, default="W4A16", help="compute precision"
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--image-file",
|
| 228 |
+
type=str,
|
| 229 |
+
default="https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg",
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument(
|
| 232 |
+
"--im-sep",
|
| 233 |
+
type=str,
|
| 234 |
+
default=",",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 237 |
+
parser.add_argument("--max_seq_len", type=int, default=2048)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--single_round",
|
| 240 |
+
action="store_true",
|
| 241 |
+
help="whether to memorize previous conversations",
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--vis-image",
|
| 245 |
+
action="store_true",
|
| 246 |
+
help="whether to visualize the image while chatting",
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--empty-prompt",
|
| 250 |
+
action="store_true",
|
| 251 |
+
help="whether to use empty prompt template",
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--flash_attn",
|
| 255 |
+
action="store_true",
|
| 256 |
+
help="whether to use flash attention",
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--chunk_prefilling",
|
| 260 |
+
action="store_true",
|
| 261 |
+
help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
|
| 262 |
+
)
|
| 263 |
+
args = parser.parse_args()
|
| 264 |
+
main(args)
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_sot.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: sot
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_sot_prompt_2
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_yor.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: yor
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_yor_prompt_2
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_ibo.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: ibo
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_ibo_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_kin.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: kin
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_kin_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sna.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: sna
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_sna_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sot.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: sot
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_sot_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_xho.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: xho
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_xho_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tag:
|
| 2 |
+
- afrimgsm_tasks
|
| 3 |
+
- afrimgsm_tasks_prompt_3
|
| 4 |
+
dataset_path: masakhane/afrimgsm
|
| 5 |
+
output_type: generate_until
|
| 6 |
+
test_split: test
|
| 7 |
+
doc_to_target: '{% if answer is not none %}{{answer[21:]}}{% else %}{{answer_number|string}}{% endif %}'
|
| 8 |
+
doc_to_text: "Solve the following math question \n\nQuestion: {{question}} \nAnswer: "
|
| 9 |
+
target_delimiter: ""
|
| 10 |
+
generation_kwargs:
|
| 11 |
+
do_sample: false
|
| 12 |
+
until:
|
| 13 |
+
- 'Question:'
|
| 14 |
+
- </s>
|
| 15 |
+
- <|im_end|>
|
| 16 |
+
filter_list:
|
| 17 |
+
- name: remove_whitespace
|
| 18 |
+
filter:
|
| 19 |
+
- function: remove_whitespace
|
| 20 |
+
- function: take_first
|
| 21 |
+
- filter:
|
| 22 |
+
- function: regex
|
| 23 |
+
group_select: -1
|
| 24 |
+
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
|
| 25 |
+
- function: take_first
|
| 26 |
+
name: flexible-extract
|
| 27 |
+
metric_list:
|
| 28 |
+
- metric: exact_match
|
| 29 |
+
aggregation: mean
|
| 30 |
+
higher_is_better: true
|
| 31 |
+
ignore_case: true
|
| 32 |
+
ignore_punctuation: true
|
| 33 |
+
metadata:
|
| 34 |
+
version: 2.0
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yor.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: yor
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_yor_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_zul.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: zul
|
| 3 |
+
include: afrimgsm_yaml
|
| 4 |
+
task: afrimgsm_zul_prompt_3
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_ibo.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: ibo
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_ibo_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lin.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: lin
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_lin_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lug.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: lug
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_lug_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_orm.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: orm
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_orm_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sna.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: sna
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_sna_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sot.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: sot
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_sot_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_swa.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: swa
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_swa_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_twi.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: twi
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_twi_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_vai.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: vai
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_vai_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_wol.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: wol
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_wol_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_xho.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: xho
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_xho_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_yor.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: yor
|
| 3 |
+
doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
|
| 4 |
+
\ that the response is clear and without any supplementary information. \n\nQuestion:\
|
| 5 |
+
\ {{question}} \nAnswer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_yor_prompt_4
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_amh.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: amh
|
| 3 |
+
doc_to_text: "For mathematical questions provided in Amharic language. Supply the\
|
| 4 |
+
\ accurate numeric answer to the provided question. \n\nQuestion: {{question}} \n\
|
| 5 |
+
Answer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_amh_prompt_5
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_eng.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: eng
|
| 3 |
+
doc_to_text: "For mathematical questions provided in English language. Supply the\
|
| 4 |
+
\ accurate numeric answer to the provided question. \n\nQuestion: {{question}} \n\
|
| 5 |
+
Answer: "
|
| 6 |
+
include: afrimgsm_yaml
|
| 7 |
+
task: afrimgsm_eng_prompt_5
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_ewe.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: ewe
|
| 3 |
+
doc_to_text: "For mathematical questions provided in Ewe language. Supply the accurate\
|
| 4 |
+
\ numeric answer to the provided question. \n\nQuestion: {{question}} \nAnswer: "
|
| 5 |
+
include: afrimgsm_yaml
|
| 6 |
+
task: afrimgsm_ewe_prompt_5
|
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_fra.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by utils.py
|
| 2 |
+
dataset_name: fra
|
| 3 |
+
doc_to_text: "For mathematical questions provided in French language. Supply the accurate\
|
| 4 |
+
\ numeric answer to the provided question. \n\nQuestion: {{question}} \nAnswer: "
|
| 5 |
+
include: afrimgsm_yaml
|
| 6 |
+
task: afrimgsm_fra_prompt_5
|