Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| pipeline, | |
| BitsAndBytesConfig | |
| ) | |
| from peft import PeftModel | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| BASE_MODEL = "NousResearch/Llama-2-7b-chat-hf" | |
| ADAPTER = "Suramya/Llama-2-7b-CloudLex-Intent-Detection" | |
| NUM_LABELS = 6 # MUST match training (Buying, Support, Careers, Partnership, Explore, Others) | |
| LABEL_NAMES = [ | |
| "Buying", | |
| "Support", | |
| "Careers", | |
| "Partnership", | |
| "Explore", | |
| "Others", | |
| ] | |
| # ============================================================ | |
| # Quantization config (replaces deprecated load_in_4bit) | |
| # ============================================================ | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # ============================================================ | |
| # Load model + LoRA adapter | |
| # ============================================================ | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=NUM_LABELS, # π CRITICAL FIX | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| ) | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| ADAPTER, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ============================================================ | |
| # Pipeline | |
| # ============================================================ | |
| clf = pipeline( | |
| task="text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_all_scores=True | |
| ) | |
| # ============================================================ | |
| # Inference function | |
| # ============================================================ | |
| def predict_intent(message: str): | |
| if not message or not message.strip(): | |
| return {} | |
| outputs = clf(message)[0] | |
| # Map label IDs to human-readable names | |
| results = {} | |
| for i, item in enumerate(outputs): | |
| label_name = LABEL_NAMES[i] | |
| results[label_name] = float(item["score"]) | |
| return results | |
| # ============================================================ | |
| # Gradio UI | |
| # ============================================================ | |
| demo = gr.Interface( | |
| fn=predict_intent, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| placeholder="Type a CloudLex-related message..." | |
| ), | |
| outputs=gr.Label(num_top_classes=6), | |
| title="CloudLex Intent Detection", | |
| description=( | |
| "Llama-2-7B fine-tuned with QLoRA for CloudLex intent classification.\n\n" | |
| "Intents: Buying, Support, Careers, Partnership, Explore, Others" | |
| ), | |
| examples=[ | |
| ["I'd like to schedule a demo for our law firm"], | |
| ["My CloudLex account isn't loading properly"], | |
| ["Are you hiring software engineers?"], | |
| ["We want to partner with CloudLex"], | |
| ["What features does CloudLex offer?"], | |
| ["Just browsing"] | |
| ], | |
| ) | |
| demo.launch() | |