chen459664 commited on
Commit
075eaa3
·
verified ·
1 Parent(s): 3803ea9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. llm-awq/awq.egg-info/SOURCES.txt +81 -0
  2. llm-awq/awq.egg-info/dependency_links.txt +1 -0
  3. llm-awq/awq.egg-info/requires.txt +16 -0
  4. llm-awq/awq.egg-info/top_level.txt +3 -0
  5. llm-awq/awq/quantize/qmodule.py +235 -0
  6. llm-awq/awq/quantize/w8a8_linear.py +276 -0
  7. llm-awq/awq/utils/lm_eval_adaptor.py +116 -0
  8. llm-awq/awq/utils/utils.py +51 -0
  9. llm-awq/examples/convert_to_hf.py +69 -0
  10. llm-awq/examples/llava_demo.ipynb +0 -0
  11. llm-awq/figures/vila-logo.jpg +0 -0
  12. llm-awq/scripts/codellama_example.sh +25 -0
  13. llm-awq/scripts/llama2_example.sh +25 -0
  14. llm-awq/scripts/llama3_example.sh +25 -0
  15. llm-awq/scripts/llama_example.sh +25 -0
  16. llm-awq/scripts/opt_example.sh +25 -0
  17. llm-awq/scripts/qwen_example.sh +25 -0
  18. llm-awq/scripts/starcoder_example.sh +25 -0
  19. llm-awq/scripts/vicuna_example.sh +25 -0
  20. llm-awq/tinychat/benchmark.py +379 -0
  21. llm-awq/tinychat/demo.py +283 -0
  22. llm-awq/tinychat/internvl_benchmark.py +167 -0
  23. llm-awq/tinychat/split_ckpt.py +51 -0
  24. llm-awq/tinychat/vila15_demo.py +264 -0
  25. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_sot.yaml +4 -0
  26. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_yor.yaml +4 -0
  27. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_ibo.yaml +4 -0
  28. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_kin.yaml +4 -0
  29. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sna.yaml +4 -0
  30. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sot.yaml +4 -0
  31. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_xho.yaml +4 -0
  32. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yaml +34 -0
  33. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yor.yaml +4 -0
  34. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_zul.yaml +4 -0
  35. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_ibo.yaml +7 -0
  36. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lin.yaml +7 -0
  37. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lug.yaml +7 -0
  38. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_orm.yaml +7 -0
  39. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sna.yaml +7 -0
  40. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sot.yaml +7 -0
  41. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_swa.yaml +7 -0
  42. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_twi.yaml +7 -0
  43. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_vai.yaml +7 -0
  44. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_wol.yaml +7 -0
  45. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_xho.yaml +7 -0
  46. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_yor.yaml +7 -0
  47. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_amh.yaml +7 -0
  48. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_eng.yaml +7 -0
  49. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_ewe.yaml +6 -0
  50. lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_fra.yaml +6 -0
