Spaces:
Sleeping
Sleeping
| # # import streamlit as st | |
| # # import torch | |
| # # from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # # # Load the model and tokenizer | |
| # # # @st.cache_resource | |
| # # # def load_model(): | |
| # # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') | |
| # # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
| # # # model.eval() | |
| # # # return tokenizer, model | |
| # # @st.cache_resource | |
| # # def load_model(): | |
| # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
| # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
| # # model.eval() | |
| # # return tokenizer, model | |
| # # def predict_news(text, tokenizer, model): | |
| # # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| # # with torch.no_grad(): | |
| # # outputs = model(**inputs) | |
| # # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # # predicted_label = torch.argmax(probabilities, dim=-1).item() | |
| # # confidence = probabilities[0][predicted_label].item() | |
| # # return "FAKE" if predicted_label == 1 else "REAL", confidence | |
| # # def main(): | |
| # # st.title("News Classifier") | |
| # # # Load model | |
| # # tokenizer, model = load_model() | |
| # # # Text input | |
| # # news_text = st.text_area("Enter news text to analyze:", height=200) | |
| # # if st.button("Classify"): | |
| # # if news_text: | |
| # # with st.spinner('Analyzing...'): | |
| # # prediction, confidence = predict_news(news_text, tokenizer, model) | |
| # # # Display results | |
| # # if prediction == "FAKE": | |
| # # st.error(f"⚠️ {prediction} NEWS") | |
| # # else: | |
| # # st.success(f"✅ {prediction} NEWS") | |
| # # st.info(f"Confidence: {confidence*100:.2f}%") | |
| # # if __name__ == "__main__": | |
| # # main() | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| from threading import Thread | |
| from streamlit.web import cli | |
| # FastAPI app | |
| api_app = FastAPI() | |
| # Load the model and tokenizer | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
| model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
| model.eval() | |
| return tokenizer, model | |
| # Prediction function | |
| def predict_news(text, tokenizer, model): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_label = torch.argmax(probabilities, dim=-1).item() | |
| confidence = probabilities[0][predicted_label].item() | |
| return "FAKE" if predicted_label == 1 else "REAL", confidence | |
| # FastAPI request model | |
| class NewsInput(BaseModel): | |
| text: str | |
| # FastAPI route for POST requests | |
| async def classify_news(data: NewsInput): | |
| tokenizer, model = load_model() | |
| prediction, confidence = predict_news(data.text, tokenizer, model) | |
| return { | |
| "prediction": prediction, | |
| "confidence": f"{confidence*100:.2f}%" | |
| } | |
| # Streamlit app | |
| def run_streamlit(): | |
| def main(): | |
| st.title("News Classifier") | |
| # Load model | |
| tokenizer, model = load_model() | |
| # Text input | |
| news_text = st.text_area("Enter news text to analyze:", height=200) | |
| if st.button("Classify"): | |
| if news_text: | |
| with st.spinner('Analyzing...'): | |
| prediction, confidence = predict_news(news_text, tokenizer, model) | |
| # Display results | |
| if prediction == "FAKE": | |
| st.error(f"⚠️ {prediction} NEWS") | |
| else: | |
| st.success(f"✅ {prediction} NEWS") | |
| st.info(f"Confidence: {confidence*100:.2f}%") | |
| main() | |
| # Threaded execution for FastAPI and Streamlit | |
| def start_fastapi(): | |
| import uvicorn | |
| uvicorn.run(api_app, host="0.0.0.0", port=8502) | |
| if __name__ == "__main__": | |
| fastapi_thread = Thread(target=start_fastapi, daemon=True) | |
| fastapi_thread.start() | |
| # Start Streamlit | |
| cli.main() | |
| # from fastapi import FastAPI, HTTPException | |
| # from pydantic import BaseModel | |
| # from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # import torch | |
| # from fastapi.middleware.cors import CORSMiddleware | |
| # # Define the FastAPI app | |
| # app = FastAPI() | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], # Update with your frontend's URL for security | |
| # allow_credentials=True, | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # # Define the input data schema | |
| # class InputText(BaseModel): | |
| # text: str | |
| # # Load the model and tokenizer (ensure these paths are correct in your Space) | |
| # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
| # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
| # model.eval() | |
| # # Prediction function | |
| # def predict_news(text: str): | |
| # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| # with torch.no_grad(): | |
| # outputs = model(**inputs) | |
| # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # predicted_label = torch.argmax(probabilities, dim=-1).item() | |
| # confidence = probabilities[0][predicted_label].item() | |
| # return { | |
| # "prediction": "FAKE" if predicted_label == 1 else "REAL", | |
| # "confidence": round(confidence * 100, 2) # Return confidence as a percentage | |
| # } | |
| # # Define the POST endpoint | |
| # @app.post("/predict") | |
| # async def classify_news(input_text: InputText): | |
| # try: | |
| # result = predict_news(input_text.text) | |
| # return result | |
| # except Exception as e: | |
| # raise HTTPException(status_code=500, detail=str(e)) | |