Spaces:
Running
on
Zero
Running
on
Zero
xieli
commited on
Commit
Β·
3f373d0
1
Parent(s):
0b420f3
feat: remove awq pkg
Browse files- model_loader.py +19 -67
- requirements.txt +0 -1
model_loader.py
CHANGED
|
@@ -7,7 +7,6 @@ import threading
|
|
| 7 |
from typing import Optional, Dict, Any, Tuple
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 10 |
-
from awq import AutoAWQForCausalLM
|
| 11 |
from funasr_detach import AutoModel
|
| 12 |
|
| 13 |
# Global cache for downloaded models to avoid repeated downloads
|
|
@@ -106,7 +105,7 @@ class UnifiedModelLoader:
|
|
| 106 |
Prepare quantization configuration for model loading
|
| 107 |
|
| 108 |
Args:
|
| 109 |
-
quantization_config: Quantization type ('int4', 'int8',
|
| 110 |
torch_dtype: PyTorch data type for compute operations
|
| 111 |
|
| 112 |
Returns:
|
|
@@ -117,12 +116,7 @@ class UnifiedModelLoader:
|
|
| 117 |
|
| 118 |
quantization_config = quantization_config.lower()
|
| 119 |
|
| 120 |
-
if quantization_config == "
|
| 121 |
-
# For pre-quantized AWQ models, no additional quantization needed
|
| 122 |
-
self.logger.info("π§ Loading pre-quantized AWQ 4-bit model (offline)")
|
| 123 |
-
return {}, True # Load pre-quantized model normally, allow torch_dtype setting
|
| 124 |
-
|
| 125 |
-
elif quantization_config == "int8":
|
| 126 |
# Use user-specified torch_dtype for compute, default to bfloat16
|
| 127 |
compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
|
| 128 |
self.logger.info(f"π§ INT8 quantization: using {compute_dtype} for compute operations")
|
|
@@ -149,7 +143,7 @@ class UnifiedModelLoader:
|
|
| 149 |
"quantization_config": bnb_config
|
| 150 |
}, False # INT4 quantization handles torch_dtype internally, don't set it again
|
| 151 |
else:
|
| 152 |
-
raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8'
|
| 153 |
|
| 154 |
def load_transformers_model(
|
| 155 |
self,
|
|
@@ -164,7 +158,7 @@ class UnifiedModelLoader:
|
|
| 164 |
Args:
|
| 165 |
model_path: Model path or ID
|
| 166 |
source: Model source, auto means auto-detect
|
| 167 |
-
quantization_config: Quantization configuration ('int4', 'int8',
|
| 168 |
**kwargs: Other parameters (torch_dtype, device_map, etc.)
|
| 169 |
|
| 170 |
Returns:
|
|
@@ -196,25 +190,11 @@ class UnifiedModelLoader:
|
|
| 196 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 197 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 198 |
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
|
| 205 |
-
|
| 206 |
-
self.logger.info(f"π§ Loading AWQ quantized model from: {awq_model_path}")
|
| 207 |
-
model = AutoAWQForCausalLM.from_quantized(
|
| 208 |
-
awq_model_path,
|
| 209 |
-
device_map=kwargs.get("device_map", "auto"),
|
| 210 |
-
trust_remote_code=True
|
| 211 |
-
)
|
| 212 |
-
else:
|
| 213 |
-
# Standard loading
|
| 214 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 215 |
-
model_path,
|
| 216 |
-
**load_kwargs
|
| 217 |
-
)
|
| 218 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 219 |
model_path,
|
| 220 |
trust_remote_code=True,
|
|
@@ -240,25 +220,11 @@ class UnifiedModelLoader:
|
|
| 240 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 241 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 242 |
|
| 243 |
-
#
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
|
| 249 |
-
|
| 250 |
-
self.logger.info(f"π§ Loading AWQ quantized model from: {awq_model_path}")
|
| 251 |
-
model = AutoAWQForCausalLM.from_quantized(
|
| 252 |
-
awq_model_path,
|
| 253 |
-
device_map=kwargs.get("device_map", "auto"),
|
| 254 |
-
trust_remote_code=True
|
| 255 |
-
)
|
| 256 |
-
else:
|
| 257 |
-
# Standard loading
|
| 258 |
-
model = MSAutoModelForCausalLM.from_pretrained(
|
| 259 |
-
model_path,
|
| 260 |
-
**load_kwargs
|
| 261 |
-
)
|
| 262 |
tokenizer = MSAutoTokenizer.from_pretrained(
|
| 263 |
model_path,
|
| 264 |
trust_remote_code=True,
|
|
@@ -282,25 +248,11 @@ class UnifiedModelLoader:
|
|
| 282 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 283 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 284 |
|
| 285 |
-
#
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
|
| 291 |
-
|
| 292 |
-
self.logger.info(f"π§ Loading AWQ quantized model from: {awq_model_path}")
|
| 293 |
-
model = AutoAWQForCausalLM.from_quantized(
|
| 294 |
-
awq_model_path,
|
| 295 |
-
device_map=kwargs.get("device_map", "auto"),
|
| 296 |
-
trust_remote_code=True
|
| 297 |
-
)
|
| 298 |
-
else:
|
| 299 |
-
# Standard loading
|
| 300 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 301 |
-
model_path,
|
| 302 |
-
**load_kwargs
|
| 303 |
-
)
|
| 304 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 305 |
model_path,
|
| 306 |
trust_remote_code=True,
|
|
|
|
| 7 |
from typing import Optional, Dict, Any, Tuple
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
|
| 10 |
from funasr_detach import AutoModel
|
| 11 |
|
| 12 |
# Global cache for downloaded models to avoid repeated downloads
|
|
|
|
| 105 |
Prepare quantization configuration for model loading
|
| 106 |
|
| 107 |
Args:
|
| 108 |
+
quantization_config: Quantization type ('int4', 'int8', or None)
|
| 109 |
torch_dtype: PyTorch data type for compute operations
|
| 110 |
|
| 111 |
Returns:
|
|
|
|
| 116 |
|
| 117 |
quantization_config = quantization_config.lower()
|
| 118 |
|
| 119 |
+
if quantization_config == "int8":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
# Use user-specified torch_dtype for compute, default to bfloat16
|
| 121 |
compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
|
| 122 |
self.logger.info(f"π§ INT8 quantization: using {compute_dtype} for compute operations")
|
|
|
|
| 143 |
"quantization_config": bnb_config
|
| 144 |
}, False # INT4 quantization handles torch_dtype internally, don't set it again
|
| 145 |
else:
|
| 146 |
+
raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8'")
|
| 147 |
|
| 148 |
def load_transformers_model(
|
| 149 |
self,
|
|
|
|
| 158 |
Args:
|
| 159 |
model_path: Model path or ID
|
| 160 |
source: Model source, auto means auto-detect
|
| 161 |
+
quantization_config: Quantization configuration ('int4', 'int8', or None for no quantization)
|
| 162 |
**kwargs: Other parameters (torch_dtype, device_map, etc.)
|
| 163 |
|
| 164 |
Returns:
|
|
|
|
| 190 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 191 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 192 |
|
| 193 |
+
# Standard loading
|
| 194 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 195 |
+
model_path,
|
| 196 |
+
**load_kwargs
|
| 197 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 199 |
model_path,
|
| 200 |
trust_remote_code=True,
|
|
|
|
| 220 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 221 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 222 |
|
| 223 |
+
# Standard loading
|
| 224 |
+
model = MSAutoModelForCausalLM.from_pretrained(
|
| 225 |
+
model_path,
|
| 226 |
+
**load_kwargs
|
| 227 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
tokenizer = MSAutoTokenizer.from_pretrained(
|
| 229 |
model_path,
|
| 230 |
trust_remote_code=True,
|
|
|
|
| 248 |
if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
|
| 249 |
load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
|
| 250 |
|
| 251 |
+
# Standard loading
|
| 252 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 253 |
+
model_path,
|
| 254 |
+
**load_kwargs
|
| 255 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 257 |
model_path,
|
| 258 |
trust_remote_code=True,
|
requirements.txt
CHANGED
|
@@ -22,4 +22,3 @@ gradio>=5.16.0
|
|
| 22 |
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 23 |
spaces==0.42.1
|
| 24 |
matplotlib==3.10.7
|
| 25 |
-
autoawq==0.2.9
|
|
|
|
| 22 |
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 23 |
spaces==0.42.1
|
| 24 |
matplotlib==3.10.7
|
|
|