from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification from lime.lime_text import LimeTextExplainer import numpy as np import os app = FastAPI(title="MedGuard API") # --- NUCLEAR CORS FIX --- # Allow EVERYTHING. This rules out CORS as the problem. app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- CONFIGURATION --- MODEL_PATH = "./model" DEVICE = "cpu" print(f"🔄 Loading Model from {MODEL_PATH}...") model = None tokenizer = None try: tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.to(DEVICE) model.eval() print("✅ Model Loaded Successfully!") except Exception as e: print(f"❌ Error loading local model: {e}") MODEL_NAME = "csebuetnlp/banglabert" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) # --- DATA MODELS --- class QueryRequest(BaseModel): text: str class PredictionResponse(BaseModel): label: str confidence: float probs: dict explanation: list = None LABELS = ["Highly Relevant", "Partially Relevant", "Not Relevant"] def predict_proba_lime(texts): inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() @app.get("/") def health_check(): return {"status": "active", "model": "MedGuard v1.0"} @app.post("/predict", response_model=PredictionResponse) def predict(request: QueryRequest): if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: inputs = tokenizer(request.text, return_tensors="pt", truncation=True, max_length=128).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0] pred_idx = np.argmax(probs) # LIME (Reduced to 20 samples for speed testing) explainer = LimeTextExplainer(class_names=LABELS, split_expression=lambda x: x.split()) exp = explainer.explain_instance(request.text, predict_proba_lime, num_features=6, num_samples=20, labels=[pred_idx]) lime_features = exp.as_list(label=pred_idx) return { "label": LABELS[pred_idx], "confidence": round(float(probs[pred_idx]) * 100, 2), "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)}, "explanation": lime_features } except Exception as e: print(f"Server Error: {e}") # Print error to backend terminal raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn # Bind to localhost specifically uvicorn.run(app, host="127.0.0.1", port=8000)