nazdridoy commited on
Commit
9a50492
Β·
verified Β·
1 Parent(s): 43333ad

feat(auth): add Hugging Face OAuth access control

Browse files

- [docs] Add documentation for Hugging Face OAuth configuration and allowed organizations (README.md:6,44-45,240-244)
- [feat] Add Gradio LoginButton and import Gradio (app.py:5,27-28)
- [feat] Integrate OAuth token handling and organization access checks into chat submission and retry functions (chat_handler.py:5,17-18,170,175-182,199,204-211)
- [feat] Integrate OAuth token handling and organization access checks into image generation functions (image_handler.py:5,18-19,279,284-288,309,314-318)
- [feat] Integrate OAuth token handling and organization access checks into text-to-speech generation (tts_handler.py:5,19-20,156,161-165)
- [feat] Add utility functions for parsing allowed organizations, fetching Hugging Face identity, checking organization access, and formatting access denied messages (utils.py:6,231-237,241-257,260-276,279-281)

Files changed (6) hide show
  1. README.md +9 -0
  2. app.py +3 -0
  3. chat_handler.py +25 -3
  4. image_handler.py +18 -3
  5. tts_handler.py +11 -2
  6. utils.py +64 -0
README.md CHANGED
@@ -6,6 +6,7 @@ colorTo: blue
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
 
9
  ---
10
 
11
  # πŸš€ HF-Inferoxy AI Hub
@@ -44,6 +45,9 @@ Add the following secrets to your HuggingFace Space:
44
  - **Key**: `PROXY_URL`
45
  - **Value**: Your HF-Inferoxy proxy server URL (e.g., `https://hf-proxy.example.com`)
46
 
 
 
 
47
  ### 2. HF-Inferoxy Server
48
 
49
  The app will use the HF-Inferoxy server URL specified in the `PROXY_URL` secret.
@@ -240,6 +244,11 @@ Prompt: "Help me debug this Python code: [paste code]"
240
 
241
  ## πŸ”’ Security & Authentication
242
 
 
 
 
 
 
243
  ### RBAC System
244
  - All operations require authentication with the HF-Inferoxy proxy server
245
  - API keys are managed securely through HuggingFace Space secrets
 
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
+ hf_oauth: true
10
  ---
11
 
12
  # πŸš€ HF-Inferoxy AI Hub
 
45
  - **Key**: `PROXY_URL`
46
  - **Value**: Your HF-Inferoxy proxy server URL (e.g., `https://hf-proxy.example.com`)
47
 
48
+ - **Key**: `ALLOWED_ORGS`
49
+ - **Value**: Comma- or space-separated list of org names allowed to use this Space (e.g., `acme, acme-research`)
50
+
51
  ### 2. HF-Inferoxy Server
52
 
53
  The app will use the HF-Inferoxy server URL specified in the `PROXY_URL` secret.
 
244
 
245
  ## πŸ”’ Security & Authentication
246
 
247
+ ### Hugging Face OAuth (no inference scope)
248
+ - Login is required. The app uses Hugging Face OAuth and automatically injects an access token.
249
+ - We do not request the `inference-api` scope; the token is used only to call `whoami-v2` to verify org membership.
250
+ - Inference calls continue to use tokens provisioned by your HF-Inferoxy proxy.
251
+
252
  ### RBAC System
253
  - All operations require authentication with the HF-Inferoxy proxy server
254
  - API keys are managed securely through HuggingFace Space secrets
app.py CHANGED
@@ -23,6 +23,9 @@ def create_app():
23
 
24
  # Create the main Gradio interface with tabs
25
  with gr.Blocks(title="HF-Inferoxy AI Hub", theme=get_gradio_theme()) as demo:
 
 
 
26
 
27
  # Main header
28
  create_main_header()
 
23
 
24
  # Create the main Gradio interface with tabs
25
  with gr.Blocks(title="HF-Inferoxy AI Hub", theme=get_gradio_theme()) as demo:
26
+ # Sidebar with HF OAuth login/logout
27
+ with gr.Sidebar():
28
+ gr.LoginButton()
29
 
30
  # Main header
31
  create_main_header()
