AliHamza852's picture
Create app.py
637717f verified
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
import gradio as gr
# ============================================================
# CONFIG
# ============================================================
MODEL_NAME = "t5-small"
WEIGHTS_PATH = "t5_weights.pt" # your uploaded weights file
# ============================================================
# LOAD MODEL & TOKENIZER
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model.to(device)
model.eval()
# ============================================================
# SUMMARIZATION FUNCTION
# ============================================================
def summarize(text):
if not text.strip():
return "⚠️ Please enter some text."
inputs = tokenizer(
"summarize: " + text,
return_tensors="pt",
truncation=True,
max_length=512,
padding="max_length"
).to(device)
summary_ids = model.generate(
**inputs,
max_new_tokens=150,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# ============================================================
# GRADIO UI
# ============================================================
demo = gr.Interface(
fn=summarize,
inputs=gr.Textbox(lines=10, label="Enter News Article"),
outputs=gr.Textbox(lines=8, label="Generated Summary"),
title="📰 T5 News Summarizer",
description="Upload your own fine-tuned T5 model weights and summarize news articles easily!",
examples=[
["The US economy grew at an annual rate of 2.4% last quarter, surprising analysts..."],
["Scientists have discovered a new species of bird in the Amazon rainforest..."]
]
)
# ============================================================
# RUN APP
# ============================================================
if __name__ == "__main__":
demo.launch()