#!/usr/bin/env python3 """ Mock inference server for testing CLI topic extraction without API costs. This server mimics an inference-server API endpoint and returns dummy responses that satisfy the validation requirements (markdown tables with |). """ import json import threading from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Optional class MockInferenceServerHandler(BaseHTTPRequestHandler): """HTTP request handler for the mock inference server.""" def _generate_mock_response(self, prompt: str, system_prompt: str) -> str: """ Generate a mock response that satisfies validation requirements. The response must: - Be longer than 120 characters - Contain a markdown table (with | characters) Args: prompt: The user prompt system_prompt: The system prompt Returns: A mock markdown table response """ # Generate a simple markdown table that satisfies the validation # This mimics a topic extraction table response mock_table = """| Reference | General Topic | Sub-topic | Sentiment | |-----------|---------------|-----------|-----------| | 1 | Test Topic | Test Subtopic | Positive | | 2 | Another Topic | Another Subtopic | Neutral | | 3 | Third Topic | Third Subtopic | Negative | This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs.""" return mock_table def _estimate_tokens(self, text: str) -> int: """Estimate token count (rough approximation: ~4 characters per token).""" return max(1, len(text) // 4) def do_POST(self): """Handle POST requests to /v1/chat/completions.""" print(f"[Mock Server] Received POST request to: {self.path}") if self.path == "/v1/chat/completions": try: # Read request body content_length = int(self.headers.get("Content-Length", 0)) print(f"[Mock Server] Content-Length: {content_length}") body = self.rfile.read(content_length) payload = json.loads(body.decode("utf-8")) print("[Mock Server] Payload received, processing...") # Extract messages messages = payload.get("messages", []) system_prompt = "" user_prompt = "" for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": system_prompt = content elif role == "user": user_prompt = content # Generate mock response response_text = self._generate_mock_response(user_prompt, system_prompt) # Estimate tokens input_tokens = self._estimate_tokens(system_prompt + "\n" + user_prompt) output_tokens = self._estimate_tokens(response_text) # Check if streaming is requested stream = payload.get("stream", False) if stream: # Handle streaming response self.send_response(200) self.send_header("Content-Type", "text/event-stream") self.send_header("Cache-Control", "no-cache") self.send_header("Connection", "keep-alive") self.end_headers() # Send streaming chunks chunk_size = 20 # Characters per chunk for i in range(0, len(response_text), chunk_size): chunk = response_text[i : i + chunk_size] chunk_data = { "choices": [ { "delta": {"content": chunk}, "index": 0, "finish_reason": None, } ] } self.wfile.write(f"data: {json.dumps(chunk_data)}\n\n".encode()) self.wfile.flush() # Send final done message self.wfile.write(b"data: [DONE]\n\n") self.wfile.flush() else: # Handle non-streaming response response_data = { "choices": [ { "index": 0, "finish_reason": "stop", "message": { "role": "assistant", "content": response_text, }, } ], "usage": { "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens, }, } self.send_response(200) self.send_header("Content-Type", "application/json") self.end_headers() self.wfile.write(json.dumps(response_data).encode()) except Exception as e: self.send_response(500) self.send_header("Content-Type", "application/json") self.end_headers() error_response = {"error": {"message": str(e), "type": "server_error"}} self.wfile.write(json.dumps(error_response).encode()) else: self.send_response(404) self.end_headers() def log_message(self, format, *args): """Log messages for debugging.""" # Enable logging for debugging print(f"[Mock Server] {format % args}") class MockInferenceServer: """Mock inference server that can be started and stopped for testing.""" def __init__(self, host: str = "localhost", port: int = 8080): """ Initialize the mock server. Args: host: Host to bind to (default: localhost) port: Port to bind to (default: 8080) """ self.host = host self.port = port self.server: Optional[HTTPServer] = None self.server_thread: Optional[threading.Thread] = None self.running = False def start(self): """Start the mock server in a separate thread.""" if self.running: return def run_server(): self.server = HTTPServer((self.host, self.port), MockInferenceServerHandler) self.running = True self.server.serve_forever() self.server_thread = threading.Thread(target=run_server, daemon=True) self.server_thread.start() # Wait a moment for server to start import time time.sleep(0.5) def stop(self): """Stop the mock server.""" if self.server and self.running: self.server.shutdown() self.server.server_close() self.running = False def get_url(self) -> str: """Get the server URL.""" return f"http://{self.host}:{self.port}" def __enter__(self): """Context manager entry.""" self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.stop() if __name__ == "__main__": # Test the server print("Starting mock inference server on http://localhost:8080") print("Press Ctrl+C to stop") server = MockInferenceServer() try: server.start() print(f"Server running at {server.get_url()}") # Keep running while True: import time time.sleep(1) except KeyboardInterrupt: print("\nStopping server...") server.stop() print("Server stopped")