Spaces:
Build error
Build error
| import gradio as gr | |
| import openai | |
| import os | |
| import requests | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| import torch | |
| import faiss | |
| import numpy as np | |
| import json | |
| def clean_payload(payload): | |
| # Remove "data:" prefix and clean newline characters | |
| cleaned_payload = payload.lstrip("data:").rstrip("\n") | |
| try: | |
| json_payload = json.loads(cleaned_payload) | |
| except json.JSONDecodeError as e: | |
| print(f"JSON decoding error: {e}") | |
| json_payload = None | |
| return json_payload | |
| from huggingface_hub import InferenceClient # Keeping Hugging Face Client as requested | |
| def clean_payload(payload): | |
| cleaned_payload = payload.lstrip("data:").rstrip("\n") | |
| return json.loads(cleaned_payload) | |
| # API Keys and Org ID | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| openai.organization = os.getenv("OPENAI_ORG_ID") | |
| serper_api_key = os.getenv("SERPER_API_KEY") # SERPER API key from environment variable | |
| # Load PubMedBERT tokenizer and model | |
| tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
| model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) | |
| # FAISS setup for vector search | |
| dimension = 768 | |
| index = faiss.IndexFlatL2(dimension) | |
| # Embed text (PubMedBERT) | |
| def embed_text(text): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
| outputs = model(**inputs, output_hidden_states=True) | |
| return outputs.hidden_states[-1].mean(dim=1).detach().numpy() | |
| # Handle FDA query | |
| def handle_fda_query(query): | |
| inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512) | |
| logits = model(**inputs).logits | |
| return "FDA Query Processed: Contains regulatory info." if torch.argmax(logits, dim=1).item() == 1 else "FDA Query Processed: General." | |
| # Function to enhance info via GPT-4o-mini | |
| def enhance_with_gpt4o(fda_response): | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", # Correct model | |
| messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}], | |
| max_tokens=150 | |
| ) | |
| return response['choices'][0]['message']['content'] | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def respond(message, system_message, max_tokens, temperature, top_p): | |
| try: | |
| # First retrieve info via PubMedBERT | |
| fda_response = handle_fda_query(message) | |
| # Stream the enhanced response via GPT-4o-mini using the correct OpenAI API | |
| response = openai.ChatCompletion.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert FDA assistant."}, | |
| {"role": "user", "content": f"Enhance this FDA info: {fda_response}"} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| enhanced_response = "" | |
| for chat_message in response['choices']: | |
| payload = chat_message['message']['content'] | |
| enhanced_response += payload | |
| # Return both the PubMedBERT result and the enhanced version | |
| return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=respond, | |
| inputs=[ | |
| gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."), | |
| gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
| ], | |
| outputs="text", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |