xieli commited on
Commit
3f373d0
Β·
1 Parent(s): 0b420f3

feat: remove awq pkg

Browse files
Files changed (2) hide show
  1. model_loader.py +19 -67
  2. requirements.txt +0 -1
model_loader.py CHANGED
@@ -7,7 +7,6 @@ import threading
7
  from typing import Optional, Dict, Any, Tuple
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
- from awq import AutoAWQForCausalLM
11
  from funasr_detach import AutoModel
12
 
13
  # Global cache for downloaded models to avoid repeated downloads
@@ -106,7 +105,7 @@ class UnifiedModelLoader:
106
  Prepare quantization configuration for model loading
107
 
108
  Args:
109
- quantization_config: Quantization type ('int4', 'int8', 'int4_offline_awq', or None)
110
  torch_dtype: PyTorch data type for compute operations
111
 
112
  Returns:
@@ -117,12 +116,7 @@ class UnifiedModelLoader:
117
 
118
  quantization_config = quantization_config.lower()
119
 
120
- if quantization_config == "int4_offline_awq":
121
- # For pre-quantized AWQ models, no additional quantization needed
122
- self.logger.info("πŸ”§ Loading pre-quantized AWQ 4-bit model (offline)")
123
- return {}, True # Load pre-quantized model normally, allow torch_dtype setting
124
-
125
- elif quantization_config == "int8":
126
  # Use user-specified torch_dtype for compute, default to bfloat16
127
  compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
128
  self.logger.info(f"πŸ”§ INT8 quantization: using {compute_dtype} for compute operations")
@@ -149,7 +143,7 @@ class UnifiedModelLoader:
149
  "quantization_config": bnb_config
150
  }, False # INT4 quantization handles torch_dtype internally, don't set it again
151
  else:
152
- raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8', 'int4_offline_awq'")
153
 
154
  def load_transformers_model(
155
  self,
@@ -164,7 +158,7 @@ class UnifiedModelLoader:
164
  Args:
165
  model_path: Model path or ID
166
  source: Model source, auto means auto-detect
167
- quantization_config: Quantization configuration ('int4', 'int8', 'int4_offline_awq', or None for no quantization)
168
  **kwargs: Other parameters (torch_dtype, device_map, etc.)
169
 
170
  Returns:
@@ -196,25 +190,11 @@ class UnifiedModelLoader:
196
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
197
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
198
 
199
- # Check if using AWQ quantization
200
- if quantization_config and quantization_config.lower() == "int4_offline_awq":
201
- # Use AWQ loading for pre-quantized AWQ models
202
- awq_model_path = os.path.join(model_path, "awq_quantized")
203
- if not os.path.exists(awq_model_path):
204
- raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
205
-
206
- self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
207
- model = AutoAWQForCausalLM.from_quantized(
208
- awq_model_path,
209
- device_map=kwargs.get("device_map", "auto"),
210
- trust_remote_code=True
211
- )
212
- else:
213
- # Standard loading
214
- model = AutoModelForCausalLM.from_pretrained(
215
- model_path,
216
- **load_kwargs
217
- )
218
  tokenizer = AutoTokenizer.from_pretrained(
219
  model_path,
220
  trust_remote_code=True,
@@ -240,25 +220,11 @@ class UnifiedModelLoader:
240
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
241
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
242
 
243
- # Check if using AWQ quantization
244
- if quantization_config and quantization_config.lower() == "int4_offline_awq":
245
- # Use AWQ loading for pre-quantized AWQ models
246
- awq_model_path = os.path.join(model_path, "awq_quantized")
247
- if not os.path.exists(awq_model_path):
248
- raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
249
-
250
- self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
251
- model = AutoAWQForCausalLM.from_quantized(
252
- awq_model_path,
253
- device_map=kwargs.get("device_map", "auto"),
254
- trust_remote_code=True
255
- )
256
- else:
257
- # Standard loading
258
- model = MSAutoModelForCausalLM.from_pretrained(
259
- model_path,
260
- **load_kwargs
261
- )
262
  tokenizer = MSAutoTokenizer.from_pretrained(
263
  model_path,
264
  trust_remote_code=True,
@@ -282,25 +248,11 @@ class UnifiedModelLoader:
282
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
283
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
284
 
285
- # Check if using AWQ quantization
286
- if quantization_config and quantization_config.lower() == "int4_offline_awq":
287
- # Use AWQ loading for pre-quantized AWQ models
288
- awq_model_path = os.path.join(model_path, "awq_quantized")
289
- if not os.path.exists(awq_model_path):
290
- raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
291
-
292
- self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
293
- model = AutoAWQForCausalLM.from_quantized(
294
- awq_model_path,
295
- device_map=kwargs.get("device_map", "auto"),
296
- trust_remote_code=True
297
- )
298
- else:
299
- # Standard loading
300
- model = AutoModelForCausalLM.from_pretrained(
301
- model_path,
302
- **load_kwargs
303
- )
304
  tokenizer = AutoTokenizer.from_pretrained(
305
  model_path,
306
  trust_remote_code=True,
 
7
  from typing import Optional, Dict, Any, Tuple
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
10
  from funasr_detach import AutoModel
11
 
12
  # Global cache for downloaded models to avoid repeated downloads
 
105
  Prepare quantization configuration for model loading
106
 
107
  Args:
108
+ quantization_config: Quantization type ('int4', 'int8', or None)
109
  torch_dtype: PyTorch data type for compute operations
110
 
111
  Returns:
 
116
 
117
  quantization_config = quantization_config.lower()
118
 
119
+ if quantization_config == "int8":
 
 
 
 
 
120
  # Use user-specified torch_dtype for compute, default to bfloat16
121
  compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
122
  self.logger.info(f"πŸ”§ INT8 quantization: using {compute_dtype} for compute operations")
 
143
  "quantization_config": bnb_config
144
  }, False # INT4 quantization handles torch_dtype internally, don't set it again
145
  else:
146
+ raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8'")
147
 
148
  def load_transformers_model(
149
  self,
 
158
  Args:
159
  model_path: Model path or ID
160
  source: Model source, auto means auto-detect
161
+ quantization_config: Quantization configuration ('int4', 'int8', or None for no quantization)
162
  **kwargs: Other parameters (torch_dtype, device_map, etc.)
163
 
164
  Returns:
 
190
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
191
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
192
 
193
+ # Standard loading
194
+ model = AutoModelForCausalLM.from_pretrained(
195
+ model_path,
196
+ **load_kwargs
197
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  tokenizer = AutoTokenizer.from_pretrained(
199
  model_path,
200
  trust_remote_code=True,
 
220
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
221
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
222
 
223
+ # Standard loading
224
+ model = MSAutoModelForCausalLM.from_pretrained(
225
+ model_path,
226
+ **load_kwargs
227
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  tokenizer = MSAutoTokenizer.from_pretrained(
229
  model_path,
230
  trust_remote_code=True,
 
248
  if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
249
  load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
250
 
251
+ # Standard loading
252
+ model = AutoModelForCausalLM.from_pretrained(
253
+ model_path,
254
+ **load_kwargs
255
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  tokenizer = AutoTokenizer.from_pretrained(
257
  model_path,
258
  trust_remote_code=True,
requirements.txt CHANGED
@@ -22,4 +22,3 @@ gradio>=5.16.0
22
  nvidia-cuda-nvrtc-cu12==12.8.93
23
  spaces==0.42.1
24
  matplotlib==3.10.7
25
- autoawq==0.2.9
 
22
  nvidia-cuda-nvrtc-cu12==12.8.93
23
  spaces==0.42.1
24
  matplotlib==3.10.7