chat_handler.py CHANGED
@@ -4,6 +4,7 @@ Handles chat completion requests with streaming responses.
4
  """
5
 
6
  import os
 
7
  import time
8
  import threading
9
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
@@ -14,7 +15,9 @@ from hf_token_utils import get_proxy_token, report_token_status
14
  from utils import (
15
  validate_proxy_key,
16
  parse_model_and_provider,
17
- format_error_message
 
 
18
  )
19
 
20
  # Timeout configuration for inference requests
@@ -164,7 +167,7 @@ def chat_respond(
164
  yield format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
165
 
166
 
167
- def handle_chat_submit(message, history, system_msg, model_name, max_tokens, temperature, top_p):
168
  """
169
  Handle chat submission and manage conversation history with streaming.
170
  """
@@ -172,6 +175,16 @@ def handle_chat_submit(message, history, system_msg, model_name, max_tokens, tem
172
  yield history, ""
173
  return
174
 
 
 
 
 
 
 
 
 
 
 
175
  # Add user message to history
176
  history = history + [{"role": "user", "content": message}]
177
 
@@ -195,11 +208,20 @@ def handle_chat_submit(message, history, system_msg, model_name, max_tokens, tem
195
  yield current_history, ""
196
 
197
 
198
- def handle_chat_retry(history, system_msg, model_name, max_tokens, temperature, top_p, retry_data=None):
199
  """
200
  Retry the assistant response for the selected message.
201
  Works with gr.Chatbot.retry() which provides retry_data.index for the message.
202
  """
 
 
 
 
 
 
 
 
 
203
  # Guard: empty history
204
  if not history:
205
  yield history
 
4
  """
5
 
6
  import os
7
+ import gradio as gr
8
  import time
9
  import threading
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
 
15
  from utils import (
16
  validate_proxy_key,
17
  parse_model_and_provider,
18
+ format_error_message,
19
+ check_org_access,
20
+ format_access_denied_message
21
  )
22
 
23
  # Timeout configuration for inference requests
 
167
  yield format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
168
 
169
 
170
+ def handle_chat_submit(message, history, system_msg, model_name, max_tokens, temperature, top_p, hf_token: gr.OAuthToken = None):
171
  """
172
  Handle chat submission and manage conversation history with streaming.
173
  """
 
175
  yield history, ""
176
  return
177
 
178
+ # Enforce org-based access control via HF OAuth token
179
+ access_token = getattr(hf_token, "token", None) if hf_token is not None else None
180
+ is_allowed, access_msg, _username, _matched = check_org_access(access_token)
181
+ if not is_allowed:
182
+ # Show access denied as assistant message
183
+ assistant_response = format_access_denied_message(access_msg)
184
+ current_history = history + [{"role": "assistant", "content": assistant_response}]
185
+ yield current_history, ""
186
+ return
187
+
188
  # Add user message to history
189
  history = history + [{"role": "user", "content": message}]
190
 
 
208
  yield current_history, ""
209
 
210
 
211
+ def handle_chat_retry(history, system_msg, model_name, max_tokens, temperature, top_p, hf_token: gr.OAuthToken = None, retry_data=None):
212
  """
213
  Retry the assistant response for the selected message.
214
  Works with gr.Chatbot.retry() which provides retry_data.index for the message.
215
  """
216
+ # Enforce org-based access control via HF OAuth token
217
+ access_token = getattr(hf_token, "token", None) if hf_token is not None else None
218
+ is_allowed, access_msg, _username, _matched = check_org_access(access_token)
219
+ if not is_allowed:
220
+ # Show access denied as assistant message
221
+ assistant_response = format_access_denied_message(access_msg)
222
+ current_history = (history or []) + [{"role": "assistant", "content": assistant_response}]
223
+ yield current_history
224
+ return
225
  # Guard: empty history
226
  if not history:
227
  yield history
image_handler.py CHANGED
@@ -4,6 +4,7 @@ Handles text-to-image generation with multiple providers.
4
  """
5
 
6
  import os
 
7
  import time
8
  import threading
9
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
@@ -15,7 +16,9 @@ from utils import (
15
  IMAGE_CONFIG,
16
  validate_proxy_key,
17
  format_error_message,
18
- format_success_message
 
 
19
  )
20
 
21
  # Timeout configuration for image generation
@@ -276,7 +279,7 @@ def generate_image_to_image(
276
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
277
 
278
 
279
- def handle_image_to_image_generation(input_image_val, prompt_val, model_val, provider_val, negative_prompt_val, steps_val, guidance_val, seed_val):
280
  """
281
  Handle image-to-image generation request with validation.
