Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Script để thay đổi LLM provider linh hoạt. | |
| Sử dụng: python switch_llm_provider.py [provider] [options] | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| from pathlib import Path | |
| # Colors for terminal output | |
| class Colors: | |
| GREEN = '\033[92m' | |
| YELLOW = '\033[93m' | |
| BLUE = '\033[94m' | |
| RED = '\033[91m' | |
| RESET = '\033[0m' | |
| BOLD = '\033[1m' | |
| def print_colored(text, color=Colors.RESET): | |
| """Print colored text.""" | |
| print(f"{color}{text}{Colors.RESET}") | |
| def get_env_file(): | |
| """Get .env file path.""" | |
| # Try multiple locations | |
| possible_paths = [ | |
| Path(__file__).parent / ".env", | |
| Path(__file__).parent.parent / ".env", | |
| Path.home() / ".env", | |
| ] | |
| for path in possible_paths: | |
| if path.exists(): | |
| return path | |
| # Return default location | |
| return Path(__file__).parent / ".env" | |
| def read_env_file(): | |
| """Read .env file and return as dict.""" | |
| env_file = get_env_file() | |
| env_vars = {} | |
| if env_file.exists(): | |
| with open(env_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line and not line.startswith('#') and '=' in line: | |
| key, value = line.split('=', 1) | |
| env_vars[key.strip()] = value.strip() | |
| return env_vars, env_file | |
| def write_env_file(env_vars, env_file): | |
| """Write .env file from dict.""" | |
| # Read existing file to preserve comments and order | |
| lines = [] | |
| if env_file.exists(): | |
| with open(env_file, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| # Create new content | |
| new_lines = [] | |
| llm_provider_set = False | |
| local_model_vars_set = set() | |
| # Track which LLM-related vars we've set | |
| llm_related_vars = { | |
| 'LLM_PROVIDER', 'LOCAL_MODEL_PATH', 'LOCAL_MODEL_DEVICE', | |
| 'LOCAL_MODEL_4BIT', 'LOCAL_MODEL_8BIT', 'HF_API_BASE_URL', | |
| 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'OLLAMA_BASE_URL', 'OLLAMA_MODEL' | |
| } | |
| # Process existing lines | |
| for line in lines: | |
| stripped = line.strip() | |
| if not stripped or stripped.startswith('#'): | |
| new_lines.append(line) | |
| continue | |
| if '=' in stripped: | |
| key = stripped.split('=', 1)[0].strip() | |
| if key in llm_related_vars: | |
| # Skip old LLM-related vars, we'll add new ones | |
| if key == 'LLM_PROVIDER': | |
| llm_provider_set = True | |
| if key.startswith('LOCAL_MODEL_'): | |
| local_model_vars_set.add(key) | |
| continue | |
| new_lines.append(line) | |
| # Add LLM provider config | |
| if not llm_provider_set: | |
| new_lines.append("\n# LLM Provider Configuration\n") | |
| provider = env_vars.get('LLM_PROVIDER', 'none') | |
| new_lines.append(f"LLM_PROVIDER={provider}\n") | |
| # Add provider-specific configs | |
| if provider == 'local': | |
| new_lines.append(f"LOCAL_MODEL_PATH={env_vars.get('LOCAL_MODEL_PATH', 'Qwen/Qwen2.5-7B-Instruct')}\n") | |
| new_lines.append(f"LOCAL_MODEL_DEVICE={env_vars.get('LOCAL_MODEL_DEVICE', 'auto')}\n") | |
| new_lines.append(f"LOCAL_MODEL_8BIT={env_vars.get('LOCAL_MODEL_8BIT', 'true')}\n") | |
| new_lines.append(f"LOCAL_MODEL_4BIT={env_vars.get('LOCAL_MODEL_4BIT', 'false')}\n") | |
| elif provider == 'api': | |
| new_lines.append(f"HF_API_BASE_URL={env_vars.get('HF_API_BASE_URL', 'https://davidtran999-hue-portal-backend.hf.space/api')}\n") | |
| elif provider == 'openai': | |
| if 'OPENAI_API_KEY' in env_vars: | |
| new_lines.append(f"OPENAI_API_KEY={env_vars['OPENAI_API_KEY']}\n") | |
| elif provider == 'anthropic': | |
| if 'ANTHROPIC_API_KEY' in env_vars: | |
| new_lines.append(f"ANTHROPIC_API_KEY={env_vars['ANTHROPIC_API_KEY']}\n") | |
| elif provider == 'ollama': | |
| new_lines.append(f"OLLAMA_BASE_URL={env_vars.get('OLLAMA_BASE_URL', 'http://localhost:11434')}\n") | |
| new_lines.append(f"OLLAMA_MODEL={env_vars.get('OLLAMA_MODEL', 'qwen2.5:7b')}\n") | |
| # Write to file | |
| with open(env_file, 'w', encoding='utf-8') as f: | |
| f.writelines(new_lines) | |
| return env_file | |
| def set_provider(provider, **kwargs): | |
| """Set LLM provider and related config.""" | |
| env_vars, env_file = read_env_file() | |
| # Update provider | |
| env_vars['LLM_PROVIDER'] = provider | |
| # Update provider-specific configs | |
| if provider == 'local': | |
| env_vars['LOCAL_MODEL_PATH'] = kwargs.get('model_path', 'Qwen/Qwen2.5-7B-Instruct') | |
| env_vars['LOCAL_MODEL_DEVICE'] = kwargs.get('device', 'auto') | |
| env_vars['LOCAL_MODEL_8BIT'] = kwargs.get('use_8bit', 'true') | |
| env_vars['LOCAL_MODEL_4BIT'] = kwargs.get('use_4bit', 'false') | |
| elif provider == 'api': | |
| env_vars['HF_API_BASE_URL'] = kwargs.get('api_url', 'https://davidtran999-hue-portal-backend.hf.space/api') | |
| # Write to file | |
| write_env_file(env_vars, env_file) | |
| print_colored(f"✅ Đã chuyển sang LLM Provider: {provider.upper()}", Colors.GREEN) | |
| print_colored(f"📝 File: {env_file}", Colors.BLUE) | |
| if provider == 'local': | |
| print_colored(f" Model: {env_vars['LOCAL_MODEL_PATH']}", Colors.BLUE) | |
| print_colored(f" Device: {env_vars['LOCAL_MODEL_DEVICE']}", Colors.BLUE) | |
| print_colored(f" 8-bit: {env_vars['LOCAL_MODEL_8BIT']}", Colors.BLUE) | |
| print_colored(f" 4-bit: {env_vars['LOCAL_MODEL_4BIT']}", Colors.BLUE) | |
| elif provider == 'api': | |
| print_colored(f" API URL: {env_vars['HF_API_BASE_URL']}", Colors.BLUE) | |
| return env_file | |
| def show_current(): | |
| """Show current LLM provider configuration.""" | |
| env_vars, env_file = read_env_file() | |
| provider = env_vars.get('LLM_PROVIDER', 'none') | |
| print_colored("\n" + "="*60, Colors.BOLD) | |
| print_colored("Current LLM Provider Configuration", Colors.BOLD) | |
| print_colored("="*60, Colors.RESET) | |
| print_colored(f"Provider: {provider.upper()}", Colors.GREEN) | |
| print_colored(f"Config file: {env_file}", Colors.BLUE) | |
| if provider == 'local': | |
| print_colored("\nLocal Model Settings:", Colors.YELLOW) | |
| print(f" MODEL_PATH: {env_vars.get('LOCAL_MODEL_PATH', 'Qwen/Qwen2.5-7B-Instruct')}") | |
| print(f" DEVICE: {env_vars.get('LOCAL_MODEL_DEVICE', 'auto')}") | |
| print(f" 8BIT: {env_vars.get('LOCAL_MODEL_8BIT', 'true')}") | |
| print(f" 4BIT: {env_vars.get('LOCAL_MODEL_4BIT', 'false')}") | |
| elif provider == 'api': | |
| print_colored("\nAPI Mode Settings:", Colors.YELLOW) | |
| print(f" API_URL: {env_vars.get('HF_API_BASE_URL', 'https://davidtran999-hue-portal-backend.hf.space/api')}") | |
| elif provider == 'openai': | |
| has_key = 'OPENAI_API_KEY' in env_vars and env_vars['OPENAI_API_KEY'] | |
| print_colored(f"\nOpenAI Settings:", Colors.YELLOW) | |
| print(f" API_KEY: {'✅ Set' if has_key else '❌ Not set'}") | |
| elif provider == 'anthropic': | |
| has_key = 'ANTHROPIC_API_KEY' in env_vars and env_vars['ANTHROPIC_API_KEY'] | |
| print_colored(f"\nAnthropic Settings:", Colors.YELLOW) | |
| print(f" API_KEY: {'✅ Set' if has_key else '❌ Not set'}") | |
| elif provider == 'ollama': | |
| print_colored("\nOllama Settings:", Colors.YELLOW) | |
| print(f" BASE_URL: {env_vars.get('OLLAMA_BASE_URL', 'http://localhost:11434')}") | |
| print(f" MODEL: {env_vars.get('OLLAMA_MODEL', 'qwen2.5:7b')}") | |
| elif provider == 'none': | |
| print_colored("\n⚠️ No LLM provider configured. Using template-based generation.", Colors.YELLOW) | |
| print_colored("="*60 + "\n", Colors.RESET) | |
| def main(): | |
| """Main function.""" | |
| parser = argparse.ArgumentParser( | |
| description='Switch LLM provider linh hoạt', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Switch to local model | |
| python switch_llm_provider.py local | |
| # Switch to local with custom model | |
| python switch_llm_provider.py local --model Qwen/Qwen2.5-14B-Instruct --device cuda --8bit | |
| # Switch to API mode | |
| python switch_llm_provider.py api | |
| # Switch to API with custom URL | |
| python switch_llm_provider.py api --url https://custom-api.hf.space/api | |
| # Switch to OpenAI | |
| python switch_llm_provider.py openai | |
| # Switch to Anthropic | |
| python switch_llm_provider.py anthropic | |
| # Switch to Ollama | |
| python switch_llm_provider.py ollama | |
| # Disable LLM (use templates only) | |
| python switch_llm_provider.py none | |
| # Show current configuration | |
| python switch_llm_provider.py show | |
| """ | |
| ) | |
| parser.add_argument( | |
| 'provider', | |
| choices=['local', 'api', 'openai', 'anthropic', 'ollama', 'none', 'show'], | |
| help='LLM provider to use' | |
| ) | |
| # Local model options | |
| parser.add_argument('--model', '--model-path', dest='model_path', | |
| help='Model path for local provider (e.g., Qwen/Qwen2.5-7B-Instruct)') | |
| parser.add_argument('--device', choices=['auto', 'cpu', 'cuda'], | |
| help='Device for local model (auto, cpu, cuda)') | |
| parser.add_argument('--8bit', action='store_true', | |
| help='Use 8-bit quantization for local model') | |
| parser.add_argument('--4bit', action='store_true', | |
| help='Use 4-bit quantization for local model') | |
| # API mode options | |
| parser.add_argument('--url', '--api-url', dest='api_url', | |
| help='API URL for API mode') | |
| args = parser.parse_args() | |
| if args.provider == 'show': | |
| show_current() | |
| return 0 | |
| # Prepare kwargs | |
| kwargs = {} | |
| if args.provider == 'local': | |
| if args.model_path: | |
| kwargs['model_path'] = args.model_path | |
| if args.device: | |
| kwargs['device'] = args.device | |
| if args.__dict__.get('8bit'): | |
| kwargs['use_8bit'] = 'true' | |
| kwargs['use_4bit'] = 'false' | |
| elif args.__dict__.get('4bit'): | |
| kwargs['use_4bit'] = 'true' | |
| kwargs['use_8bit'] = 'false' | |
| elif args.provider == 'api': | |
| if args.api_url: | |
| kwargs['api_url'] = args.api_url | |
| # Set provider | |
| try: | |
| set_provider(args.provider, **kwargs) | |
| print_colored("\n💡 Tip: Restart your Django server để áp dụng thay đổi!", Colors.YELLOW) | |
| return 0 | |
| except Exception as e: | |
| print_colored(f"❌ Error: {e}", Colors.RED) | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |