Nexus_NLP_model / app.py
Krish Patel
try1
13e414c
raw
history blame
6.29 kB
# # import streamlit as st
# # import torch
# # from transformers import AutoTokenizer, AutoModelForSequenceClassification
# # # Load the model and tokenizer
# # # @st.cache_resource
# # # def load_model():
# # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
# # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# # # model.eval()
# # # return tokenizer, model
# # @st.cache_resource
# # def load_model():
# # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
# # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# # model.eval()
# # return tokenizer, model
# # def predict_news(text, tokenizer, model):
# # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# # with torch.no_grad():
# # outputs = model(**inputs)
# # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# # predicted_label = torch.argmax(probabilities, dim=-1).item()
# # confidence = probabilities[0][predicted_label].item()
# # return "FAKE" if predicted_label == 1 else "REAL", confidence
# # def main():
# # st.title("News Classifier")
# # # Load model
# # tokenizer, model = load_model()
# # # Text input
# # news_text = st.text_area("Enter news text to analyze:", height=200)
# # if st.button("Classify"):
# # if news_text:
# # with st.spinner('Analyzing...'):
# # prediction, confidence = predict_news(news_text, tokenizer, model)
# # # Display results
# # if prediction == "FAKE":
# # st.error(f"⚠️ {prediction} NEWS")
# # else:
# # st.success(f"✅ {prediction} NEWS")
# # st.info(f"Confidence: {confidence*100:.2f}%")
# # if __name__ == "__main__":
# # main()
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI, Request
from pydantic import BaseModel
from threading import Thread
from streamlit.web import cli
# FastAPI app
api_app = FastAPI()
# Load the model and tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
model.eval()
return tokenizer, model
# Prediction function
def predict_news(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_label].item()
return "FAKE" if predicted_label == 1 else "REAL", confidence
# FastAPI request model
class NewsInput(BaseModel):
text: str
# FastAPI route for POST requests
@api_app.post("/classify")
async def classify_news(data: NewsInput):
tokenizer, model = load_model()
prediction, confidence = predict_news(data.text, tokenizer, model)
return {
"prediction": prediction,
"confidence": f"{confidence*100:.2f}%"
}
# Streamlit app
def run_streamlit():
def main():
st.title("News Classifier")
# Load model
tokenizer, model = load_model()
# Text input
news_text = st.text_area("Enter news text to analyze:", height=200)
if st.button("Classify"):
if news_text:
with st.spinner('Analyzing...'):
prediction, confidence = predict_news(news_text, tokenizer, model)
# Display results
if prediction == "FAKE":
st.error(f"⚠️ {prediction} NEWS")
else:
st.success(f"✅ {prediction} NEWS")
st.info(f"Confidence: {confidence*100:.2f}%")
main()
# Threaded execution for FastAPI and Streamlit
def start_fastapi():
import uvicorn
uvicorn.run(api_app, host="0.0.0.0", port=8502)
if __name__ == "__main__":
fastapi_thread = Thread(target=start_fastapi, daemon=True)
fastapi_thread.start()
# Start Streamlit
cli.main()
# from fastapi import FastAPI, HTTPException
# from pydantic import BaseModel
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# import torch
# from fastapi.middleware.cors import CORSMiddleware
# # Define the FastAPI app
# app = FastAPI()
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"], # Update with your frontend's URL for security
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Define the input data schema
# class InputText(BaseModel):
# text: str
# # Load the model and tokenizer (ensure these paths are correct in your Space)
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# model.eval()
# # Prediction function
# def predict_news(text: str):
# inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# predicted_label = torch.argmax(probabilities, dim=-1).item()
# confidence = probabilities[0][predicted_label].item()
# return {
# "prediction": "FAKE" if predicted_label == 1 else "REAL",
# "confidence": round(confidence * 100, 2) # Return confidence as a percentage
# }
# # Define the POST endpoint
# @app.post("/predict")
# async def classify_news(input_text: InputText):
# try:
# result = predict_news(input_text.text)
# return result
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))