282
  """
@@ -284,6 +287,12 @@ def handle_image_to_image_generation(input_image_val, prompt_val, model_val, pro
284
  if input_image_val is None:
285
  return None, format_error_message("Validation Error", "Please upload an input image")
286
 
 
 
 
 
 
 
287
  # Generate image-to-image
288
  return generate_image_to_image(
289
  input_image=input_image_val,
@@ -297,7 +306,7 @@ def handle_image_to_image_generation(input_image_val, prompt_val, model_val, pro
297
  )
298
 
299
 
300
- def handle_image_generation(prompt_val, model_val, provider_val, negative_prompt_val, width_val, height_val, steps_val, guidance_val, seed_val):
301
  """
302
  Handle image generation request with validation.
303
  """
@@ -306,6 +315,12 @@ def handle_image_generation(prompt_val, model_val, provider_val, negative_prompt
306
  if not is_valid:
307
  return None, format_error_message("Validation Error", error_msg)
308
 
 
 
 
 
 
 
309
  # Generate image
310
  return generate_image(
311
  prompt=prompt_val,
 
4
  """
5
 
6
  import os
7
+ import gradio as gr
8
  import time
9
  import threading
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
 
16
  IMAGE_CONFIG,
17
  validate_proxy_key,
18
  format_error_message,
19
+ format_success_message,
20
+ check_org_access,
21
+ format_access_denied_message
22
  )
23
 
24
  # Timeout configuration for image generation
 
279
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
280
 
281
 
282
+ def handle_image_to_image_generation(input_image_val, prompt_val, model_val, provider_val, negative_prompt_val, steps_val, guidance_val, seed_val, hf_token: gr.OAuthToken = None):
283
  """
284
  Handle image-to-image generation request with validation.
285
  """
 
287
  if input_image_val is None:
288
  return None, format_error_message("Validation Error", "Please upload an input image")
289
 
290
+ # Enforce org-based access control via HF OAuth token
291
+ access_token = getattr(hf_token, "token", None) if hf_token is not None else None
292
+ is_allowed, access_msg, _username, _matched = check_org_access(access_token)
293
+ if not is_allowed:
294
+ return None, format_access_denied_message(access_msg)
295
+
296
  # Generate image-to-image
297
  return generate_image_to_image(
298
  input_image=input_image_val,
 
306
  )
307
 
308
 
309
+ def handle_image_generation(prompt_val, model_val, provider_val, negative_prompt_val, width_val, height_val, steps_val, guidance_val, seed_val, hf_token: gr.OAuthToken = None):
310
  """
311
  Handle image generation request with validation.
312
  """
 
315
  if not is_valid:
316
  return None, format_error_message("Validation Error", error_msg)
317
 
318
+ # Enforce org-based access control via HF OAuth token
319
+ access_token = getattr(hf_token, "token", None) if hf_token is not None else None
320
+ is_allowed, access_msg, _username, _matched = check_org_access(access_token)
321
+ if not is_allowed:
322
+ return None, format_access_denied_message(access_msg)
323
+
324
  # Generate image
325
  return generate_image(
326
  prompt=prompt_val,
tts_handler.py CHANGED
@@ -4,6 +4,7 @@ Handles text-to-speech generation with multiple providers.
4
  """
5
 
6
  import os
 
7
  import time
8
  import threading
9
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
@@ -16,7 +17,9 @@ from utils import (
16
  validate_proxy_key,
17
  format_error_message,
18
  format_success_message,
19
- TTS_MODEL_CONFIGS
 
 
20
  )
21
 
22
  # Timeout configuration for TTS generation
@@ -153,7 +156,7 @@ def generate_text_to_speech(
153
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
154
 
155
 
156
- def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_val, speed_val, audio_url_val, exaggeration_val, temperature_val, cfg_val):
157
  """
158
  Handle text-to-speech generation request with validation.
159
  """
@@ -165,6 +168,12 @@ def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_va
165
  if len(text_val) > 5000:
166
  return None, format_error_message("Validation Error", "Text is too long. Please keep it under 5000 characters.")
167
 
 
 
 
 
 
 
168
  # Generate speech
169
  return generate_text_to_speech(
170
  text=text_val.strip(),
 
4
  """
5
 
6
  import os
7
+ import gradio as gr
8
  import time
9
  import threading
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
 
17
  validate_proxy_key,
18
  format_error_message,
19
  format_success_message,
20
+ TTS_MODEL_CONFIGS,
21
+ check_org_access,
22
+ format_access_denied_message
23
  )
24
 
25
  # Timeout configuration for TTS generation
 
