Jiaqi-hkust commited on
Commit
4f2a894
·
verified ·
1 Parent(s): b0ac577

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +25 -16
app.py CHANGED
@@ -40,12 +40,6 @@ if not is_spaces:
40
 
41
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
42
 
43
- def gpu_decorator(func):
44
- """条件应用 GPU 装饰器"""
45
- if spaces_available and GPU is not None:
46
- return GPU(func)
47
- return func
48
-
49
  print(f"==========================================")
50
  print(f"Initializing application...")
51
  print(f"==========================================")
@@ -55,27 +49,38 @@ class ModelHandler:
55
  self.model_path = model_path
56
  self.model = None
57
  self.processor = None
58
- self._load_model()
59
 
60
  def _load_model(self):
 
 
 
 
61
  try:
62
  print(f"⏳ Loading model weights, this may take a few minutes...")
63
 
64
  self.processor = AutoProcessor.from_pretrained(self.model_path)
65
 
66
- if torch.cuda.is_available():
67
- device_capability = torch.cuda.get_device_capability()
68
- use_flash_attention = device_capability[0] >= 8
69
- print(f"🔧 CUDA available, device capability: {device_capability}")
70
- else:
71
- use_flash_attention = False
72
- print(f"🔧 Using CPU or non-CUDA device")
 
 
 
 
 
 
 
 
73
 
74
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
75
  self.model_path,
76
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
77
  device_map="auto",
78
- # attn_implementation="flash_attention_2" if use_flash_attention else "eager",
79
  attn_implementation="sdpa",
80
  trust_remote_code=True
81
  )
@@ -85,6 +90,10 @@ class ModelHandler:
85
  raise e
86
 
87
  def predict(self, message_dict, history, temperature, max_tokens):
 
 
 
 
88
  text = message_dict.get("text", "")
89
  files = message_dict.get("files", [])
90
 
 
40
 
41
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
42
 
 
 
 
 
 
 
43
  print(f"==========================================")
44
  print(f"Initializing application...")
45
  print(f"==========================================")
 
49
  self.model_path = model_path
50
  self.model = None
51
  self.processor = None
52
+ # 不在 __init__ 中加载模型,延迟到实际使用时
53
 
54
  def _load_model(self):
55
+ """延迟加载模型,在 GPU 装饰器函数内部调用"""
56
+ if self.model is not None:
57
+ return # 已经加载过了
58
+
59
  try:
60
  print(f"⏳ Loading model weights, this may take a few minutes...")
61
 
62
  self.processor = AutoProcessor.from_pretrained(self.model_path)
63
 
64
+ # 在 ZeroGPU 环境中,避免过早检查 CUDA
65
+ # 让 device_map="auto" 自动处理设备分配
66
+ try:
67
+ cuda_available = torch.cuda.is_available()
68
+ if cuda_available:
69
+ device_capability = torch.cuda.get_device_capability()
70
+ print(f"🔧 CUDA available, device capability: {device_capability}")
71
+ torch_dtype = torch.bfloat16
72
+ else:
73
+ print(f"🔧 Using CPU or non-CUDA device")
74
+ torch_dtype = torch.float32
75
+ except RuntimeError:
76
+ # ZeroGPU 环境中可能暂时无法检查 CUDA
77
+ print(f"🔧 CUDA check skipped (ZeroGPU environment)")
78
+ torch_dtype = torch.bfloat16 # 假设有 GPU,让 device_map 处理
79
 
80
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
  self.model_path,
82
+ torch_dtype=torch_dtype,
83
  device_map="auto",
 
84
  attn_implementation="sdpa",
85
  trust_remote_code=True
86
  )
 
90
  raise e
91
 
92
  def predict(self, message_dict, history, temperature, max_tokens):
93
+ # 确保模型已加载
94
+ if self.model is None:
95
+ self._load_model()
96
+
97
  text = message_dict.get("text", "")
98
  files = message_dict.get("files", [])
99