Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Textilindo AI Assistant - Hugging Face Spaces FastAPI Application | |
| Simplified version for HF Spaces deployment | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import requests | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Textilindo AI Assistant", | |
| description="AI Assistant for Textilindo textile company", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| conversation_id: Optional[str] = None | |
| class ChatResponse(BaseModel): | |
| response: str | |
| conversation_id: str | |
| status: str = "success" | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| version: str = "1.0.0" | |
| class TrainingRequest(BaseModel): | |
| model_name: str = "distilgpt2" | |
| dataset_path: str = "data/lora_dataset_20250910_145055.jsonl" | |
| config_path: str = "configs/training_config.yaml" | |
| max_samples: int = 20 | |
| epochs: int = 1 | |
| batch_size: int = 1 | |
| learning_rate: float = 5e-5 | |
| class TrainingResponse(BaseModel): | |
| success: bool | |
| message: str | |
| training_id: str | |
| status: str | |
| # Training status storage | |
| training_status = { | |
| "is_training": False, | |
| "progress": 0, | |
| "status": "idle", | |
| "current_step": 0, | |
| "total_steps": 0, | |
| "loss": 0.0, | |
| "start_time": None, | |
| "end_time": None, | |
| "error": None | |
| } | |
| class TextilindoAI: | |
| """Textilindo AI Assistant using HuggingFace Inference API""" | |
| def __init__(self): | |
| self.api_key = os.getenv('HUGGINGFACE_API_KEY') | |
| self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-3.1-8B-Instruct') | |
| self.system_prompt = self.load_system_prompt() | |
| if not self.api_key: | |
| logger.warning("HUGGINGFACE_API_KEY not found. Using mock responses.") | |
| self.client = None | |
| else: | |
| try: | |
| from huggingface_hub import InferenceClient | |
| self.client = InferenceClient( | |
| token=self.api_key, | |
| model=self.model | |
| ) | |
| logger.info(f"Initialized with model: {self.model}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize InferenceClient: {e}") | |
| self.client = None | |
| def load_system_prompt(self) -> str: | |
| """Load system prompt from config file""" | |
| try: | |
| prompt_path = Path("configs/system_prompt.md") | |
| if prompt_path.exists(): | |
| with open(prompt_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Extract system prompt from markdown | |
| if 'SYSTEM_PROMPT = """' in content: | |
| start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """') | |
| end = content.find('"""', start) | |
| return content[start:end].strip() | |
| else: | |
| # Fallback: use entire content | |
| return content.strip() | |
| else: | |
| return self.get_default_system_prompt() | |
| except Exception as e: | |
| logger.error(f"Error loading system prompt: {e}") | |
| return self.get_default_system_prompt() | |
| def get_default_system_prompt(self) -> str: | |
| """Default system prompt if file not found""" | |
| return """You are a friendly and helpful AI assistant for Textilindo, a textile company. | |
| Always respond in Indonesian (Bahasa Indonesia). | |
| Keep responses short and direct. | |
| Be friendly and helpful. | |
| Use exact information from the knowledge base. | |
| The company uses yards for sales. | |
| Minimum purchase is 1 roll (67-70 yards).""" | |
| def generate_response(self, user_message: str) -> str: | |
| """Generate response using HuggingFace Inference API""" | |
| if not self.client: | |
| return self.get_mock_response(user_message) | |
| try: | |
| # Create full prompt with system prompt | |
| full_prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n" | |
| # Generate response | |
| response = self.client.text_generation( | |
| full_prompt, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.1, | |
| stop_sequences=["<|end|>", "<|user|>"] | |
| ) | |
| # Extract only the assistant's response | |
| if "<|assistant|>" in response: | |
| assistant_response = response.split("<|assistant|>")[-1].strip() | |
| assistant_response = assistant_response.replace("<|end|>", "").strip() | |
| return assistant_response | |
| else: | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| return self.get_mock_response(user_message) | |
| def get_mock_response(self, user_message: str) -> str: | |
| """Mock responses for testing without API key""" | |
| mock_responses = { | |
| "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213", | |
| "jam berapa textilindo beroperasional": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.", | |
| "berapa ketentuan pembelian": "Minimal order 1 roll per jenis kain", | |
| "bagaimana dengan pembayarannya": "Pembayaran dapat dilakukan via transfer bank atau cash on delivery", | |
| "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll.", | |
| "apa bisa dikirimkan sample": "hallo kak untuk sampel kita bisa kirimkan gratis ya kak 😊" | |
| } | |
| # Simple keyword matching | |
| user_lower = user_message.lower() | |
| for key, response in mock_responses.items(): | |
| if any(word in user_lower for word in key.split()): | |
| return response | |
| return "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊" | |
| # Initialize AI assistant | |
| ai_assistant = TextilindoAI() | |
| # Routes | |
| async def root(): | |
| """Serve the main chat interface""" | |
| try: | |
| with open("templates/chat.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| return HTMLResponse(content=""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Textilindo AI Assistant</title> | |
| <meta charset="utf-8"> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 20px; margin: 20px 0; } | |
| .message { margin: 10px 0; padding: 10px; border-radius: 5px; } | |
| .user { background-color: #e3f2fd; text-align: right; } | |
| .assistant { background-color: #f5f5f5; } | |
| input[type="text"] { width: 70%; padding: 10px; border: 1px solid #ddd; border-radius: 5px; } | |
| button { padding: 10px 20px; background-color: #2196f3; color: white; border: none; border-radius: 5px; cursor: pointer; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🤖 Textilindo AI Assistant</h1> | |
| <div class="chat-container"> | |
| <div id="chat-messages"></div> | |
| <div style="margin-top: 20px;"> | |
| <input type="text" id="message-input" placeholder="Tulis pesan Anda..." onkeypress="handleKeyPress(event)"> | |
| <button onclick="sendMessage()">Kirim</button> | |
| </div> | |
| </div> | |
| <script> | |
| async function sendMessage() { | |
| const input = document.getElementById('message-input'); | |
| const message = input.value.trim(); | |
| if (!message) return; | |
| // Add user message | |
| addMessage(message, 'user'); | |
| input.value = ''; | |
| // Get AI response | |
| try { | |
| const response = await fetch('/chat', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ message: message }) | |
| }); | |
| const data = await response.json(); | |
| addMessage(data.response, 'assistant'); | |
| } catch (error) { | |
| addMessage('Maaf, terjadi kesalahan. Silakan coba lagi.', 'assistant'); | |
| } | |
| } | |
| function addMessage(text, sender) { | |
| const messages = document.getElementById('chat-messages'); | |
| const div = document.createElement('div'); | |
| div.className = `message ${sender}`; | |
| div.textContent = text; | |
| messages.appendChild(div); | |
| messages.scrollTop = messages.scrollHeight; | |
| } | |
| function handleKeyPress(event) { | |
| if (event.key === 'Enter') { | |
| sendMessage(); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """) | |
| async def chat(request: ChatRequest): | |
| """Chat endpoint""" | |
| try: | |
| response = ai_assistant.generate_response(request.message) | |
| return ChatResponse( | |
| response=response, | |
| conversation_id=request.conversation_id or "default", | |
| status="success" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat endpoint: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| message="Textilindo AI Assistant is running", | |
| version="1.0.0" | |
| ) | |
| async def get_info(): | |
| """Get application information""" | |
| return { | |
| "name": "Textilindo AI Assistant", | |
| "version": "1.0.0", | |
| "model": ai_assistant.model, | |
| "has_api_key": bool(ai_assistant.api_key), | |
| "client_initialized": bool(ai_assistant.client), | |
| "endpoints": { | |
| "training": { | |
| "start": "POST /api/train/start", | |
| "status": "GET /api/train/status", | |
| "data": "GET /api/train/data", | |
| "gpu": "GET /api/train/gpu", | |
| "test": "POST /api/train/test" | |
| }, | |
| "chat": { | |
| "chat": "POST /chat", | |
| "health": "GET /health" | |
| } | |
| } | |
| } | |
| # Training API endpoints (simplified for HF Spaces) | |
| async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): | |
| """Start training process (simplified for HF Spaces)""" | |
| global training_status | |
| if training_status["is_training"]: | |
| raise HTTPException(status_code=400, detail="Training already in progress") | |
| # For HF Spaces, we'll simulate training | |
| training_id = f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| # Update status to show training started | |
| training_status.update({ | |
| "is_training": True, | |
| "status": "started", | |
| "progress": 0, | |
| "start_time": datetime.now().isoformat(), | |
| "error": None | |
| }) | |
| # Simulate training completion after a delay | |
| background_tasks.add_task(simulate_training_completion) | |
| return TrainingResponse( | |
| success=True, | |
| message="Training started successfully (simulated for HF Spaces)", | |
| training_id=training_id, | |
| status="started" | |
| ) | |
| async def simulate_training_completion(): | |
| """Simulate training completion for HF Spaces""" | |
| import asyncio | |
| await asyncio.sleep(10) # Simulate 10 seconds of training | |
| global training_status | |
| training_status.update({ | |
| "is_training": False, | |
| "status": "completed", | |
| "progress": 100, | |
| "end_time": datetime.now().isoformat() | |
| }) | |
| async def get_training_status(): | |
| """Get current training status""" | |
| return training_status | |
| async def get_training_data_info(): | |
| """Get information about available training data""" | |
| data_dir = Path("data") | |
| if not data_dir.exists(): | |
| return {"files": [], "count": 0} | |
| jsonl_files = list(data_dir.glob("*.jsonl")) | |
| files_info = [] | |
| for file in jsonl_files: | |
| try: | |
| with open(file, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| files_info.append({ | |
| "name": file.name, | |
| "size": file.stat().st_size, | |
| "lines": len(lines) | |
| }) | |
| except Exception as e: | |
| files_info.append({ | |
| "name": file.name, | |
| "error": str(e) | |
| }) | |
| return { | |
| "files": files_info, | |
| "count": len(jsonl_files) | |
| } | |
| async def get_gpu_info(): | |
| """Get GPU information (simulated for HF Spaces)""" | |
| return { | |
| "available": False, | |
| "message": "GPU not available in HF Spaces free tier", | |
| "recommendation": "Use local training or upgrade to paid tier" | |
| } | |
| async def test_trained_model(): | |
| """Test the trained model (simulated)""" | |
| return { | |
| "success": True, | |
| "message": "Model testing simulated for HF Spaces", | |
| "test_prompt": "dimana lokasi textilindo?", | |
| "response": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213", | |
| "note": "This is a simulated response for HF Spaces demo" | |
| } | |
| if __name__ == "__main__": | |
| # Get port from environment variable (Hugging Face Spaces uses 7860) | |
| port = int(os.getenv("PORT", 7860)) | |
| # Run the application | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info" | |
| ) |