CLX_Finetuned / app.py
Suramya's picture
Update app.py
507d0d6 verified
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline,
BitsAndBytesConfig
)
from peft import PeftModel
# ============================================================
# Configuration
# ============================================================
BASE_MODEL = "NousResearch/Llama-2-7b-chat-hf"
ADAPTER = "Suramya/Llama-2-7b-CloudLex-Intent-Detection"
NUM_LABELS = 6 # MUST match training (Buying, Support, Careers, Partnership, Explore, Others)
LABEL_NAMES = [
"Buying",
"Support",
"Careers",
"Partnership",
"Explore",
"Others",
]
# ============================================================
# Quantization config (replaces deprecated load_in_4bit)
# ============================================================
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# ============================================================
# Load model + LoRA adapter
# ============================================================
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=NUM_LABELS, # πŸ”‘ CRITICAL FIX
device_map="auto",
quantization_config=bnb_config,
)
model = PeftModel.from_pretrained(
base_model,
ADAPTER,
)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER)
tokenizer.pad_token = tokenizer.eos_token
# ============================================================
# Pipeline
# ============================================================
clf = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True
)
# ============================================================
# Inference function
# ============================================================
def predict_intent(message: str):
if not message or not message.strip():
return {}
outputs = clf(message)[0]
# Map label IDs to human-readable names
results = {}
for i, item in enumerate(outputs):
label_name = LABEL_NAMES[i]
results[label_name] = float(item["score"])
return results
# ============================================================
# Gradio UI
# ============================================================
demo = gr.Interface(
fn=predict_intent,
inputs=gr.Textbox(
lines=3,
placeholder="Type a CloudLex-related message..."
),
outputs=gr.Label(num_top_classes=6),
title="CloudLex Intent Detection",
description=(
"Llama-2-7B fine-tuned with QLoRA for CloudLex intent classification.\n\n"
"Intents: Buying, Support, Careers, Partnership, Explore, Others"
),
examples=[
["I'd like to schedule a demo for our law firm"],
["My CloudLex account isn't loading properly"],
["Are you hiring software engineers?"],
["We want to partner with CloudLex"],
["What features does CloudLex offer?"],
["Just browsing"]
],
)
demo.launch()