arahrooh commited on
Commit
084bec8
·
1 Parent(s): 272b3bb

Add HF_TOKEN support for gated models

Browse files
Files changed (1) hide show
  1. bot.py +20 -2
bot.py CHANGED
@@ -258,10 +258,20 @@ class RAGBot:
258
  logger.info(f"Loading model: {model_name}...")
259
  from transformers import AutoTokenizer, AutoModelForCausalLM
260
 
 
 
 
261
  # Load tokenizer
 
 
 
 
 
 
 
262
  self.tokenizer = AutoTokenizer.from_pretrained(
263
  model_name,
264
- trust_remote_code=True
265
  )
266
 
267
  # Determine appropriate torch dtype based on device and model
@@ -280,6 +290,10 @@ class RAGBot:
280
  "trust_remote_code": True,
281
  }
282
 
 
 
 
 
283
  # For MPS, use device_map; for CUDA, let it auto-detect
284
  if self.device == "mps":
285
  model_kwargs["device_map"] = self.device
@@ -309,7 +323,11 @@ class RAGBot:
309
  except Exception as e:
310
  logger.error(f"Failed to load model {self.args.model}: {e}")
311
  logger.error("Make sure the model name is correct and you have access to it on HuggingFace")
312
- logger.error("For private models, ensure you're logged in: huggingface-cli login")
 
 
 
 
313
  sys.exit(2)
314
 
315
  def _setup_vector_retriever(self):
 
258
  logger.info(f"Loading model: {model_name}...")
259
  from transformers import AutoTokenizer, AutoModelForCausalLM
260
 
261
+ # Get Hugging Face token from environment (for gated models)
262
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
263
+
264
  # Load tokenizer
265
+ tokenizer_kwargs = {
266
+ "trust_remote_code": True
267
+ }
268
+ if hf_token:
269
+ tokenizer_kwargs["token"] = hf_token
270
+ logger.info("Using HF_TOKEN for authentication")
271
+
272
  self.tokenizer = AutoTokenizer.from_pretrained(
273
  model_name,
274
+ **tokenizer_kwargs
275
  )
276
 
277
  # Determine appropriate torch dtype based on device and model
 
290
  "trust_remote_code": True,
291
  }
292
 
293
+ # Add token if available (for gated models)
294
+ if hf_token:
295
+ model_kwargs["token"] = hf_token
296
+
297
  # For MPS, use device_map; for CUDA, let it auto-detect
298
  if self.device == "mps":
299
  model_kwargs["device_map"] = self.device
 
323
  except Exception as e:
324
  logger.error(f"Failed to load model {self.args.model}: {e}")
325
  logger.error("Make sure the model name is correct and you have access to it on HuggingFace")
326
+ logger.error("For gated models (like Llama), you need to:")
327
+ logger.error(" 1. Request access at: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct")
328
+ logger.error(" 2. Add HF_TOKEN as a secret in your Hugging Face Space settings")
329
+ logger.error(" 3. Get your token from: https://huggingface.co/settings/tokens")
330
+ logger.error("For local use, ensure you're logged in: huggingface-cli login")
331
  sys.exit(2)
332
 
333
  def _setup_vector_retriever(self):