156
  return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}")
157
 
158
 
159
+ def handle_text_to_speech_generation(text_val, model_val, provider_val, voice_val, speed_val, audio_url_val, exaggeration_val, temperature_val, cfg_val, hf_token: gr.OAuthToken = None):
160
  """
161
  Handle text-to-speech generation request with validation.
162
  """
 
168
  if len(text_val) > 5000:
169
  return None, format_error_message("Validation Error", "Text is too long. Please keep it under 5000 characters.")
170
 
171
+ # Enforce org-based access control via HF OAuth token
172
+ access_token = getattr(hf_token, "token", None) if hf_token is not None else None
173
+ is_allowed, access_msg, _username, _matched = check_org_access(access_token)
174
+ if not is_allowed:
175
+ return None, format_access_denied_message(access_msg)
176
+
177
  # Generate speech
178
  return generate_text_to_speech(
179
  text=text_val.strip(),
utils.py CHANGED
@@ -4,6 +4,7 @@ Contains configuration constants and helper functions.
4
  """
5
 
6
  import os
 
7
 
8
 
9
  # Configuration constants
@@ -226,3 +227,66 @@ def get_gradio_theme():
226
  return gr.themes.Soft()
227
  except ImportError:
228
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import os
7
+ import requests
8
 
9
 
10
  # Configuration constants
 
227
  return gr.themes.Soft()
228
  except ImportError:
229
  return None
230
+
231
+
232
+ # -----------------------------
233
+ # OAuth / Org Access Utilities
234
+ # -----------------------------
235
+
236
+ def _parse_allowed_orgs() -> list[str]:
237
+ """Parse comma/space separated ALLOWED_ORGS env var into a list of lowercase names."""
238
+ raw = os.getenv("ALLOWED_ORGS", "").strip()
239
+ if not raw:
240
+ return []
241
+ # support comma or whitespace separated
242
+ parts = [p.strip().lower() for p in raw.replace("\n", ",").replace(" ", ",").split(",") if p.strip()]
243
+ return list(dict.fromkeys(parts)) # dedupe while preserving order
244
+
245
+
246
+ def fetch_hf_identity(access_token: str) -> tuple[bool, dict | None, str]:
247
+ """
248
+ Call whoami-v2 to get user identity and orgs.
249
+ Returns (success, data, error_message).
250
+ """
251
+ if not access_token:
252
+ return False, None, "Missing access token"
253
+ try:
254
+ resp = requests.get(
255
+ "https://huggingface.co/api/whoami-v2",
256
+ headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
257
+ timeout=15,
258
+ )
259
+ if resp.status_code != 200:
260
+ return False, None, f"HF whoami-v2 HTTP {resp.status_code}"
261
+ return True, resp.json(), ""
262
+ except requests.exceptions.RequestException as e:
263
+ return False, None, f"HF whoami-v2 error: {str(e)}"
264
+
265
+
266
+ def check_org_access(access_token: str) -> tuple[bool, str, str | None, list[str]]:
267
+ """
268
+ Validate that the logged-in user belongs to any of ALLOWED_ORGS.
269
+ Returns (is_allowed, message, username, matched_orgs).
270
+ """
271
+ allowed_orgs = _parse_allowed_orgs()
272
+ if not access_token:
273
+ return False, "πŸ”’ Please log in with Hugging Face to continue.", None, []
274
+ if not allowed_orgs:
275
+ return False, "❌ Access denied: ALLOWED_ORGS is not configured in Space secrets.", None, []
276
+
277
+ ok, data, err = fetch_hf_identity(access_token)
278
+ if not ok or not data:
279
+ return False, f"❌ Failed to verify identity: {err}", None, []
280
+
281
+ username = data.get("name") or data.get("fullname") or data.get("id")
282
+ org_objs = data.get("orgs", []) or []
283
+ user_org_names = [str(org.get("name", "")).lower() for org in org_objs if org.get("name")]
284
+ matched = sorted(list(set(user_org_names).intersection(set(allowed_orgs))))
285
+ if matched:
286
+ return True, f"βœ… Access granted for @{username} in org(s): {', '.join(matched)}", username, matched
287
+ return False, f"🚫 Access denied for @{username}. Required org(s): {', '.join(allowed_orgs)}", username, []
288
+
289
+
290
+ def format_access_denied_message(message: str) -> str:
291
+ """Return a standardized access denied message for UI display."""
292
+ return format_error_message("Access Denied", message)