import torch from transformers import T5ForConditionalGeneration, T5TokenizerFast import gradio as gr # ============================================================ # CONFIG # ============================================================ MODEL_NAME = "t5-small" WEIGHTS_PATH = "t5_weights.pt" # your uploaded weights file # ============================================================ # LOAD MODEL & TOKENIZER # ============================================================ device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME) model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device)) model.to(device) model.eval() # ============================================================ # SUMMARIZATION FUNCTION # ============================================================ def summarize(text): if not text.strip(): return "⚠️ Please enter some text." inputs = tokenizer( "summarize: " + text, return_tensors="pt", truncation=True, max_length=512, padding="max_length" ).to(device) summary_ids = model.generate( **inputs, max_new_tokens=150, num_beams=4, early_stopping=True, no_repeat_ngram_size=3 ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) # ============================================================ # GRADIO UI # ============================================================ demo = gr.Interface( fn=summarize, inputs=gr.Textbox(lines=10, label="Enter News Article"), outputs=gr.Textbox(lines=8, label="Generated Summary"), title="📰 T5 News Summarizer", description="Upload your own fine-tuned T5 model weights and summarize news articles easily!", examples=[ ["The US economy grew at an annual rate of 2.4% last quarter, surprising analysts..."], ["Scientists have discovered a new species of bird in the Amazon rainforest..."] ] ) # ============================================================ # RUN APP # ============================================================ if __name__ == "__main__": demo.launch()