llm-awq/awq.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ awq/entry.py
5
+ awq.egg-info/PKG-INFO
6
+ awq.egg-info/SOURCES.txt
7
+ awq.egg-info/dependency_links.txt
8
+ awq.egg-info/requires.txt
9
+ awq.egg-info/top_level.txt
10
+ awq/kernels/setup.py
11
+ awq/kernels/csrc/attention/setup.py
12
+ awq/quantize/__init__.py
13
+ awq/quantize/auto_clip.py
14
+ awq/quantize/auto_scale.py
15
+ awq/quantize/pre_quant.py
16
+ awq/quantize/qmodule.py
17
+ awq/quantize/quantizer.py
18
+ awq/quantize/smooth.py
19
+ awq/quantize/w8a8_linear.py
20
+ awq/utils/__init__.py
21
+ awq/utils/calib_data.py
22
+ awq/utils/lm_eval_adaptor.py
23
+ awq/utils/module.py
24
+ awq/utils/parallel.py
25
+ awq/utils/utils.py
26
+ tinychat/benchmark.py
27
+ tinychat/demo.py
28
+ tinychat/internvl_benchmark.py
29
+ tinychat/internvl_demo.py
30
+ tinychat/nvila_benchmark.py
31
+ tinychat/nvila_demo.py
32
+ tinychat/offline-weight-repacker.py
33
+ tinychat/split_ckpt.py
34
+ tinychat/vila10_demo.py
35
+ tinychat/vila15_demo.py
36
+ tinychat/models/__init__.py
37
+ tinychat/models/falcon.py
38
+ tinychat/models/internvl3.py
39
+ tinychat/models/llama.py
40
+ tinychat/models/llava_llama.py
41
+ tinychat/models/mpt.py
42
+ tinychat/models/nvila_qwen2.py
43
+ tinychat/models/qwen2.py
44
+ tinychat/models/vila_llama.py
45
+ tinychat/models/internvl/configuration_internvl.py
46
+ tinychat/models/internvl/conversation.py
47
+ tinychat/models/internvl/internvit.py
48
+ tinychat/models/internvl/media.py
49
+ tinychat/models/llava_base/llava_arch.py
50
+ tinychat/models/llava_base/multimodal_encoder/builder.py
51
+ tinychat/models/llava_base/multimodal_encoder/clip_encoder.py
52
+ tinychat/models/llava_base/multimodal_projector/builder.py
53
+ tinychat/models/nvila/builder.py
54
+ tinychat/models/nvila/configuration_llava.py
55
+ tinychat/models/nvila/llava_arch.py
56
+ tinychat/modules/__init__.py
57
+ tinychat/modules/fused_attn.py
58
+ tinychat/modules/fused_internencoder.py
59
+ tinychat/modules/fused_mlp.py
60
+ tinychat/modules/fused_norm.py
61
+ tinychat/modules/fused_siglipdecoder.py
62
+ tinychat/modules/fused_vision_attn.py
63
+ tinychat/serve/controller.py
64
+ tinychat/serve/gradio_web_server.py
65
+ tinychat/serve/llava_conv.py
66
+ tinychat/serve/model_worker.py
67
+ tinychat/serve/model_worker_new.py
68
+ tinychat/stream_generators/NVILA_stream_gen.py
69
+ tinychat/stream_generators/__init__.py
70
+ tinychat/stream_generators/internvl_stream_gen.py
71
+ tinychat/stream_generators/llava_stream_gen.py
72
+ tinychat/stream_generators/stream_gen.py
73
+ tinychat/utils/__init__.py
74
+ tinychat/utils/constants.py
75
+ tinychat/utils/conversation_utils.py
76
+ tinychat/utils/input_metadata.py
77
+ tinychat/utils/llava_image_processing.py
78
+ tinychat/utils/load_quant.py
79
+ tinychat/utils/log_utils.py
80
+ tinychat/utils/prompt_templates.py
81
+ tinychat/utils/tune.py
llm-awq/awq.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
llm-awq/awq.egg-info/requires.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.34.2
2
+ sentencepiece
3
+ tokenizers>=0.12.1
4
+ torch==2.3.0
5
+ torchvision==0.18.0
6
+ transformers==4.46.0
7
+ lm_eval==0.3.0
8
+ texttable
9
+ toml
10
+ attributedict
11
+ protobuf
12
+ gradio==3.35.2
13
+ gradio_client==0.2.9
14
+ fastapi
15
+ uvicorn
16
+ pydantic==1.10.19
llm-awq/awq.egg-info/top_level.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ awq
2
+ figures
3
+ tinychat
llm-awq/awq/quantize/qmodule.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import awq_inference_engine # with CUDA kernels
5
+
6
+
7
+ def make_divisible(c, divisor):
8
+ return (c + divisor - 1) // divisor
9
+
10
+
11
+ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
12
+ if group_size >= 128:
13
+ size_multiplier = 1
14
+ elif group_size == 64:
15
+ size_multiplier = 2
16
+ elif group_size == 32:
17
+ size_multiplier = 4
18
+ else:
19
+ raise NotImplementedError
20
+
21
+ base_width = make_divisible(in_features // group_size, pack_num)
22
+ base_width = make_divisible(base_width, size_multiplier) * size_multiplier
23
+ return base_width
24
+
25
+
26
+ def pack_intweight(unpacked_qweight, interleave, kstride):
27
+ # unpacked_qweight: [N, K]
28
+ N = unpacked_qweight.shape[0]
29
+ K = unpacked_qweight.shape[1]
30
+
31
+ Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
32
+ # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
33
+ Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
34
+ Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
35
+
36
+ # reorder each 8 weights for fast dequantization
37
+ # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
38
+ Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
39
+ Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
40
+ Packed_Kernel = Packed_Kernel.reshape(N, K)
41
+
42
+ # interleaving every four rows
43
+ Packed_Kernel = Packed_Kernel.reshape(
44
+ N // interleave, interleave, K // kstride, kstride
45
+ )
46
+ # N // 4, K // 64, 4, 64
47
+ Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
48
+ Packed_Kernel = Packed_Kernel.reshape(
49
+ N // interleave, K // kstride, kstride, interleave
50
+ )
51
+ # Packing -> (N // 4, K // 64, 64)
52
+ Packed_Kernel = (
53
+ Packed_Kernel[..., 0]
54
+ | (Packed_Kernel[..., 1] << 4)
55
+ | (Packed_Kernel[..., 2] << 8)
56
+ | (Packed_Kernel[..., 3] << 12)
57
+ )
58
+ # reshape to (N // 4, K), FP16 format
59
+ Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
60
+ qweight = (
61
+ torch.tensor(Packed_Kernel.astype("int16"))
62
+ .to(unpacked_qweight.device)
63
+ .contiguous()
64
+ )
65
+ return qweight
66
+
67
+
68
+ class ScaledActivation(nn.Module):
69
+ def __init__(self, module, scales):
70
+ super().__init__()
71
+ self.act = module
72
+ self.scales = nn.Parameter(scales.data)
73
+
74
+ def forward(self, x):
75
+ return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
76
+
77
+
78
+ class WQLinear(nn.Module):
79
+ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, dtype=torch.float16):
80
+ super().__init__()
81
+
82
+ if w_bit not in [4]:
83
+ raise NotImplementedError("Only 4-bit are supported for now.")
84
+
85
+ self.in_features = in_features
86
+ self.out_features = out_features
87
+ self.w_bit = w_bit
88
+ self.group_size = group_size if group_size != -1 else in_features
89
+ self.split_k_iters = 8
90
+ self.interleave = 4
91
+ # quick sanity check (make sure aligment)
92
+ assert self.in_features % self.group_size == 0
93
+ assert out_features % (32 // self.w_bit) == 0
94
+ pack_num = 32 // self.w_bit
95
+ int16_pack_num = 16 // self.w_bit
96
+
97
+ assert out_features % (self.interleave) == 0
98
+ self.register_buffer(
99
+ "qweight",
100
+ torch.zeros(
101
+ (
102
+ out_features // self.interleave,
103
+ in_features // int16_pack_num * self.interleave,
104
+ ),
105
+ dtype=torch.int16,
106
+ device=dev,
107
+ ),
108
+ )
109
+ self.register_buffer(
110
+ "scales",
111
+ torch.zeros(
112
+ (
113
+ calculate_zeros_width(in_features, self.group_size) * pack_num,
114
+ out_features,
115
+ ),
116
+ dtype=dtype,
117
+ device=dev,
118
+ ),
119
+ )
120
+ self.register_buffer(
121
+ "scaled_zeros",
122
+ torch.zeros(
123
+ (
124
+ calculate_zeros_width(in_features, self.group_size) * pack_num,
125
+ out_features,
126
+ ),
127
+ dtype=dtype,
128
+ device=dev,
129
+ ),
130
+ )
131
+
132
+ if bias:
133
+ self.register_buffer(
134
+ "bias", torch.zeros((out_features), dtype=dtype, device=dev)
135
+ )
136
+ else:
137
+ self.bias = None
138
+
139
+ @classmethod
140
+ def from_linear(
141
+ cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
142
+ ):
143
+ awq_linear = cls(
144
+ w_bit,
145
+ group_size,
146
+ linear.in_features,
147
+ linear.out_features,
148
+ linear.bias is not None,
149
+ linear.weight.device,
150
+ dtype=linear.weight.data.dtype
151
+ )
152
+ if init_only: # just prepare for loading sd
153
+ return awq_linear
154
+
155
+ # need scales and zeros info for real quantization
156
+ assert scales is not None and zeros is not None
157
+ scale_zeros = zeros * scales
158
+
159
+ dtype = scales.dtype
160
+
161
+ pack_num = 32 // awq_linear.w_bit
162
+ qscales = torch.zeros(
163
+ (
164
+ scales.shape[0],
165
+ calculate_zeros_width(linear.in_features, group_size) * pack_num,
166
+ ),
167
+ dtype=dtype,
168
+ device=scales.device,
169
+ )
170
+ qscales[:, : scales.shape[1]] = scales
171
+ # awq_linear.scales = scales.clone().half()
172
+ awq_linear.scales = qscales.transpose(1, 0).contiguous()
173
+ if linear.bias is not None:
174
+ awq_linear.bias = linear.bias.clone().to(dtype)
175
+
176
+ intweight = []
177
+ for idx in range(awq_linear.in_features):
178
+ intweight.append(
179
+ torch.round(
180
+ (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
181
+ / qscales[:, idx // group_size]
182
+ ).to(torch.int)[:, None]
183
+ )
184
+ intweight = torch.cat(intweight, dim=1)
185
+ # intweight = intweight.t().contiguous()
186
+ intweight = intweight.to(dtype=torch.int32)
187
+ awq_linear.qweight = pack_intweight(
188
+ intweight.contiguous(), interleave=4, kstride=64
189
+ )
190
+
191
+ zeros = zeros.to(dtype=torch.int32)
192
+ scaled_zeros = torch.zeros_like(qscales)
193
+ # scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
194
+ scaled_zeros[:, : scales.shape[1]] = -(
195
+ qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
196
+ ).to(dtype)
197
+ awq_linear.scaled_zeros = scaled_zeros.transpose(1, 0).contiguous()
198
+
199
+ return awq_linear
200
+
201
+ @torch.no_grad()
202
+ def forward(self, x):
203
+ # out_shape = x.shape[:-1] + (self.out_features,)
204
+ # inputs = x.reshape(-1, x.shape[-1])
205
+ inputs = x
206
+ if inputs.numel() / inputs.shape[-1] < 8:
207
+ out = awq_inference_engine.gemv_forward_cuda_new(
208
+ inputs,
209
+ self.qweight,
210
+ self.scales,
211
+ self.scaled_zeros,
212
+ inputs.numel() // inputs.shape[-1],
213
+ self.out_features,
214
+ self.in_features,
215
+ self.group_size,
216
+ )
217
+ else:
218
+ out = awq_inference_engine.gemm_forward_cuda_new(
219
+ inputs, self.qweight, self.scales, self.scaled_zeros
220
+ ) # - 8.0 * self.scales)
221
+ out = out + self.bias if self.bias is not None else out
222
+ # print(out)
223
+ # assert 0
224
+ return out
225
+
226
+ def extra_repr(self) -> str:
227
+ return (
228
+ "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
229
+ self.in_features,
230
+ self.out_features,
231
+ self.bias is not None,
232
+ self.w_bit,
233
+ self.group_size,
234
+ )
235
+ )
llm-awq/awq/quantize/w8a8_linear.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from qserve (https://github.com/mit-han-lab/qserve/tree/main) and modified by Yuming Lou
2
+
3
+
4
+ from typing import Optional, Union
5
+ from torch.nn import Parameter
6
+ import awq_inference_engine
7
+ import torch
8
+ import gc
9
+ from awq.utils.module import set_op_by_name
10
+ from tqdm import tqdm
11
+
12
+
13
+ class W8A8OF16LinearStaticScale(torch.nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ out_features: int,
18
+ bias: bool = True,
19
+ scale: Union[torch.tensor, float] = 1.0,
20
+ params_dtype: Optional[torch.dtype] = None,
21
+ ):
22
+ super().__init__()
23
+
24
+ # Keep input parameters
25
+ self.in_features = in_features
26
+ self.out_features = out_features
27
+ # size [1] or size [oc]
28
+ self.register_buffer(
29
+ "dequant_scale", torch.ones(out_features, dtype=torch.half)
30
+ )
31
+ # Parameters.
32
+ # NOTE: torch.nn.functional.linear performs XA^T + b and as a result
33
+ # we allocate the transpose.
34
+ self.create_weights()
35
+
36
+ if bias:
37
+ self.bias = torch.empty(
38
+ self.out_features,
39
+ device=torch.cuda.current_device(),
40
+ dtype=torch.float16,
41
+ )
42
+ else:
43
+ self.register_parameter("bias", None)
44
+
45
+ def create_weights(self) -> None:
46
+ self.register_buffer(
47
+ "weight",
48
+ torch.empty(
49
+ self.out_features,
50
+ self.in_features,
51
+ dtype=torch.int8,
52
+ requires_grad=False,
53
+ ),
54
+ )
55
+
56
+ def apply_weights(
57
+ self,
58
+ x: torch.Tensor,
59
+ bias: Optional[torch.Tensor],
60
+ ) -> torch.Tensor:
61
+ raise NotImplementedError
62
+
63
+ def forward(self, input_):
64
+ # Matrix multiply.
65
+ output = self.apply_weights(input_, self.bias)
66
+ output_bias = self.bias
67
+ return output, output_bias
68
+
69
+
70
+ class W8A8OF16LinearDynamicInputScale(W8A8OF16LinearStaticScale):
71
+ def __init__(
72
+ self,
73
+ in_features: int,
74
+ out_features: int,
75
+ bias: bool = True,
76
+ scale: Union[torch.tensor, float] = 1.0,
77
+ params_dtype: Optional[torch.dtype] = None,
78
+ ):
79
+ super().__init__(
80
+ in_features=in_features,
81
+ out_features=out_features,
82
+ bias=bias,
83
+ scale=scale,
84
+ params_dtype=params_dtype,
85
+ )
86
+ if bias:
87
+ self.apply_weights = self.apply_weights_bias
88
+ else:
89
+ self.apply_weights = self.apply_weights_no_bias
90
+
91
+ #W bias. Fused bias and W8A8 GEMM
92
+ def apply_weights_bias(
93
+ self,
94
+ # [batch, tokens, channels]
95
+ x: torch.Tensor,
96
+ # [batch * tokens]
97
+ input_scale: torch.Tensor,
98
+ output_buffer: torch.Tensor,
99
+ bias: torch.Tensor = None,
100
+ ):
101
+ x_shape = x.shape
102
+ if len(x.shape) > 2:
103
+ assert 0, "Not implemented"
104
+ x = x.view(-1, x_shape[-1])
105
+ # If use awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda
106
+ awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda(
107
+ x, self.weight, self.dequant_scale, input_scale, output_buffer, bias
108
+ )
109
+ if len(x.shape) > 2:
110
+ assert 0, "Not implemented 2"
111
+ output_buffer = output_buffer.view(*x_shape[:-1], -1)
112
+
113
+ #W/H bias. W8A8 GEMM
114
+ def apply_weights_no_bias(
115
+ self,
116
+ # [batch, tokens, channels]
117
+ x: torch.Tensor,
118
+ # [batch * tokens]
119
+ input_scale: torch.Tensor,
120
+ output_buffer: torch.Tensor,
121
+ bias: torch.Tensor = None,
122
+ ):
123
+ x_shape = x.shape
124
+ if len(x.shape) > 2:
125
+ assert 0, "Not implemented"
126
+ x = x.view(-1, x_shape[-1])
127
+ # If use awq_inference_engine.w8a8_gemm_forward_cuda
128
+ awq_inference_engine.w8a8_gemm_forward_cuda(
129
+ x, self.weight, self.dequant_scale, input_scale, output_buffer
130
+ )
131
+ if len(x.shape) > 2:
132
+ assert 0, "Not implemented 2"
133
+ output_buffer = output_buffer.view(*x_shape[:-1], -1)
134
+
135
+ def forward(self, input_, input_scale, output_buffer):
136
+ # Matrix multiply.
137
+ self.apply_weights(input_, input_scale, output_buffer, self.bias)
138
+
139
+ @classmethod
140
+ def from_linear(
141
+ cls,
142
+ linear,
143
+ init_only=False,
144
+ s1_scale=None,
145
+ fc1=False,
146
+ ):
147
+ q_linear = cls(
148
+ linear.in_features,
149
+ linear.out_features,
150
+ linear.bias is not None,
151
+ )
152
+ if init_only: # just prepare for loading sd
153
+ return q_linear
154
+ if s1_scale is None:
155
+ s1_scale, _ = torch.max(abs(linear.weight.data), dim=-1, keepdim=True)
156
+ s1_scale = s1_scale.clamp_(min=1e-5).div_(127)
157
+
158
+ if linear.bias is not None:
159
+ q_linear.bias = linear.bias.clone().half().contiguous().cuda()
160
+ ## Quantize the weights
161
+ # ---- Quantize the weights to int8 ---- #
162
+ linear_weight = linear.weight.data # OC, IC
163
+ linear_weight = linear_weight.div_(s1_scale.to(linear_weight.device))
164
+ linear_weight = linear_weight.round_().to(torch.int8)
165
+
166
+ q_linear.weight.data[:, :] = linear_weight.half().contiguous().cuda()
167
+
168
+ # ---- Pack the scales ---- #
169
+ q_linear.dequant_scale.data[:] = (
170
+ s1_scale.reshape(-1).half().contiguous().cuda()
171
+ )
172
+ return q_linear.cuda()
173
+
174
+ @classmethod
175
+ def from_qkv(
176
+ cls,
177
+ q,
178
+ k,
179
+ v,
180
+ init_only=False,
181
+ s1_scale=None,
182
+ ):
183
+ q_linear = cls(
184
+ q.in_features,
185
+ q.out_features + k.out_features + v.out_features,
186
+ q.bias is not None,
187
+ )
188
+ if init_only: # just prepare for loading sd
189
+ return q_linear
190
+ weight = torch.cat([q.weight.data, k.weight.data, v.weight.data], dim=0)
191
+
192
+ if s1_scale is None:
193
+ s1_scale, _ = torch.max(abs(weight), dim=-1, keepdim=True)
194
+ s1_scale = s1_scale.clamp_(min=1e-5).div_(127)
195
+
196
+ if q.bias is not None:
197
+ bias = torch.cat([q.bias, k.bias, v.bias], dim=0)
198
+ q_linear.bias = bias.clone().half().contiguous().cuda()
199
+ # ---- Quantize the weights to int8 ---- #
200
+ weight = weight.div_(s1_scale.to(weight.device))
201
+ weight = weight.round_().to(torch.int8)
202
+
203
+ q_linear.weight.data[:, :] = weight.contiguous().cuda()
204
+
205
+ # ---- Pack the scales ---- #
206
+ q_linear.dequant_scale.data[:] = (
207
+ s1_scale.reshape(q.out_features + k.out_features + v.out_features)
208
+ .half()
209
+ .contiguous().cuda()
210
+ )
211
+ return q_linear.cuda()
212
+
213
+
214
+ class FakeW8A8Linear(torch.nn.Module):
215
+ def __init__(
216
+ self, in_features: int, out_features: int, bias: bool = True, wbit: int = 8
217
+ ):
218
+ super().__init__()
219
+ self.weight = torch.nn.Parameter(
220
+ torch.empty(out_features, in_features, dtype=torch.half)
221
+ )
222
+ if bias:
223
+ self.bias = torch.nn.Parameter(
224
+ torch.empty(1, out_features, dtype=torch.half)
225
+ )
226
+ else:
227
+ self.bias = None
228
+ self.wbit = wbit
229
+ self.maxv = 2 ** (wbit - 1) - 1
230
+
231
+ def forward(self, input):
232
+ t_shape = input.shape
233
+ input.view(-1, t_shape[-1])
234
+ scales = input.abs().max(dim=-1, keepdim=True)[0]
235
+ scales.clamp_(min=1e-5).div_(self.maxv)
236
+ input.div_(scales).round_().mul_(scales)
237
+ output = torch.functional.F.linear(input, self.weight, self.bias)
238
+ return output
239
+
240
+ @classmethod
241
+ def from_linear(cls, linear: torch.nn.Linear, wbit=8):
242
+ fake_linear = cls(
243
+ linear.in_features, linear.out_features, linear.bias is not None, wbit
244
+ )
245
+ maxv = 2 ** (wbit - 1) - 1
246
+ scale = (
247
+ torch.max(abs(linear.weight.data.detach()), -1, keepdim=True)[0]
248
+ .clamp_(min=1e-5)
249
+ .div_(maxv)
250
+ )
251
+ weight = linear.weight.data / scale
252
+ weight = weight.round_()
253
+ weight = weight * scale
254
+ fake_linear.weight.copy_(weight.contiguous())
255
+ if linear.bias is not None:
256
+ fake_linear.bias.copy_(
257
+ linear.bias.detach().half().reshape(1, linear.out_features).contiguous()
258
+ )
259
+ else:
260
+ linear.bias = None
261
+ del linear, scale, weight
262
+ torch.cuda.empty_cache()
263
+ return fake_linear
264
+
265
+
266
+ def fake_quant(model, wbit=8):
267
+ for name, m in tqdm(
268
+ model.named_modules(),
269
+ desc="Fake quantizing",
270
+ total=len(list(model.named_modules())),
271
+ ):
272
+ if isinstance(m, torch.nn.Linear):
273
+ FQlinear = FakeW8A8Linear.from_linear(m, wbit)
274
+ del m
275
+ torch.cuda.empty_cache()
276
+ set_op_by_name(model, name, FQlinear)
llm-awq/awq/utils/lm_eval_adaptor.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ from lm_eval.base import BaseLM
4
+ import fnmatch
5
+
6
+
7
+ class LMEvalAdaptor(BaseLM):
8
+ def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
9
+ super().__init__()
10
+
11
+ assert isinstance(batch_size, int)
12
+
13
+ self.model_name = model_name
14
+ self.model = model
15
+ self.model.eval()
16
+
17
+ self.tokenizer = tokenizer
18
+
19
+ # assert isinstance(self.tokenizer, (
20
+ # transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
21
+ # transformers.T5Tokenizer, transformers.T5TokenizerFast,
22
+ # )), "this tokenizer has not been checked for compatibility yet!"
23
+
24
+ self.vocab_size = self.tokenizer.vocab_size
25
+
26
+ self._batch_size = batch_size
27
+
28
+ self._max_length = max_length
29
+
30
+ @property
31
+ def eot_token_id(self):
32
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
33
+ return self.tokenizer.eos_token_id
34
+
35
+ @property
36
+ def max_length(self):
37
+ if self._max_length != -1:
38
+ return self._max_length
39
+ if hasattr(self.model.config, "n_ctx"):
40
+ return self.model.config.n_ctx
41
+ elif hasattr(self.model.config, "max_position_embeddings"):
42
+ return self.model.config.max_position_embeddings
43
+ elif hasattr(self.model.config, "n_positions"):
44
+ return self.model.config.n_positions
45
+ elif "bloom" in self.model_name:
46
+ return 2048
47
+ elif "llama" in self.model_name:
48
+ return 2048 # TODO: did not check this
49
+ elif "mpt" in self.model_name:
50
+ return 2048
51
+ elif "falcon" in self.model_name:
52
+ return 2048
53
+ else:
54
+ print(self.model.config)
55
+ raise NotImplementedError
56
+
57
+ @property
58
+ def max_gen_toks(self):
59
+ return 256
60
+
61
+ @property
62
+ def batch_size(self):
63
+ return self._batch_size
64
+
65
+ @property
66
+ def device(self):
67
+ return "cuda"
68
+
69
+ def tok_encode(self, string: str):
70
+ return self.tokenizer.encode(string, add_special_tokens=False)
71
+
72
+ def tok_decode(self, tokens):
73
+ return self.tokenizer.decode(tokens)
74
+
75
+ def _model_call(self, inps):
76
+ """
77
+ inps: a torch tensor of shape [batch, sequence]
78
+ the size of sequence may vary from call to call
79
+
80
+ returns: a torch tensor of shape [batch, sequence, vocab] with the
81
+ logits returned from the model
82
+ """
83
+ with torch.no_grad():
84
+ if isinstance(
85
+ self.model,
86
+ transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
87
+ ):
88
+ dec_inps = torch.cat(
89
+ [
90
+ torch.tensor(
91
+ self.model.generation_config.decoder_start_token_id,
92
+ )
93
+ .tile(len(inps), 1)
94
+ .to(inps),
95
+ inps,
96
+ ],
97
+ dim=1,
98
+ )
99
+
100
+ kwargs = {
101
+ "decoder_input_ids": dec_inps,
102
+ }
103
+ else:
104
+ kwargs = {}
105
+ out = self.model(inps, **kwargs)[0]
106
+ if (
107
+ "opt" in self.model_name
108
+ ): # there are a few extra tokens in opt, which we should omit
109
+ return out[:, :, :50257]
110
+ else:
111
+ return out # [:, :, :self.tokenizer.vocab_size]
112
+
113
+ def _model_generate(self, context, max_length, eos_token_id):
114
+ return self.model.generate(
115
+ context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
116
+ )
llm-awq/awq/utils/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate
3
+
4
+
5
+ def get_module_by_name_suffix(model, module_name: str):
6
+ for name, module in model.named_modules():
7
+ if name.endswith(module_name):
8
+ return module
9
+
10
+
11
+ def simple_dispatch_model(model, device_map):
12
+ from accelerate.hooks import add_hook_to_module, AlignDevicesHook
13
+
14
+ if "" in device_map:
15
+ d = device_map[""]
16
+ model = model.to(torch.device(d))
17
+ model.hf_device_map = device_map
18
+ return model
19
+
20
+ tied_params = accelerate.utils.modeling.find_tied_parameters(model)
21
+ if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {
22
+ "cpu",
23
+ "disk",
24
+ }:
25
+ main_device = "cpu"
26
+ else:
27
+ main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
28
+
29
+ cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
30
+ prev_hook = None
31
+ for idx, (n, d) in enumerate(cpu_offload_group):
32
+ m = get_module_by_name_suffix(model, n)
33
+ _, prev_hook = accelerate.cpu_offload_with_hook(
34
+ m, execution_device=main_device, prev_module_hook=prev_hook
35
+ )
36
+ # set first cpu offload module's prev_module_hook to the last cpu offload module's hook
37
+ if len(cpu_offload_group) > 1:
38
+ get_module_by_name_suffix(
39
+ model, cpu_offload_group[0][0]
40
+ )._hf_hook.prev_module_hook = prev_hook
41
+
42
+ for n, d in device_map.items():
43
+ m = get_module_by_name_suffix(model, n)
44
+ if d != "cpu":
45
+ d = torch.device(d)
46
+ hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
47
+ add_hook_to_module(m, hook)
48
+ accelerate.utils.modeling.retie_parameters(model, tied_params)
49
+ model.hf_device_map = device_map
50
+
51
+ return model
llm-awq/examples/convert_to_hf.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script demonstrates how you can convert your model into HF format
2
+ # easily and push the quantized weights on the Hub using simple tools.
3
+ # Make sure to have transformers > 4.34 and that you have ran
4
+ # `huggingface-cli login` on your terminal before running this
5
+ # script
6
+ import os
7
+ import argparse
8
+
9
+ # This demo only support single GPU for now
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11
+
12
+ from transformers import AutoConfig, AwqConfig, AutoTokenizer
13
+ from huggingface_hub import HfApi
14
+
15
+ api = HfApi()
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--model_path", type=str, help="path of the original hf model", required=True
20
+ )
21
+ parser.add_argument(
22
+ "--quantized_model_path",
23
+ type=str,
24
+ help="path of the quantized AWQ model",
25
+ required=True,
26
+ )
27
+ parser.add_argument(
28
+ "--quantized_model_hub_path",
29
+ type=str,
30
+ help="path of the quantized AWQ model to push on the Hub",
31
+ required=True,
32
+ )
33
+ parser.add_argument("--w_bit", type=int, default=4, help="")
34
+ parser.add_argument("--q_group_size", default=128, type=int)
35
+ parser.add_argument("--no_zero_point", action="store_true")
36
+
37
+ args = parser.parse_args()
38
+
39
+ original_model_path = args.model_path
40
+ quantized_model_path = args.quantized_model_path
41
+ quantized_model_hub_path = args.quantized_model_hub_path
42
+
43
+ # Load the corresponding AWQConfig
44
+ quantization_config = AwqConfig(
45
+ bits=args.w_bit,
46
+ group_size=args.q_group_size,
47
+ zero_point=not args.no_zero_point,
48
+ backend="llm-awq",
49
+ version="gemv",
50
+ )
51
+
52
+ # Set the attribute `quantization_config` in model's config
53
+ config = AutoConfig.from_pretrained(original_model_path)
54
+ config.quantization_config = quantization_config
55
+
56
+ # Load tokenizer
57
+ tok = AutoTokenizer.from_pretrained(original_model_path)
58
+
59
+ # Push config and tokenizer
60
+ config.push_to_hub(quantized_model_hub_path)
61
+ tok.push_to_hub(quantized_model_hub_path)
62
+
63
+ # Upload model weights
64
+ api.upload_file(
65
+ path_or_fileobj=quantized_model_path,
66
+ path_in_repo="pytorch_model.bin",
67
+ repo_id=quantized_model_hub_path,
68
+ repo_type="model",
69
+ )
llm-awq/examples/llava_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
llm-awq/figures/vila-logo.jpg ADDED
llm-awq/scripts/codellama_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=CodeLlama-13b-Instruct
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/codellama-hf/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/llama2_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=llama-2-7b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/llama2-hf/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/llama3_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=llama3-8b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/models/llama3/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/llama_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=llama-7b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/opt_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=opt-6.7b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/opt/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/opt/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/opt/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/opt/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/qwen_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=qwen2.5-7b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/models/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/models/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/models/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/models/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/starcoder_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=starcoder
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/starcoder-hf/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/scripts/vicuna_example.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL=vicuna-7b
2
+
3
+ # run AWQ search (optional; we provided the pre-computed results)
4
+ python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
5
+ --w_bit 4 --q_group_size 128 \
6
+ --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
7
+
8
+ # evaluate the AWQ quantize model (simulated pseudo quantization)
9
+ python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
10
+ --tasks wikitext \
11
+ --w_bit 4 --q_group_size 128 \
12
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
13
+ --q_backend fake
14
+
15
+ # generate real quantized weights (w4)
16
+ python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
17
+ --w_bit 4 --q_group_size 128 \
18
+ --load_awq awq_cache/$MODEL-w4-g128.pt \
19
+ --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
20
+
21
+ # load and evaluate the real quantized model (smaller gpu memory usage)
22
+ python -m awq.entry --model_path /dataset/vicuna-hf/$MODEL \
23
+ --tasks wikitext \
24
+ --w_bit 4 --q_group_size 128 \
25
+ --load_quant quant_cache/$MODEL-w4-g128-awq.pt
llm-awq/tinychat/benchmark.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage:
2
+ # Please first install awq/kernels
3
+ # then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py
4
+ import argparse
5
+ import torch
6
+ import time
7
+ import numpy as np
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
9
+ import tinychat.utils.constants
10
+ from tinychat.utils.load_quant import load_awq_model
11
+ from awq.quantize.quantizer import real_quantize_model_weight
12
+ from tinychat.utils.tune import (
13
+ tune_all_wqlinears,
14
+ device_warmup,
15
+ tune_llava_patch_embedding,
16
+ )
17
+ from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
18
+
19
+
20
+ def skip(*args, **kwargs):
21
+ pass
22
+
23
+
24
+ def main():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--model_type", type=str, default="LLaMa", help="type of the model"
28
+ )
29
+ parser.add_argument(
30
+ "--model_path",
31
+ type=str,
32
+ default="/data/llm/checkpoints/vicuna-hf/vicuna-7b",
33
+ help="path to the model",
34
+ )
35
+ parser.add_argument("--q_group_size", type=int, default=128)
36
+ parser.add_argument(
37
+ "--verbose",
38
+ default=False,
39
+ action="store_true",
40
+ help="Wheter to print more information.",
41
+ )
42
+ parser.add_argument(
43
+ "--max_seq_len",
44
+ type=int,
45
+ default=8192,
46
+ help="maximum sequence length for kv cache",
47
+ )
48
+ parser.add_argument(
49
+ "--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
50
+ )
51
+ parser.add_argument(
52
+ "--flash_attn",
53
+ action="store_true",
54
+ help="whether to use flash attention",
55
+ )
56
+ parser.add_argument(
57
+ "--chunk_prefilling",
58
+ action="store_true",
59
+ help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
60
+ )
61
+ parser.add_argument(
62
+ "--context_length",
63
+ type=list,
64
+ nargs="+",
65
+ help="The length of input. And if chunk_prefilling used, this serves as the length of tokens from history rounds.",
66
+ )
67
+ parser.add_argument(
68
+ "--question_length",
69
+ type=list,
70
+ nargs="+",
71
+ help="The length of new input. Only useful and necessary when benchmarking chunk_prefilling method",
72
+ )
73
+ parser.add_argument(
74
+ "--precision", type=str, default="W4A16", help="compute precision"
75
+ )
76
+ args = parser.parse_args()
77
+ # some checks
78
+ assert (args.question_length is not None and args.chunk_prefilling) or (
79
+ not args.chunk_prefilling
80
+ ), "If you want to benchmark chunk prefilling, you need specify the question length and context length"
81
+ assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
82
+ token_num = 256
83
+ # We support fixing a certain kind of length
84
+ if args.chunk_prefilling:
85
+ if len(args.context_length) == 1 and len(args.question_length) > 1:
86
+ args.context_length = [
87
+ args.context_length[0] for _ in range(len(args.question_length))
88
+ ]
89
+ elif len(args.question_length) == 1 and len(args.context_length) > 1:
90
+ args.question_length = [
91
+ args.question_length[0] for _ in range(len(args.context_length))
92
+ ]
93
+ elif len(args.question_length) != len(args.context_length):
94
+ raise ValueError(
95
+ "The number of items in the question_length and context_length is expected to be either one or equal!"
96
+ )
97
+ tinychat.utils.constants.max_batch_size = args.max_batch_size
98
+ tinychat.utils.constants.max_seq_len = args.max_seq_len
99
+ from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM
100
+ from tinychat.models.vila_llama import VilaLlamaForCausalLM
101
+
102
+ modeling_utils._init_weights = False
103
+ torch.nn.init.kaiming_uniform_ = skip
104
+ torch.nn.init.kaiming_normal_ = skip
105
+ torch.nn.init.uniform_ = skip
106
+ torch.nn.init.normal_ = skip
107
+
108
+ device = "cuda:0"
109
+ model_type_dict = {
110
+ "llama": LlamaForCausalLM,
111
+ "falcon": FalconForCausalLM,
112
+ "mpt": MPTForCausalLM,
113
+ }
114
+
115
+ config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
116
+ assert args.model_type.lower() in [
117
+ "llama",
118
+ "falcon",
119
+ "mpt",
120
+ "vila",
121
+ ], "We only support llama & falcon & mpt & vila now"
122
+ if "vila" in args.model_type.lower():
123
+ model = VilaLlamaForCausalLM(config).half()
124
+ print(model)
125
+ if args.precision in ["W4A16"]:
126
+ real_quantize_model_weight(
127
+ model.llm,
128
+ w_bit=4,
129
+ q_config=dict(q_group_size=args.q_group_size, zero_point=True),
130
+ init_only=True,
131
+ )
132
+ make_quant_attn(model.llm, device, args.flash_attn)
133
+ make_quant_norm(model.llm)
134
+ make_fused_mlp(model.llm)
135
+ model = model.to(device)
136
+ device_warmup(device)
137
+ tune_llava_patch_embedding(model.get_vision_tower(), device=device)
138
+ if not args.chunk_prefilling:
139
+ image_num = [
140
+ int(int("".join(i)) * 1 / 196) for i in args.context_length
141
+ ] # consider about three thirds of the history tokens are images
142
+ if sum(image_num) > 0:
143
+ image_tensor = 2 * torch.rand((max(image_num), 3, 384, 384)) - 1
144
+ image_tensor = image_tensor.half().to(device)
145
+ else:
146
+ image_tensor = None
147
+
148
+ print("huggingface ckpt loaded")
149
+
150
+ # warming up
151
+ input_ids = [1 for _ in range(2048)]
152
+ inputs = torch.as_tensor([input_ids], device=device)
153
+ out = model(
154
+ inputs, start_pos=0, chunk_prefilling=args.chunk_prefilling
155
+ ) # warmup
156
+
157
+ if not args.chunk_prefilling:
158
+ for i, context_length in enumerate(args.context_length):
159
+ context_length = int("".join(context_length))
160
+ time_lis = []
161
+ if image_num[i]:
162
+ images = image_tensor[0 : image_num[i], :, :, :]
163
+ input_ids = [-200 for _ in range(image_num[i])] + [
164
+ 1 for _ in range(context_length - 196 * image_num[i])
165
+ ]
166
+ else:
167
+ images = None
168
+ input_ids = [1 for _ in range(context_length)]
169
+ print("-" * 80)
170
+ print(
171
+ "Context length: {} with {} pictures".format(
172
+ context_length, image_num[i]
173
+ )
174
+ )
175
+ with torch.inference_mode():
176
+ for i in range(10): # Run ten times and get the average value
177
+ start_pos = 0
178
+ torch.cuda.synchronize()
179
+ t_st = time.time()
180
+ inputs = torch.as_tensor([input_ids], device=device)
181
+ out = model(
182
+ inputs,
183
+ start_pos=start_pos,
184
+ chunk_prefilling=args.chunk_prefilling,
185
+ images=images,
186
+ )
187
+ start_pos += inputs.shape[1]
188
+ torch.cuda.synchronize()
189
+ t_ed = time.time()
190
+ token = out[:, -1].max(1)[1].unsqueeze(1)
191
+ time_lis.append(t_ed - t_st)
192
+ if args.verbose:
193
+ print(i, t_ed - t_st)
194
+ print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
195
+ print("-" * 80)
196
+ else:
197
+ for i, (context_length, question_length) in enumerate(
198
+ zip(args.context_length, args.question_length)
199
+ ):
200
+ context_length = int("".join(context_length))
201
+ question_length = int("".join(question_length))
202
+ input_ids_old = [1 for _ in range(context_length)]
203
+ images = None
204
+ input_ids_new = [1 for _ in range(question_length)]
205
+ time_lis = []
206
+ print("-" * 80)
207
+ print(
208
+ "History length: {} ; Question length: {}".format(
209
+ context_length, question_length
210
+ )
211
+ )
212
+ with torch.inference_mode():
213
+ for i in range(10): # Run ten times and get the average value
214
+ # history rounds
215
+ start_pos = 0
216
+ if context_length > question_length:
217
+ inputs = torch.as_tensor([input_ids_old], device=device)
218
+ out = model(
219
+ inputs,
220
+ start_pos=start_pos,
221
+ chunk_prefilling=args.chunk_prefilling,
222
+ images=None,
223
+ )
224
+ start_pos += context_length
225
+
226
+ # the present round
227
+ torch.cuda.synchronize()
228
+ t_st = time.time()
229
+ inputs = torch.as_tensor([input_ids_new], device=device)
230
+ out = model(
231
+ inputs,
232
+ start_pos=start_pos,
233
+ chunk_prefilling=args.chunk_prefilling,
234
+ )
235
+ start_pos += inputs.shape[1]
236
+ torch.cuda.synchronize()
237
+ t_ed = time.time()
238
+
239
+ token = out[:, -1].max(1)[1].unsqueeze(1)
240
+ time_lis.append(t_ed - t_st)
241
+ if args.verbose:
242
+ print(i, t_ed - t_st)
243
+ print(
244
+ f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
245
+ )
246
+ print("-" * 80)
247
+ else:
248
+ model = model_type_dict[args.model_type.lower()](config).half()
249
+ if args.precision in ["W4A16"]:
250
+ real_quantize_model_weight(
251
+ model,
252
+ w_bit=4,
253
+ q_config=dict(q_group_size=args.q_group_size, zero_point=True),
254
+ init_only=True,
255
+ )
256
+ model = model.to(device)
257
+
258
+ if args.precision in ["W4A16"]:
259
+ # tune_all_wqlinears(model)
260
+ make_quant_attn(model, device, args.flash_attn)
261
+ make_quant_norm(model)
262
+ make_fused_mlp(model)
263
+ device_warmup(device)
264
+
265
+ print("huggingface ckpt loaded")
266
+
267
+ # warming up
268
+ input_ids = [1 for _ in range(2048)]
269
+ inputs = torch.as_tensor([input_ids], device=device)
270
+ out = model(
271
+ inputs,
272
+ start_pos=0,
273
+ chunk_prefilling=args.chunk_prefilling,
274
+ quant=args.precision in ["W4A16"],
275
+ ) # warmup
276
+
277
+ if not args.chunk_prefilling:
278
+ for context_length in args.context_length:
279
+ context_length = int("".join(context_length))
280
+ input_ids = [1 for _ in range(context_length)]
281
+ time_lis = []
282
+ print("-" * 80)
283
+ print("Context length: {}".format(context_length))
284
+ with torch.inference_mode():
285
+ for i in range(10): # Run ten times and get the average value
286
+ start_pos = 0
287
+ torch.cuda.synchronize()
288
+ t_st = time.time()
289
+ inputs = torch.as_tensor([input_ids], device=device)
290
+ out = model(
291
+ inputs,
292
+ start_pos=start_pos,
293
+ chunk_prefilling=args.chunk_prefilling,
294
+ quant=args.precision in ["W4A16"],
295
+ )
296
+ start_pos += inputs.shape[1]
297
+ torch.cuda.synchronize()
298
+ t_ed = time.time()
299
+ token = torch.argmax(out, keepdim=True)[0]
300
+ time_lis.append(t_ed - t_st)
301
+ if args.verbose:
302
+ print(i, t_ed - t_st)
303
+ print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
304
+ # decoing throughput
305
+ time_lis = []
306
+ start_pos = context_length
307
+ torch.cuda.synchronize()
308
+ t_st = time.time()
309
+ for i in range(token_num):
310
+ token = model(
311
+ token,
312
+ start_pos=start_pos,
313
+ chunk_prefilling=args.chunk_prefilling,
314
+ quant=args.precision in ["W4A16"],
315
+ )
316
+ start_pos += 1
317
+ token = torch.argmax(token, keepdim=True)[0]
318
+ torch.cuda.synchronize()
319
+ t_ed = time.time()
320
+ time_lis.append(t_ed - t_st)
321
+ print(
322
+ f"Decoding throughput: {token_num/sum(time_lis):.5f} token/s."
323
+ )
324
+ print("-" * 80)
325
+ else:
326
+ for context_length, question_length in zip(
327
+ args.context_length, args.question_length
328
+ ):
329
+ context_length = int("".join(context_length))
330
+ question_length = int("".join(question_length))
331
+ input_ids_old = [1 for _ in range(context_length)]
332
+ input_ids_new = [1 for _ in range(question_length)]
333
+ time_lis = []
334
+ print("-" * 80)
335
+ print(
336
+ "History length: {} ; Question length: {}".format(
337
+ context_length, question_length
338
+ )
339
+ )
340
+ with torch.inference_mode():
341
+ for i in range(10): # Run ten times and get the average value
342
+ # history rounds
343
+ start_pos = 0
344
+ if context_length > question_length:
345
+ inputs = torch.as_tensor([input_ids_old], device=device)
346
+ out = model(
347
+ inputs,
348
+ start_pos=start_pos,
349
+ chunk_prefilling=args.chunk_prefilling,
350
+ quant=args.precision in ["W4A16"],
351
+ )
352
+ start_pos += inputs.shape[1]
353
+
354
+ # the present round
355
+ torch.cuda.synchronize()
356
+ t_st = time.time()
357
+ inputs = torch.as_tensor([input_ids_new], device=device)
358
+ out = model(
359
+ inputs,
360
+ start_pos=start_pos,
361
+ chunk_prefilling=args.chunk_prefilling,
362
+ quant=args.precision in ["W4A16"],
363
+ )
364
+ start_pos += inputs.shape[1]
365
+ torch.cuda.synchronize()
366
+ t_ed = time.time()
367
+
368
+ token = out[:, -1].max(1)[1].unsqueeze(1)
369
+ time_lis.append(t_ed - t_st)
370
+ if args.verbose:
371
+ print(i, t_ed - t_st)
372
+ print(
373
+ f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
374
+ )
375
+ print("-" * 80)
376
+
377
+
378
+ if __name__ == "__main__":
379
+ main()
llm-awq/tinychat/demo.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
7
+ from attributedict.collections import AttributeDict
8
+ from tinychat.stream_generators import StreamGenerator
9
+ import tinychat.utils.constants
10
+ from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
11
+ from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
12
+ from tinychat.utils.tune import device_warmup, tune_all_wqlinears
13
+
14
+ import os
15
+
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
17
+
18
+ # opt_params in TinyLLMEngine
19
+ gen_params = AttributeDict(
20
+ [
21
+ ("seed", -1), # RNG seed
22
+ ("n_threads", 1), # TODO: fix this
23
+ ("n_predict", 512), # new tokens to predict
24
+ ("n_parts", -1), # amount of model parts (-1: determine from model dimensions)
25
+ ("n_ctx", 512), # context size
26
+ ("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS)
27
+ ("n_keep", 0), # number of tokens to keep from initial prompt
28
+ ("n_vocab", 50272), # vocabulary size
29
+ # sampling parameters
30
+ ("logit_bias", dict()), # logit bias for specific tokens: <int, float>
31
+ ("top_k", 40), # <= 0 to use vocab size
32
+ ("top_p", 0.95), # 1.0 = disabled
33
+ ("tfs_z", 1.00), # 1.0 = disabled
34
+ ("typical_p", 1.00), # 1.0 = disabled
35
+ ("temp", 0.70), # 1.0 = disabled
36
+ ("repeat_penalty", 1.10), # 1.0 = disabled
37
+ (
38
+ "repeat_last_n",
39
+ 64,
40
+ ), # last n tokens to penalize (0 = disable penalty, -1 = context size)
41
+ ("frequency_penalty", 0.00), # 0.0 = disabled
42
+ ("presence_penalty", 0.00), # 0.0 = disabled
43
+ ("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
44
+ ("mirostat_tau", 5.00), # target entropy
45
+ ("mirostat_eta", 0.10), # learning rate
46
+ ]
47
+ )
48
+
49
+
50
+ def stream_output(output_stream):
51
+ print(f"ASSISTANT: ", end="", flush=True)
52
+ pre = 0
53
+ for outputs in output_stream:
54
+ output_text = outputs["text"]
55
+ output_text = output_text.strip().split(" ")
56
+ now = len(output_text) - 1
57
+ if now > pre:
58
+ print(" ".join(output_text[pre:now]), end=" ", flush=True)
59
+ pre = now
60
+ print(" ".join(output_text[pre:]), flush=True)
61
+ if "timing" in outputs and outputs["timing"] is not None:
62
+ timing = outputs["timing"]
63
+ context_tokens = timing["context_tokens"]
64
+ context_time = timing["context_time"]
65
+ total_tokens = timing["total_tokens"]
66
+ generation_time_list = timing["generation_time_list"]
67
+ generation_tokens = len(generation_time_list)
68
+ average_speed = (context_time + np.sum(generation_time_list)) / (
69
+ context_tokens + generation_tokens
70
+ )
71
+ print("=" * 50)
72
+ print("Speed of Inference")
73
+ print("-" * 50)
74
+ print(f"TTFT : { context_time:.3f} s for {context_tokens} tokens")
75
+ print(
76
+ f"Speed of Generation : {np.average(generation_time_list)*1000:.2f} ms/token"
77
+ )
78
+ print("=" * 50)
79
+ return " ".join(output_text), total_tokens
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument(
85
+ "--model_type", type=str, default="LLaMa", help="type of the model"
86
+ )
87
+ parser.add_argument(
88
+ "--dtype", type=str, default="float16", choices=["float16", "bfloat16"]
89
+ )
90
+ parser.add_argument(
91
+ "--model_path",
92
+ type=str,
93
+ help="path to the model",
94
+ )
95
+ parser.add_argument(
96
+ "--precision", type=str, default="W4A16", help="compute precision"
97
+ )
98
+ parser.add_argument("--device", type=str, default="cuda:0")
99
+ parser.add_argument("--q_group_size", type=int, default=128)
100
+ parser.add_argument(
101
+ "--load_quant",
102
+ type=str,
103
+ help="path to the pre-quanted 4-bit weights",
104
+ )
105
+ parser.add_argument(
106
+ "--max_seq_len",
107
+ type=int,
108
+ default=2048,
109
+ help="maximum sequence length for kv cache",
110
+ )
111
+ parser.add_argument(
112
+ "--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
113
+ )
114
+ parser.add_argument(
115
+ "--mem_efficient_load",
116
+ action="store_true",
117
+ help="enable mem_efficient_load mod",
118
+ )
119
+ parser.add_argument(
120
+ "--single_round",
121
+ action="store_true",
122
+ help="whether to memorize previous conversations",
123
+ )
124
+ parser.add_argument(
125
+ "--flash_attn",
126
+ action="store_true",
127
+ help="whether to use flash attention",
128
+ )
129
+ parser.add_argument(
130
+ "--chunk_prefilling",
131
+ action="store_true",
132
+ help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
133
+ )
134
+
135
+ args = parser.parse_args()
136
+ assert args.model_type.lower() in [
137
+ "llama",
138
+ "falcon",
139
+ "mpt",
140
+ "qwen",
141
+ ], "We only support llama & falcon & mpt now"
142
+ assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
143
+
144
+ gen_params.n_predict = 1024
145
+ gen_params.n_vocab = 32000
146
+ tinychat.utils.constants.max_batch_size = args.max_batch_size
147
+ tinychat.utils.constants.max_seq_len = args.max_seq_len
148
+ tinychat.utils.constants.mem_efficient_load = args.mem_efficient_load
149
+ if tinychat.utils.constants.mem_efficient_load:
150
+ print("=" * 80)
151
+ print(
152
+ "[Info] You have activated mem_efficient_load mode.\n Less on-chip memory will be consumed when loading the model.\n However, the loading process will take more time."
153
+ )
154
+ print("=" * 80)
155
+ # TODO (Haotian): a more elegant implementation here.
156
+ # We need to update these global variables before models use them.
157
+ from tinychat.models import (
158
+ FalconForCausalLM,
159
+ LlamaForCausalLM,
160
+ MPTForCausalLM,
161
+ Qwen2ForCausalLM,
162
+ )
163
+
164
+ def skip(*args, **kwargs):
165
+ pass
166
+
167
+ torch.nn.init.kaiming_uniform_ = skip
168
+ torch.nn.init.kaiming_normal_ = skip
169
+ torch.nn.init.uniform_ = skip
170
+ torch.nn.init.normal_ = skip
171
+
172
+ config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
173
+ if "mpt" in config.__class__.__name__.lower():
174
+ # config.init_device="meta"
175
+ tokenizer = AutoTokenizer.from_pretrained(
176
+ config.tokenizer_name, trust_remote_code=True
177
+ )
178
+ else:
179
+ tokenizer = AutoTokenizer.from_pretrained(
180
+ args.model_path, use_fast=False, trust_remote_code=True
181
+ )
182
+ torch_dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
183
+ modeling_utils._init_weights = False
184
+ torch.set_default_dtype(torch_dtype)
185
+
186
+ model_type_dict = {
187
+ "llama": LlamaForCausalLM,
188
+ "falcon": FalconForCausalLM,
189
+ "mpt": MPTForCausalLM,
190
+ "qwen": Qwen2ForCausalLM,
191
+ }
192
+
193
+ if args.precision == "W4A16":
194
+ if args.model_type.lower() == "llama":
195
+ model = model_type_dict["llama"](config).to(torch_dtype)
196
+ model = load_awq_llama_fast(
197
+ model, args.load_quant, 4, args.q_group_size, args.device
198
+ )
199
+ elif args.model_type.lower() == "qwen":
200
+ model = model_type_dict["qwen"](config).to(torch_dtype)
201
+ model = load_awq_llama_fast(
202
+ model, args.load_quant, 4, args.q_group_size, args.device
203
+ )
204
+ else:
205
+ model = model_type_dict[args.model_type.lower()](config).to(torch_dtype)
206
+ model = load_awq_model(
207
+ model, args.load_quant, 4, args.q_group_size, args.device
208
+ )
209
+ else:
210
+ loaded_model = AutoModelForCausalLM.from_pretrained(
211
+ args.model_path,
212
+ config=config,
213
+ torch_dtype=torch_dtype,
214
+ trust_remote_code=True,
215
+ )
216
+ model = (
217
+ model_type_dict[args.model_type.lower()](config)
218
+ .to(torch_dtype)
219
+ .to(args.device)
220
+ )
221
+ model.load_state_dict(loaded_model.state_dict())
222
+ # device warm up
223
+ device_warmup(args.device)
224
+
225
+ # autotune split_k_iters
226
+ # tune_all_wqlinears(model)
227
+
228
+ # TODO (Haotian): Verify if the StreamGenerator still works for the unmodified falcon impl.
229
+ stream_generator = StreamGenerator
230
+
231
+ # Optimize AWQ quantized model
232
+ if args.precision == "W4A16" and (
233
+ args.model_type.lower() == "llama" or args.model_type.lower() == "qwen"
234
+ ):
235
+ from tinychat.modules import make_quant_norm, make_quant_attn
236
+
237
+ if args.flash_attn:
238
+ make_quant_attn(model, args.device, args.flash_attn)
239
+ else:
240
+ make_quant_attn(model, args.device)
241
+ make_quant_norm(model)
242
+ model(
243
+ torch.randint(0, 1000, (1, 512), dtype=torch.int, device="cuda:0"),
244
+ 0,
245
+ quant=args.precision == "W4A16",
246
+ )
247
+ if args.max_seq_len <= 1024:
248
+ short_prompt = True
249
+ else:
250
+ short_prompt = False
251
+ model_prompter = get_prompter(args.model_type, args.model_path, short_prompt)
252
+ stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
253
+ count = 0
254
+ start_pos = 0
255
+ print("=" * 50)
256
+ while True:
257
+ # Get input from the user
258
+ input_prompt = input("USER: ")
259
+ if input_prompt == "":
260
+ print("EXIT...")
261
+ break
262
+ model_prompter.insert_prompt(input_prompt)
263
+ output_stream = stream_generator(
264
+ model,
265
+ tokenizer,
266
+ model_prompter.model_input,
267
+ start_pos,
268
+ gen_params,
269
+ device=args.device,
270
+ stop_token_ids=stop_token_ids,
271
+ chunk_prefilling=args.chunk_prefilling,
272
+ quant_llm=args.precision == "W4A16",
273
+ )
274
+ outputs, total_tokens = stream_output(output_stream)
275
+ if args.chunk_prefilling:
276
+ start_pos += total_tokens
277
+ else:
278
+ start_pos = 0
279
+ if (
280
+ args.single_round is not True and args.max_seq_len > 512
281
+ ): # Only memorize previous conversations when kv_cache_size > 512
282
+ model_prompter.update_template(outputs, args.chunk_prefilling)
283
+ count += 1
llm-awq/tinychat/internvl_benchmark.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from termcolor import colored
4
+
5
+ import llava
6
+ from llava import conversation as clib
7
+ from llava.media import Image, Video
8
+ import torch
9
+ from awq.quantize import fake_quant
10
+ from awq.quantize.quantizer import real_quantize_model_weight
11
+ from transformers import AutoConfig
12
+ import tinychat
13
+
14
+ from torchao.quantization import quantize_, Int4WeightOnlyConfig
15
+
16
+ import os
17
+
18
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
19
+
20
+ def skip(*args, **kwargs):
21
+ pass
22
+
23
+
24
+ def main() -> None:
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--model-path",
28
+ "-m",
29
+ type=str,
30
+ default="/home/yuming/workspace/qwen/models/nvila-internal-8b-v1",
31
+ )
32
+ parser.add_argument(
33
+ "--quant_path",
34
+ type=str,
35
+ default="/PATH/TO/QUANT",
36
+ )
37
+ # parser.add_argument("--model-path", "-m", type=str, default="Efficient-Large-Model/J65")
38
+ # parser.add_argument("--quant_path", type=str, default="/home/yuming/workspace/qwen/models/J65/llm/vila2-J65-w4-g128-awq-v2.pt")
39
+ parser.add_argument("--conv-mode", "-c", type=str, default="auto")
40
+ # parser.add_argument("--media", type=str, default="/home/yuming/workspace/space_woaudio.mp4")
41
+ parser.add_argument("--device", type=str, default="cuda:0")
42
+ parser.add_argument(
43
+ "--act_scale_path",
44
+ type=str,
45
+ default="/PATH/TO/SCALE",
46
+ )
47
+ # quantization options
48
+ parser.add_argument("--quant_llm", action="store_true")
49
+ parser.add_argument("--quant_VT", action="store_true")
50
+ # Four basic tasks
51
+ parser.add_argument("--video_caption", action="store_true")
52
+ parser.add_argument("--video_QA", action="store_true")
53
+ parser.add_argument("--image_caption", action="store_true")
54
+ parser.add_argument("--image_QA", action="store_true")
55
+
56
+ parser.add_argument(
57
+ "--all",
58
+ action="store_true",
59
+ help="Whether to quantize visiontower and llm, and test all 4 tasks",
60
+ )
61
+ parser.add_argument(
62
+ "--fakequant_VT",
63
+ action="store_true",
64
+ help="Use fake quant or real quant for VisionTower",
65
+ )
66
+ parser.add_argument(
67
+ "--all_task", action="store_true", help="Whether to test all 4 tasks"
68
+ )
69
+ parser.add_argument(
70
+ "--video_path", type=str, default="../figures/nvila_demo_video.mp4"
71
+ )
72
+ parser.add_argument("--image_path", type=str, default="../figures/vila-logo.jpg")
73
+ parser.add_argument("--max_seq_len", type=int, default=8192)
74
+ args = parser.parse_args()
75
+
76
+ torch.nn.init.kaiming_uniform_ = skip
77
+ torch.nn.init.kaiming_normal_ = skip
78
+ torch.nn.init.uniform_ = skip
79
+ torch.nn.init.normal_ = skip
80
+ import tinychat.utils.constants
81
+
82
+ tinychat.utils.constants.max_seq_len = args.max_seq_len
83
+ from transformers import modeling_utils
84
+
85
+ modeling_utils._init_weights = False
86
+
87
+ # Load model
88
+ from tinychat.models import InternVL3
89
+
90
+ config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
91
+ config.resume_path = args.model_path
92
+ model = InternVL3(config).half()
93
+ model.language_model = model.language_model.eval()
94
+ if args.quant_llm or args.all:
95
+ from tinychat.modules import (
96
+ make_quant_norm,
97
+ make_quant_attn,
98
+ make_fused_mlp,
99
+ make_fused_vision_attn,
100
+ )
101
+
102
+ real_quantize_model_weight(
103
+ model.language_model,
104
+ w_bit=4,
105
+ q_config=dict(q_group_size=128, zero_point=True),
106
+ init_only=True,
107
+ )
108
+ make_quant_attn(model.language_model, "cuda", True)
109
+ make_quant_norm(model.language_model)
110
+ make_fused_mlp(model.language_model)
111
+ model = model.to("cuda")
112
+ model = model.to(args.device)
113
+ if args.quant_VT or args.all:
114
+ from tinychat.modules import QuantInternVisionEncoder
115
+ model.vision_model.encoder = QuantInternVisionEncoder(model.vision_model.encoder)
116
+ model.vision_model.encoder = torch.compile(model.vision_model.encoder)
117
+
118
+ model = model.cuda().eval()
119
+
120
+ if args.video_caption or args.all or args.all_task:
121
+ print("-" * 80)
122
+ print("Video_Caption")
123
+ # Set conversation mode
124
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
125
+ media = Video(args.video_path)
126
+ text = "Elaborate on the visual and narrative elements of the video in detail." # + "1"+" 1"*3069
127
+ prompt = [media, text]
128
+ # Generate response
129
+ with torch.no_grad():
130
+ response = model.benchmark(prompt, args.quant_llm)
131
+ if args.video_QA or args.all or args.all_task:
132
+ print("-" * 80)
133
+ print("Video_QA")
134
+ # Set conversation mode
135
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
136
+ media = Video(args.video_path)
137
+ text = "What is the person in the video doing? Select the option that best describes their action: A. Folding paper B. Playing computer games C. Sleeping." # + "1"+" 1"*3069
138
+ prompt = [media, text]
139
+ # Generate response
140
+ with torch.no_grad():
141
+ response = model.benchmark(prompt, args.quant_llm)
142
+ if args.image_caption or args.all or args.all_task:
143
+ print("-" * 80)
144
+ print("Image_Caption")
145
+ # Set conversation mode
146
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
147
+ media = Image(args.image_path)
148
+ text = "Describe the image in detail."
149
+ prompt = [media, text]
150
+ # Generate response
151
+ with torch.no_grad():
152
+ response = model.benchmark(prompt, args.quant_llm)
153
+ if args.image_QA or args.all or args.all_task:
154
+ print("-" * 80)
155
+ print("Image_QA")
156
+ # Set conversation mode
157
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
158
+ media = Image(args.image_path)
159
+ text = "What does the text in the image say? Choose the option that best matches: A. VILA B. AIIV C. ALIV."
160
+ prompt = [media, text]
161
+ # Generate response
162
+ with torch.no_grad():
163
+ response = model.benchmark(prompt, args.quant_llm)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
llm-awq/tinychat/split_ckpt.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import argparse
5
+
6
+
7
+ def split(
8
+ ckpt_path: str,
9
+ out_folder_path: str,
10
+ ):
11
+ os.system(f"mkdir -p {out_folder_path}")
12
+ ckpt = torch.load(ckpt_path)
13
+ count = 0
14
+ for key, value in ckpt.items():
15
+ output_dict = {key: value}
16
+ output_name = out_folder_path + "/" + key + ".pt"
17
+ torch.save(output_dict, output_name)
18
+ count += 1
19
+ print(f"Finished splitting the original checkpoint into {count} shards.")
20
+
21
+
22
+ def ckpt_folder_reader(ckpt_folder_path: str):
23
+ file_list = [f for f in os.listdir(ckpt_folder_path) if f.endswith(".pt")]
24
+ for ckpt in file_list:
25
+ print(ckpt)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument(
31
+ "--input_path",
32
+ type=str,
33
+ default=None,
34
+ help="Path to the original checkpoint (ends with *.pt)",
35
+ )
36
+ parser.add_argument(
37
+ "--output_path",
38
+ type=str,
39
+ default=None,
40
+ help="Folder to store the splitted checkpoint shards",
41
+ )
42
+
43
+ args = parser.parse_args()
44
+ assert (
45
+ args.input_path is not None
46
+ ), "Please specify the path to the original checkpoint."
47
+ if args.output_path is None:
48
+ suffix = r"\.pt$"
49
+ args.output_path = re.sub(suffix, "", args.input_path)
50
+
51
+ split(args.input_path, args.output_path)
llm-awq/tinychat/vila15_demo.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+ from transformers import AutoConfig, AutoTokenizer
8
+ from accelerate import load_checkpoint_and_dispatch
9
+
10
+ from tinychat.utils.tune import (
11
+ device_warmup,
12
+ tune_all_wqlinears,
13
+ tune_llava_patch_embedding,
14
+ )
15
+ from tinychat.utils.prompt_templates import (
16
+ get_prompter,
17
+ get_stop_token_ids,
18
+ get_image_token,
19
+ )
20
+ from tinychat.utils.llava_image_processing import (
21
+ process_images,
22
+ load_images,
23
+ vis_images,
24
+ )
25
+ import tinychat.utils.constants
26
+
27
+ # from tinychat.models.llava_llama import LlavaLlamaForCausalLM
28
+ from tinychat.models.vila_llama import VilaLlamaForCausalLM
29
+ from tinychat.stream_generators.llava_stream_gen import LlavaStreamGenerator
30
+ from tinychat.utils.conversation_utils import gen_params, stream_output, TimeStats
31
+
32
+ import os
33
+
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
35
+
36
+
37
+ def image_parser(args):
38
+ out = args.image_file.split(args.im_sep)
39
+ return out
40
+
41
+
42
+ def skip(*args, **kwargs):
43
+ pass
44
+
45
+
46
+ def main(args):
47
+ # Accelerate model initialization
48
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
49
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
50
+ torch.nn.init.kaiming_uniform_ = skip
51
+ torch.nn.init.kaiming_normal_ = skip
52
+ torch.nn.init.uniform_ = skip
53
+ torch.nn.init.normal_ = skip
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ os.path.join(args.model_path, "llm"), use_fast=False
57
+ )
58
+ tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
59
+ tokenizer.convert_tokens_to_ids(
60
+ [tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
61
+ )[0]
62
+ )
63
+ config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
64
+ model = VilaLlamaForCausalLM(config).half()
65
+ tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
66
+ tokenizer.convert_tokens_to_ids(
67
+ [tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
68
+ )[0]
69
+ )
70
+ vision_tower = model.get_vision_tower()
71
+ # if not vision_tower.is_loaded:
72
+ # vision_tower.load_model()
73
+ image_processor = vision_tower.image_processor
74
+ # vision_tower = vision_tower.half()
75
+
76
+ if args.precision == "W16A16":
77
+ pbar = tqdm(range(1))
78
+ pbar.set_description("Loading checkpoint shards")
79
+ for i in pbar:
80
+ model.llm = load_checkpoint_and_dispatch(
81
+ model.llm,
82
+ os.path.join(args.model_path, "llm"),
83
+ no_split_module_classes=[
84
+ "OPTDecoderLayer",
85
+ "LlamaDecoderLayer",
86
+ "BloomBlock",
87
+ "MPTBlock",
88
+ "DecoderLayer",
89
+ "CLIPEncoderLayer",
90
+ ],
91
+ ).to(args.device)
92
+ model = model.to(args.device)
93
+
94
+ elif args.precision == "W4A16":
95
+ from tinychat.utils.load_quant import load_awq_model
96
+
97
+ model.llm = load_awq_model(model.llm, args.quant_path, 4, 128, args.device)
98
+ from tinychat.modules import (
99
+ make_quant_norm,
100
+ make_quant_attn,
101
+ make_fused_mlp,
102
+ make_fused_vision_attn,
103
+ )
104
+
105
+ if args.flash_attn:
106
+ print("Enabling flash-attention!")
107
+ make_quant_attn(model.llm, args.device, 1)
108
+ else:
109
+ print("Disabling flash-attention!")
110
+ make_quant_attn(model.llm, args.device)
111
+ make_quant_norm(model.llm)
112
+ # make_fused_mlp(model)
113
+ # make_fused_vision_attn(model,args.device)
114
+ model = model.to(args.device)
115
+
116
+ else:
117
+ raise NotImplementedError(f"Precision {args.precision} is not supported.")
118
+
119
+ image_files = image_parser(args)
120
+ image_num = len(image_files)
121
+ images = load_images(image_files)
122
+ if args.vis_image:
123
+ print("=" * 50)
124
+ print("Input Image:")
125
+ vis_images(image_files)
126
+ # Similar operation in model_worker.py
127
+ image_tensor = process_images(images, image_processor, model.config)
128
+ if type(image_tensor) is list:
129
+ image_tensor = [
130
+ image.to(args.device, dtype=torch.float16) for image in image_tensor
131
+ ]
132
+ else:
133
+ image_tensor = image_tensor.to(args.device, dtype=torch.float16)
134
+
135
+ device_warmup(args.device)
136
+ tune_llava_patch_embedding(vision_tower, device=args.device)
137
+
138
+ stream_generator = LlavaStreamGenerator
139
+
140
+ if args.max_seq_len <= 1024:
141
+ short_prompt = True
142
+ else:
143
+ short_prompt = False
144
+ model_prompter = get_prompter(
145
+ args.model_type, args.model_path, short_prompt, args.empty_prompt
146
+ )
147
+ stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
148
+ count = 0
149
+
150
+ if args.empty_prompt:
151
+ input_indicator = "Input: "
152
+ output_indicator = "Generated: "
153
+ else:
154
+ input_indicator = "USER: "
155
+ output_indicator = "ASSISTANT: "
156
+
157
+ model.eval()
158
+ time_stats = TimeStats()
159
+ start_pos = 0
160
+ while True:
161
+ # Get input from the user
162
+ print("=" * 50)
163
+ input_prompt = input(input_indicator)
164
+ print("-" * 50)
165
+ if input_prompt == "":
166
+ print("EXIT...")
167
+ time_stats.show()
168
+ break
169
+ if count == 0: # Insert image here
170
+ image_token = get_image_token(model, args.model_path)
171
+ image_token_holder = (
172
+ tinychat.utils.constants.LLAVA_DEFAULT_IM_TOKEN_PLACE_HOLDER
173
+ )
174
+ im_token_count = input_prompt.count(image_token_holder)
175
+ if im_token_count == 0:
176
+ model_prompter.insert_prompt(image_token * image_num + input_prompt)
177
+ else:
178
+ assert im_token_count == image_num
179
+ input_prompt = input_prompt.replace(image_token_holder, image_token)
180
+ model_prompter.insert_prompt(input_prompt)
181
+ else:
182
+ model_prompter.insert_prompt(input_prompt)
183
+ if args.chunk_prefilling:
184
+ image_tensor = None # Can insert more images in future
185
+ output_stream = stream_generator(
186
+ model,
187
+ tokenizer,
188
+ model_prompter.model_input,
189
+ start_pos,
190
+ gen_params,
191
+ device=args.device,
192
+ stop_token_ids=stop_token_ids,
193
+ image_tensor=image_tensor,
194
+ chunk_prefilling=args.chunk_prefilling,
195
+ )
196
+ print(output_indicator, end="", flush=True)
197
+ if count == 0:
198
+ outputs, total_tokens = stream_output(output_stream, time_stats)
199
+ else:
200
+ outputs, total_tokens = stream_output(output_stream)
201
+ if args.chunk_prefilling:
202
+ start_pos += total_tokens
203
+ if (
204
+ args.single_round is not True and args.max_seq_len > 512
205
+ ): # Only memorize previous conversations when kv_cache_size > 512
206
+ model_prompter.update_template(outputs, args.chunk_prefilling)
207
+ count += 1
208
+
209
+
210
+ if __name__ == "__main__":
211
+ parser = argparse.ArgumentParser()
212
+ parser.add_argument(
213
+ "--model_type", type=str, default="LLaMa", help="type of the model"
214
+ )
215
+ parser.add_argument(
216
+ "--model-path", type=str, default="/data/llm/checkpoints/llava/llava-v1.5-7b"
217
+ )
218
+ parser.add_argument(
219
+ "--quant-path",
220
+ type=str,
221
+ default="/data/llm/checkpoints/llava/llava-v1.5-7b-w4-g128-awq.pt",
222
+ )
223
+ parser.add_argument(
224
+ "--precision", type=str, default="W4A16", help="compute precision"
225
+ )
226
+ parser.add_argument(
227
+ "--image-file",
228
+ type=str,
229
+ default="https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg",
230
+ )
231
+ parser.add_argument(
232
+ "--im-sep",
233
+ type=str,
234
+ default=",",
235
+ )
236
+ parser.add_argument("--device", type=str, default="cuda")
237
+ parser.add_argument("--max_seq_len", type=int, default=2048)
238
+ parser.add_argument(
239
+ "--single_round",
240
+ action="store_true",
241
+ help="whether to memorize previous conversations",
242
+ )
243
+ parser.add_argument(
244
+ "--vis-image",
245
+ action="store_true",
246
+ help="whether to visualize the image while chatting",
247
+ )
248
+ parser.add_argument(
249
+ "--empty-prompt",
250
+ action="store_true",
251
+ help="whether to use empty prompt template",
252
+ )
253
+ parser.add_argument(
254
+ "--flash_attn",
255
+ action="store_true",
256
+ help="whether to use flash attention",
257
+ )
258
+ parser.add_argument(
259
+ "--chunk_prefilling",
260
+ action="store_true",
261
+ help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
262
+ )
263
+ args = parser.parse_args()
264
+ main(args)
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_sot.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: sot
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_sot_prompt_2
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_2/afrimgsm_yor.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: yor
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_yor_prompt_2
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_ibo.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: ibo
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_ibo_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_kin.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: kin
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_kin_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sna.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: sna
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_sna_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_sot.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: sot
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_sot_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_xho.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: xho
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_xho_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tag:
2
+ - afrimgsm_tasks
3
+ - afrimgsm_tasks_prompt_3
4
+ dataset_path: masakhane/afrimgsm
5
+ output_type: generate_until
6
+ test_split: test
7
+ doc_to_target: '{% if answer is not none %}{{answer[21:]}}{% else %}{{answer_number|string}}{% endif %}'
8
+ doc_to_text: "Solve the following math question \n\nQuestion: {{question}} \nAnswer: "
9
+ target_delimiter: ""
10
+ generation_kwargs:
11
+ do_sample: false
12
+ until:
13
+ - 'Question:'
14
+ - </s>
15
+ - <|im_end|>
16
+ filter_list:
17
+ - name: remove_whitespace
18
+ filter:
19
+ - function: remove_whitespace
20
+ - function: take_first
21
+ - filter:
22
+ - function: regex
23
+ group_select: -1
24
+ regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
25
+ - function: take_first
26
+ name: flexible-extract
27
+ metric_list:
28
+ - metric: exact_match
29
+ aggregation: mean
30
+ higher_is_better: true
31
+ ignore_case: true
32
+ ignore_punctuation: true
33
+ metadata:
34
+ version: 2.0
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_yor.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: yor
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_yor_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_3/afrimgsm_zul.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: zul
3
+ include: afrimgsm_yaml
4
+ task: afrimgsm_zul_prompt_3
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_ibo.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: ibo
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_ibo_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lin.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: lin
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_lin_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_lug.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: lug
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_lug_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_orm.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: orm
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_orm_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sna.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: sna
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_sna_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_sot.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: sot
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_sot_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_swa.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: swa
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_swa_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_twi.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: twi
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_twi_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_vai.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: vai
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_vai_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_wol.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: wol
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_wol_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_xho.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: xho
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_xho_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_4/afrimgsm_yor.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: yor
3
+ doc_to_text: "Answer the given question with the appropriate numerical value, ensuring\
4
+ \ that the response is clear and without any supplementary information. \n\nQuestion:\
5
+ \ {{question}} \nAnswer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_yor_prompt_4
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_amh.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: amh
3
+ doc_to_text: "For mathematical questions provided in Amharic language. Supply the\
4
+ \ accurate numeric answer to the provided question. \n\nQuestion: {{question}} \n\
5
+ Answer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_amh_prompt_5
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_eng.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: eng
3
+ doc_to_text: "For mathematical questions provided in English language. Supply the\
4
+ \ accurate numeric answer to the provided question. \n\nQuestion: {{question}} \n\
5
+ Answer: "
6
+ include: afrimgsm_yaml
7
+ task: afrimgsm_eng_prompt_5
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_ewe.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: ewe
3
+ doc_to_text: "For mathematical questions provided in Ewe language. Supply the accurate\
4
+ \ numeric answer to the provided question. \n\nQuestion: {{question}} \nAnswer: "
5
+ include: afrimgsm_yaml
6
+ task: afrimgsm_ewe_prompt_5
lm-evaluation-harness/lm_eval/tasks/afrimgsm/direct/prompt_5/afrimgsm_fra.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Generated by utils.py
2
+ dataset_name: fra
3
+ doc_to_text: "For mathematical questions provided in French language. Supply the accurate\
4
+ \ numeric answer to the provided question. \n\nQuestion: {{question}} \nAnswer: "
5
+ include: afrimgsm_yaml
6
+ task: afrimgsm_fra_prompt_5