davidtran999 commited on
Commit
e584168
·
verified ·
1 Parent(s): f1d44e1

Upload backend/switch_llm_provider.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/switch_llm_provider.py +288 -0
backend/switch_llm_provider.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script để thay đổi LLM provider linh hoạt.
4
+ Sử dụng: python switch_llm_provider.py [provider] [options]
5
+ """
6
+ import os
7
+ import sys
8
+ import argparse
9
+ from pathlib import Path
10
+
11
+ # Colors for terminal output
12
+ class Colors:
13
+ GREEN = '\033[92m'
14
+ YELLOW = '\033[93m'
15
+ BLUE = '\033[94m'
16
+ RED = '\033[91m'
17
+ RESET = '\033[0m'
18
+ BOLD = '\033[1m'
19
+
20
+ def print_colored(text, color=Colors.RESET):
21
+ """Print colored text."""
22
+ print(f"{color}{text}{Colors.RESET}")
23
+
24
+ def get_env_file():
25
+ """Get .env file path."""
26
+ # Try multiple locations
27
+ possible_paths = [
28
+ Path(__file__).parent / ".env",
29
+ Path(__file__).parent.parent / ".env",
30
+ Path.home() / ".env",
31
+ ]
32
+
33
+ for path in possible_paths:
34
+ if path.exists():
35
+ return path
36
+
37
+ # Return default location
38
+ return Path(__file__).parent / ".env"
39
+
40
+ def read_env_file():
41
+ """Read .env file and return as dict."""
42
+ env_file = get_env_file()
43
+ env_vars = {}
44
+
45
+ if env_file.exists():
46
+ with open(env_file, 'r', encoding='utf-8') as f:
47
+ for line in f:
48
+ line = line.strip()
49
+ if line and not line.startswith('#') and '=' in line:
50
+ key, value = line.split('=', 1)
51
+ env_vars[key.strip()] = value.strip()
52
+
53
+ return env_vars, env_file
54
+
55
+ def write_env_file(env_vars, env_file):
56
+ """Write .env file from dict."""
57
+ # Read existing file to preserve comments and order
58
+ lines = []
59
+ if env_file.exists():
60
+ with open(env_file, 'r', encoding='utf-8') as f:
61
+ lines = f.readlines()
62
+
63
+ # Create new content
64
+ new_lines = []
65
+ llm_provider_set = False
66
+ local_model_vars_set = set()
67
+
68
+ # Track which LLM-related vars we've set
69
+ llm_related_vars = {
70
+ 'LLM_PROVIDER', 'LOCAL_MODEL_PATH', 'LOCAL_MODEL_DEVICE',
71
+ 'LOCAL_MODEL_4BIT', 'LOCAL_MODEL_8BIT', 'HF_API_BASE_URL',
72
+ 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'OLLAMA_BASE_URL', 'OLLAMA_MODEL'
73
+ }
74
+
75
+ # Process existing lines
76
+ for line in lines:
77
+ stripped = line.strip()
78
+ if not stripped or stripped.startswith('#'):
79
+ new_lines.append(line)
80
+ continue
81
+
82
+ if '=' in stripped:
83
+ key = stripped.split('=', 1)[0].strip()
84
+ if key in llm_related_vars:
85
+ # Skip old LLM-related vars, we'll add new ones
86
+ if key == 'LLM_PROVIDER':
87
+ llm_provider_set = True
88
+ if key.startswith('LOCAL_MODEL_'):
89
+ local_model_vars_set.add(key)
90
+ continue
91
+
92
+ new_lines.append(line)
93
+
94
+ # Add LLM provider config
95
+ if not llm_provider_set:
96
+ new_lines.append("\n# LLM Provider Configuration\n")
97
+
98
+ provider = env_vars.get('LLM_PROVIDER', 'none')
99
+ new_lines.append(f"LLM_PROVIDER={provider}\n")
100
+
101
+ # Add provider-specific configs
102
+ if provider == 'local':
103
+ new_lines.append(f"LOCAL_MODEL_PATH={env_vars.get('LOCAL_MODEL_PATH', 'Qwen/Qwen2.5-7B-Instruct')}\n")
104
+ new_lines.append(f"LOCAL_MODEL_DEVICE={env_vars.get('LOCAL_MODEL_DEVICE', 'auto')}\n")
105
+ new_lines.append(f"LOCAL_MODEL_8BIT={env_vars.get('LOCAL_MODEL_8BIT', 'true')}\n")
106
+ new_lines.append(f"LOCAL_MODEL_4BIT={env_vars.get('LOCAL_MODEL_4BIT', 'false')}\n")
107
+ elif provider == 'api':
108
+ new_lines.append(f"HF_API_BASE_URL={env_vars.get('HF_API_BASE_URL', 'https://davidtran999-hue-portal-backend.hf.space/api')}\n")
109
+ elif provider == 'openai':
110
+ if 'OPENAI_API_KEY' in env_vars:
111
+ new_lines.append(f"OPENAI_API_KEY={env_vars['OPENAI_API_KEY']}\n")
112
+ elif provider == 'anthropic':
113
+ if 'ANTHROPIC_API_KEY' in env_vars:
114
+ new_lines.append(f"ANTHROPIC_API_KEY={env_vars['ANTHROPIC_API_KEY']}\n")
115
+ elif provider == 'ollama':
116
+ new_lines.append(f"OLLAMA_BASE_URL={env_vars.get('OLLAMA_BASE_URL', 'http://localhost:11434')}\n")
117
+ new_lines.append(f"OLLAMA_MODEL={env_vars.get('OLLAMA_MODEL', 'qwen2.5:7b')}\n")
118
+
119
+ # Write to file
120
+ with open(env_file, 'w', encoding='utf-8') as f:
121
+ f.writelines(new_lines)
122
+
123
+ return env_file
124
+
125
+ def set_provider(provider, **kwargs):
126
+ """Set LLM provider and related config."""
127
+ env_vars, env_file = read_env_file()
128
+
129
+ # Update provider
130
+ env_vars['LLM_PROVIDER'] = provider
131
+
132
+ # Update provider-specific configs
133
+ if provider == 'local':
134
+ env_vars['LOCAL_MODEL_PATH'] = kwargs.get('model_path', 'Qwen/Qwen2.5-7B-Instruct')
135
+ env_vars['LOCAL_MODEL_DEVICE'] = kwargs.get('device', 'auto')
136
+ env_vars['LOCAL_MODEL_8BIT'] = kwargs.get('use_8bit', 'true')
137
+ env_vars['LOCAL_MODEL_4BIT'] = kwargs.get('use_4bit', 'false')
138
+ elif provider == 'api':
139
+ env_vars['HF_API_BASE_URL'] = kwargs.get('api_url', 'https://davidtran999-hue-portal-backend.hf.space/api')
140
+
141
+ # Write to file
142
+ write_env_file(env_vars, env_file)
143
+
144
+ print_colored(f"✅ Đã chuyển sang LLM Provider: {provider.upper()}", Colors.GREEN)
145
+ print_colored(f"📝 File: {env_file}", Colors.BLUE)
146
+
147
+ if provider == 'local':
148
+ print_colored(f" Model: {env_vars['LOCAL_MODEL_PATH']}", Colors.BLUE)
149
+ print_colored(f" Device: {env_vars['LOCAL_MODEL_DEVICE']}", Colors.BLUE)
150
+ print_colored(f" 8-bit: {env_vars['LOCAL_MODEL_8BIT']}", Colors.BLUE)
151
+ print_colored(f" 4-bit: {env_vars['LOCAL_MODEL_4BIT']}", Colors.BLUE)
152
+ elif provider == 'api':
153
+ print_colored(f" API URL: {env_vars['HF_API_BASE_URL']}", Colors.BLUE)
154
+
155
+ return env_file
156
+
157
+ def show_current():
158
+ """Show current LLM provider configuration."""
159
+ env_vars, env_file = read_env_file()
160
+
161
+ provider = env_vars.get('LLM_PROVIDER', 'none')
162
+
163
+ print_colored("\n" + "="*60, Colors.BOLD)
164
+ print_colored("Current LLM Provider Configuration", Colors.BOLD)
165
+ print_colored("="*60, Colors.RESET)
166
+ print_colored(f"Provider: {provider.upper()}", Colors.GREEN)
167
+ print_colored(f"Config file: {env_file}", Colors.BLUE)
168
+
169
+ if provider == 'local':
170
+ print_colored("\nLocal Model Settings:", Colors.YELLOW)
171
+ print(f" MODEL_PATH: {env_vars.get('LOCAL_MODEL_PATH', 'Qwen/Qwen2.5-7B-Instruct')}")
172
+ print(f" DEVICE: {env_vars.get('LOCAL_MODEL_DEVICE', 'auto')}")
173
+ print(f" 8BIT: {env_vars.get('LOCAL_MODEL_8BIT', 'true')}")
174
+ print(f" 4BIT: {env_vars.get('LOCAL_MODEL_4BIT', 'false')}")
175
+ elif provider == 'api':
176
+ print_colored("\nAPI Mode Settings:", Colors.YELLOW)
177
+ print(f" API_URL: {env_vars.get('HF_API_BASE_URL', 'https://davidtran999-hue-portal-backend.hf.space/api')}")
178
+ elif provider == 'openai':
179
+ has_key = 'OPENAI_API_KEY' in env_vars and env_vars['OPENAI_API_KEY']
180
+ print_colored(f"\nOpenAI Settings:", Colors.YELLOW)
181
+ print(f" API_KEY: {'✅ Set' if has_key else '❌ Not set'}")
182
+ elif provider == 'anthropic':
183
+ has_key = 'ANTHROPIC_API_KEY' in env_vars and env_vars['ANTHROPIC_API_KEY']
184
+ print_colored(f"\nAnthropic Settings:", Colors.YELLOW)
185
+ print(f" API_KEY: {'✅ Set' if has_key else '❌ Not set'}")
186
+ elif provider == 'ollama':
187
+ print_colored("\nOllama Settings:", Colors.YELLOW)
188
+ print(f" BASE_URL: {env_vars.get('OLLAMA_BASE_URL', 'http://localhost:11434')}")
189
+ print(f" MODEL: {env_vars.get('OLLAMA_MODEL', 'qwen2.5:7b')}")
190
+ elif provider == 'none':
191
+ print_colored("\n⚠️ No LLM provider configured. Using template-based generation.", Colors.YELLOW)
192
+
193
+ print_colored("="*60 + "\n", Colors.RESET)
194
+
195
+ def main():
196
+ """Main function."""
197
+ parser = argparse.ArgumentParser(
198
+ description='Switch LLM provider linh hoạt',
199
+ formatter_class=argparse.RawDescriptionHelpFormatter,
200
+ epilog="""
201
+ Examples:
202
+ # Switch to local model
203
+ python switch_llm_provider.py local
204
+
205
+ # Switch to local with custom model
206
+ python switch_llm_provider.py local --model Qwen/Qwen2.5-14B-Instruct --device cuda --8bit
207
+
208
+ # Switch to API mode
209
+ python switch_llm_provider.py api
210
+
211
+ # Switch to API with custom URL
212
+ python switch_llm_provider.py api --url https://custom-api.hf.space/api
213
+
214
+ # Switch to OpenAI
215
+ python switch_llm_provider.py openai
216
+
217
+ # Switch to Anthropic
218
+ python switch_llm_provider.py anthropic
219
+
220
+ # Switch to Ollama
221
+ python switch_llm_provider.py ollama
222
+
223
+ # Disable LLM (use templates only)
224
+ python switch_llm_provider.py none
225
+
226
+ # Show current configuration
227
+ python switch_llm_provider.py show
228
+ """
229
+ )
230
+
231
+ parser.add_argument(
232
+ 'provider',
233
+ choices=['local', 'api', 'openai', 'anthropic', 'ollama', 'none', 'show'],
234
+ help='LLM provider to use'
235
+ )
236
+
237
+ # Local model options
238
+ parser.add_argument('--model', '--model-path', dest='model_path',
239
+ help='Model path for local provider (e.g., Qwen/Qwen2.5-7B-Instruct)')
240
+ parser.add_argument('--device', choices=['auto', 'cpu', 'cuda'],
241
+ help='Device for local model (auto, cpu, cuda)')
242
+ parser.add_argument('--8bit', action='store_true',
243
+ help='Use 8-bit quantization for local model')
244
+ parser.add_argument('--4bit', action='store_true',
245
+ help='Use 4-bit quantization for local model')
246
+
247
+ # API mode options
248
+ parser.add_argument('--url', '--api-url', dest='api_url',
249
+ help='API URL for API mode')
250
+
251
+ args = parser.parse_args()
252
+
253
+ if args.provider == 'show':
254
+ show_current()
255
+ return 0
256
+
257
+ # Prepare kwargs
258
+ kwargs = {}
259
+
260
+ if args.provider == 'local':
261
+ if args.model_path:
262
+ kwargs['model_path'] = args.model_path
263
+ if args.device:
264
+ kwargs['device'] = args.device
265
+ if args.__dict__.get('8bit'):
266
+ kwargs['use_8bit'] = 'true'
267
+ kwargs['use_4bit'] = 'false'
268
+ elif args.__dict__.get('4bit'):
269
+ kwargs['use_4bit'] = 'true'
270
+ kwargs['use_8bit'] = 'false'
271
+ elif args.provider == 'api':
272
+ if args.api_url:
273
+ kwargs['api_url'] = args.api_url
274
+
275
+ # Set provider
276
+ try:
277
+ set_provider(args.provider, **kwargs)
278
+ print_colored("\n💡 Tip: Restart your Django server để áp dụng thay đổi!", Colors.YELLOW)
279
+ return 0
280
+ except Exception as e:
281
+ print_colored(f"❌ Error: {e}", Colors.RED)
282
+ import traceback
283
+ traceback.print_exc()
284
+ return 1
285
+
286
+ if __name__ == "__main__":
287
+ sys.exit(main())
288
+