Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Textilindo AI Assistant - Hugging Face Spaces FastAPI Application | |
| Main application file for deployment on Hugging Face Spaces | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, Request, BackgroundTasks | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from huggingface_hub import InferenceClient | |
| import requests | |
| from datetime import datetime | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Textilindo AI Assistant", | |
| description="AI Assistant for Textilindo textile company", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| conversation_id: Optional[str] = None | |
| class ChatResponse(BaseModel): | |
| response: str | |
| conversation_id: str | |
| status: str = "success" | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| version: str = "1.0.0" | |
| class TextilindoAI: | |
| """Textilindo AI Assistant using HuggingFace Inference API""" | |
| def __init__(self): | |
| self.api_key = os.getenv('HUGGINGFACE_API_KEY') | |
| self.model = os.getenv('DEFAULT_MODEL', 'meta-llama/Llama-3.1-8B-Instruct') | |
| self.system_prompt = self.load_system_prompt() | |
| if not self.api_key: | |
| logger.warning("HUGGINGFACE_API_KEY not found. Using mock responses.") | |
| self.client = None | |
| else: | |
| try: | |
| self.client = InferenceClient( | |
| token=self.api_key, | |
| model=self.model | |
| ) | |
| logger.info(f"Initialized with model: {self.model}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize InferenceClient: {e}") | |
| self.client = None | |
| def load_system_prompt(self) -> str: | |
| """Load system prompt from config file""" | |
| try: | |
| prompt_path = Path("configs/system_prompt.md") | |
| if prompt_path.exists(): | |
| with open(prompt_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Extract system prompt from markdown | |
| if 'SYSTEM_PROMPT = """' in content: | |
| start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """') | |
| end = content.find('"""', start) | |
| return content[start:end].strip() | |
| else: | |
| # Fallback: use entire content | |
| return content.strip() | |
| else: | |
| return self.get_default_system_prompt() | |
| except Exception as e: | |
| logger.error(f"Error loading system prompt: {e}") | |
| return self.get_default_system_prompt() | |
| def get_default_system_prompt(self) -> str: | |
| """Default system prompt if file not found""" | |
| return """You are a friendly and helpful AI assistant for Textilindo, a textile company. | |
| Always respond in Indonesian (Bahasa Indonesia). | |
| Keep responses short and direct. | |
| Be friendly and helpful. | |
| Use exact information from the knowledge base. | |
| The company uses yards for sales. | |
| Minimum purchase is 1 roll (67-70 yards).""" | |
| def generate_response(self, user_message: str) -> str: | |
| """Generate response using HuggingFace Inference API""" | |
| if not self.client: | |
| return self.get_mock_response(user_message) | |
| try: | |
| # Create full prompt with system prompt | |
| full_prompt = f"<|system|>\n{self.system_prompt}\n<|user|>\n{user_message}\n<|assistant|>\n" | |
| # Generate response | |
| response = self.client.text_generation( | |
| full_prompt, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.1, | |
| stop_sequences=["<|end|>", "<|user|>"] | |
| ) | |
| # Extract only the assistant's response | |
| if "<|assistant|>" in response: | |
| assistant_response = response.split("<|assistant|>")[-1].strip() | |
| assistant_response = assistant_response.replace("<|end|>", "").strip() | |
| return assistant_response | |
| else: | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| return self.get_mock_response(user_message) | |
| def get_mock_response(self, user_message: str) -> str: | |
| """Mock responses for testing without API key""" | |
| mock_responses = { | |
| "dimana lokasi textilindo": "Textilindo berkantor pusat di Jl. Raya Prancis No.39, Kosambi Tim., Kec. Kosambi, Kabupaten Tangerang, Banten 15213", | |
| "jam berapa textilindo beroperasional": "Jam operasional Senin-Jumat 08:00-17:00, Sabtu 08:00-12:00.", | |
| "berapa ketentuan pembelian": "Minimal order 1 roll per jenis kain", | |
| "bagaimana dengan pembayarannya": "Pembayaran dapat dilakukan via transfer bank atau cash on delivery", | |
| "apa ada gratis ongkir": "Gratis ongkir untuk order minimal 5 roll.", | |
| "apa bisa dikirimkan sample": "hallo kak untuk sampel kita bisa kirimkan gratis ya kak 😊" | |
| } | |
| # Simple keyword matching | |
| user_lower = user_message.lower() | |
| for key, response in mock_responses.items(): | |
| if any(word in user_lower for word in key.split()): | |
| return response | |
| return "Halo! Saya adalah asisten AI Textilindo. Bagaimana saya bisa membantu Anda hari ini? 😊" | |
| # Initialize AI assistant | |
| ai_assistant = TextilindoAI() | |
| # Routes | |
| async def root(): | |
| """Serve the main chat interface""" | |
| try: | |
| with open("templates/chat.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| return HTMLResponse(content=""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Textilindo AI Assistant</title> | |
| <meta charset="utf-8"> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 20px; margin: 20px 0; } | |
| .message { margin: 10px 0; padding: 10px; border-radius: 5px; } | |
| .user { background-color: #e3f2fd; text-align: right; } | |
| .assistant { background-color: #f5f5f5; } | |
| input[type="text"] { width: 70%; padding: 10px; border: 1px solid #ddd; border-radius: 5px; } | |
| button { padding: 10px 20px; background-color: #2196f3; color: white; border: none; border-radius: 5px; cursor: pointer; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🤖 Textilindo AI Assistant</h1> | |
| <div class="chat-container"> | |
| <div id="chat-messages"></div> | |
| <div style="margin-top: 20px;"> | |
| <input type="text" id="message-input" placeholder="Tulis pesan Anda..." onkeypress="handleKeyPress(event)"> | |
| <button onclick="sendMessage()">Kirim</button> | |
| </div> | |
| </div> | |
| <script> | |
| async function sendMessage() { | |
| const input = document.getElementById('message-input'); | |
| const message = input.value.trim(); | |
| if (!message) return; | |
| // Add user message | |
| addMessage(message, 'user'); | |
| input.value = ''; | |
| // Get AI response | |
| try { | |
| const response = await fetch('/chat', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' }, | |
| body: JSON.stringify({ message: message }) | |
| }); | |
| const data = await response.json(); | |
| addMessage(data.response, 'assistant'); | |
| } catch (error) { | |
| addMessage('Maaf, terjadi kesalahan. Silakan coba lagi.', 'assistant'); | |
| } | |
| } | |
| function addMessage(text, sender) { | |
| const messages = document.getElementById('chat-messages'); | |
| const div = document.createElement('div'); | |
| div.className = `message ${sender}`; | |
| div.textContent = text; | |
| messages.appendChild(div); | |
| messages.scrollTop = messages.scrollHeight; | |
| } | |
| function handleKeyPress(event) { | |
| if (event.key === 'Enter') { | |
| sendMessage(); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """) | |
| async def chat(request: ChatRequest): | |
| """Chat endpoint""" | |
| try: | |
| response = ai_assistant.generate_response(request.message) | |
| return ChatResponse( | |
| response=response, | |
| conversation_id=request.conversation_id or "default", | |
| status="success" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat endpoint: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| message="Textilindo AI Assistant is running", | |
| version="1.0.0" | |
| ) | |
| async def get_info(): | |
| """Get application information""" | |
| return { | |
| "name": "Textilindo AI Assistant", | |
| "version": "1.0.0", | |
| "model": ai_assistant.model, | |
| "has_api_key": bool(ai_assistant.api_key), | |
| "client_initialized": bool(ai_assistant.client) | |
| } | |
| # Import training API | |
| from training_api import ( | |
| TrainingRequest, TrainingResponse, training_status, | |
| train_model_async, load_training_config, load_training_data, check_gpu_availability | |
| ) | |
| # Training API endpoints | |
| async def start_training_api(request: TrainingRequest, background_tasks: BackgroundTasks): | |
| """Start training process via API""" | |
| if training_status["is_training"]: | |
| raise HTTPException(status_code=400, detail="Training already in progress") | |
| # Validate inputs | |
| if not Path(request.dataset_path).exists(): | |
| raise HTTPException(status_code=404, detail=f"Dataset not found: {request.dataset_path}") | |
| if not Path(request.config_path).exists(): | |
| raise HTTPException(status_code=404, detail=f"Config not found: {request.config_path}") | |
| # Start training in background | |
| training_id = f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
| background_tasks.add_task( | |
| train_model_async, | |
| request.model_name, | |
| request.dataset_path, | |
| request.config_path, | |
| request.max_samples, | |
| request.epochs, | |
| request.batch_size, | |
| request.learning_rate | |
| ) | |
| return TrainingResponse( | |
| success=True, | |
| message="Training started successfully", | |
| training_id=training_id, | |
| status="started" | |
| ) | |
| async def get_training_status_api(): | |
| """Get current training status""" | |
| return training_status | |
| async def get_training_data_info_api(): | |
| """Get information about available training data""" | |
| data_dir = Path("data") | |
| if not data_dir.exists(): | |
| return {"files": [], "count": 0} | |
| jsonl_files = list(data_dir.glob("*.jsonl")) | |
| files_info = [] | |
| for file in jsonl_files: | |
| try: | |
| with open(file, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| files_info.append({ | |
| "name": file.name, | |
| "size": file.stat().st_size, | |
| "lines": len(lines) | |
| }) | |
| except Exception as e: | |
| files_info.append({ | |
| "name": file.name, | |
| "error": str(e) | |
| }) | |
| return { | |
| "files": files_info, | |
| "count": len(jsonl_files) | |
| } | |
| async def get_gpu_info_api(): | |
| """Get GPU information""" | |
| try: | |
| import torch | |
| gpu_available = torch.cuda.is_available() | |
| if gpu_available: | |
| gpu_count = torch.cuda.device_count() | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| return { | |
| "available": True, | |
| "count": gpu_count, | |
| "name": gpu_name, | |
| "memory_gb": round(gpu_memory, 2) | |
| } | |
| else: | |
| return {"available": False} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| async def test_trained_model_api(): | |
| """Test the trained model""" | |
| model_path = "./models/textilindo-trained" | |
| if not Path(model_path).exists(): | |
| return {"error": "No trained model found"} | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| # Test prompt | |
| test_prompt = "Question: dimana lokasi textilindo? Answer:" | |
| inputs = tokenizer(test_prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=inputs.input_ids.shape[1] + 30, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return { | |
| "success": True, | |
| "test_prompt": test_prompt, | |
| "response": response, | |
| "model_path": model_path | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Legacy training endpoints (for backward compatibility) | |
| async def training_interface(): | |
| """Training interface""" | |
| try: | |
| with open("templates/training.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| return HTMLResponse(content=""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Textilindo AI Training</title> | |
| <meta charset="utf-8"> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| .container { background: #f5f5f5; padding: 20px; border-radius: 10px; margin: 20px 0; } | |
| button { background: #2196f3; color: white; border: none; padding: 10px 20px; border-radius: 5px; cursor: pointer; } | |
| button:hover { background: #1976d2; } | |
| .log { background: #000; color: #0f0; padding: 10px; border-radius: 5px; font-family: monospace; height: 300px; overflow-y: auto; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🤖 Textilindo AI Training Interface</h1> | |
| <div class="container"> | |
| <h2>Training Options</h2> | |
| <p>Choose your training method:</p> | |
| <button onclick="startLightweightTraining()">Start Lightweight Training</button> | |
| <button onclick="checkResources()">Check Resources</button> | |
| <button onclick="viewData()">View Training Data</button> | |
| </div> | |
| <div class="container"> | |
| <h2>Training Log</h2> | |
| <div id="log" class="log">Ready to start training...</div> | |
| </div> | |
| <script> | |
| function addLog(message) { | |
| const log = document.getElementById('log'); | |
| const timestamp = new Date().toLocaleTimeString(); | |
| log.innerHTML += `[${timestamp}] ${message}\\n`; | |
| log.scrollTop = log.scrollHeight; | |
| } | |
| async function startLightweightTraining() { | |
| addLog('Starting lightweight training...'); | |
| try { | |
| const response = await fetch('/train/start', { | |
| method: 'POST', | |
| headers: { 'Content-Type': 'application/json' } | |
| }); | |
| const result = await response.json(); | |
| addLog(`Training result: ${result.message}`); | |
| } catch (error) { | |
| addLog(`Error: ${error.message}`); | |
| } | |
| } | |
| async function checkResources() { | |
| addLog('Checking resources...'); | |
| try { | |
| const response = await fetch('/train/status'); | |
| const result = await response.json(); | |
| addLog(`Resources: ${JSON.stringify(result, null, 2)}`); | |
| } catch (error) { | |
| addLog(`Error: ${error.message}`); | |
| } | |
| } | |
| async function viewData() { | |
| addLog('Loading training data...'); | |
| try { | |
| const response = await fetch('/train/data'); | |
| const result = await response.json(); | |
| addLog(`Data files: ${result.files.join(', ')}`); | |
| } catch (error) { | |
| addLog(`Error: ${error.message}`); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """) | |
| async def start_training(): | |
| """Start lightweight training""" | |
| try: | |
| # Import training script | |
| import subprocess | |
| import sys | |
| # Run the training script | |
| result = subprocess.run([ | |
| sys.executable, "train_on_space.py" | |
| ], capture_output=True, text=True, timeout=300) # 5 minute timeout | |
| if result.returncode == 0: | |
| return {"message": "Training completed successfully!", "output": result.stdout} | |
| else: | |
| return {"message": "Training failed", "error": result.stderr} | |
| except subprocess.TimeoutExpired: | |
| return {"message": "Training timed out (5 minutes limit)"} | |
| except Exception as e: | |
| return {"message": f"Training error: {str(e)}"} | |
| async def training_status(): | |
| """Get training status and resources""" | |
| try: | |
| import psutil | |
| return { | |
| "status": "ready", | |
| "cpu_count": psutil.cpu_count(), | |
| "memory_total_gb": round(psutil.virtual_memory().total / (1024**3), 2), | |
| "memory_available_gb": round(psutil.virtual_memory().available / (1024**3), 2), | |
| "disk_free_gb": round(psutil.disk_usage('.').free / (1024**3), 2) | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def training_data(): | |
| """Get training data information""" | |
| try: | |
| data_dir = Path("data") | |
| if data_dir.exists(): | |
| jsonl_files = list(data_dir.glob("*.jsonl")) | |
| return { | |
| "files": [f.name for f in jsonl_files], | |
| "count": len(jsonl_files) | |
| } | |
| else: | |
| return {"files": [], "count": 0} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Mount static files if they exist | |
| if Path("static").exists(): | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| if __name__ == "__main__": | |
| # Get port from environment variable (Hugging Face Spaces uses 7860) | |
| port = int(os.getenv("PORT", 7860)) | |
| # Run the application | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info" | |
| ) | |