llm_topic_modelling / test /mock_inference_server.py
seanpedrickcase's picture
Sync: Removed another s3 key, and unnecessary xlsx save print statement. Formatter check.
2cad7c3
#!/usr/bin/env python3
"""
Mock inference server for testing CLI topic extraction without API costs.
This server mimics an inference-server API endpoint and returns dummy
responses that satisfy the validation requirements (markdown tables with |).
"""
import json
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional
class MockInferenceServerHandler(BaseHTTPRequestHandler):
"""HTTP request handler for the mock inference server."""
def _generate_mock_response(self, prompt: str, system_prompt: str) -> str:
"""
Generate a mock response that satisfies validation requirements.
The response must:
- Be longer than 120 characters
- Contain a markdown table (with | characters)
Args:
prompt: The user prompt
system_prompt: The system prompt
Returns:
A mock markdown table response
"""
# Generate a simple markdown table that satisfies the validation
# This mimics a topic extraction table response
mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
|-----------|---------------|-----------|-----------|
| 1 | Test Topic | Test Subtopic | Positive |
| 2 | Another Topic | Another Subtopic | Neutral |
| 3 | Third Topic | Third Subtopic | Negative |
This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""
return mock_table
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count (rough approximation: ~4 characters per token)."""
return max(1, len(text) // 4)
def do_POST(self):
"""Handle POST requests to /v1/chat/completions."""
print(f"[Mock Server] Received POST request to: {self.path}")
if self.path == "/v1/chat/completions":
try:
# Read request body
content_length = int(self.headers.get("Content-Length", 0))
print(f"[Mock Server] Content-Length: {content_length}")
body = self.rfile.read(content_length)
payload = json.loads(body.decode("utf-8"))
print("[Mock Server] Payload received, processing...")
# Extract messages
messages = payload.get("messages", [])
system_prompt = ""
user_prompt = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
system_prompt = content
elif role == "user":
user_prompt = content
# Generate mock response
response_text = self._generate_mock_response(user_prompt, system_prompt)
# Estimate tokens
input_tokens = self._estimate_tokens(system_prompt + "\n" + user_prompt)
output_tokens = self._estimate_tokens(response_text)
# Check if streaming is requested
stream = payload.get("stream", False)
if stream:
# Handle streaming response
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.end_headers()
# Send streaming chunks
chunk_size = 20 # Characters per chunk
for i in range(0, len(response_text), chunk_size):
chunk = response_text[i : i + chunk_size]
chunk_data = {
"choices": [
{
"delta": {"content": chunk},
"index": 0,
"finish_reason": None,
}
]
}
self.wfile.write(f"data: {json.dumps(chunk_data)}\n\n".encode())
self.wfile.flush()
# Send final done message
self.wfile.write(b"data: [DONE]\n\n")
self.wfile.flush()
else:
# Handle non-streaming response
response_data = {
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": response_text,
},
}
],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
},
}
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(json.dumps(response_data).encode())
except Exception as e:
self.send_response(500)
self.send_header("Content-Type", "application/json")
self.end_headers()
error_response = {"error": {"message": str(e), "type": "server_error"}}
self.wfile.write(json.dumps(error_response).encode())
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
"""Log messages for debugging."""
# Enable logging for debugging
print(f"[Mock Server] {format % args}")
class MockInferenceServer:
"""Mock inference server that can be started and stopped for testing."""
def __init__(self, host: str = "localhost", port: int = 8080):
"""
Initialize the mock server.
Args:
host: Host to bind to (default: localhost)
port: Port to bind to (default: 8080)
"""
self.host = host
self.port = port
self.server: Optional[HTTPServer] = None
self.server_thread: Optional[threading.Thread] = None
self.running = False
def start(self):
"""Start the mock server in a separate thread."""
if self.running:
return
def run_server():
self.server = HTTPServer((self.host, self.port), MockInferenceServerHandler)
self.running = True
self.server.serve_forever()
self.server_thread = threading.Thread(target=run_server, daemon=True)
self.server_thread.start()
# Wait a moment for server to start
import time
time.sleep(0.5)
def stop(self):
"""Stop the mock server."""
if self.server and self.running:
self.server.shutdown()
self.server.server_close()
self.running = False
def get_url(self) -> str:
"""Get the server URL."""
return f"http://{self.host}:{self.port}"
def __enter__(self):
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.stop()
if __name__ == "__main__":
# Test the server
print("Starting mock inference server on http://localhost:8080")
print("Press Ctrl+C to stop")
server = MockInferenceServer()
try:
server.start()
print(f"Server running at {server.get_url()}")
# Keep running
while True:
import time
time.sleep(1)
except KeyboardInterrupt:
print("\nStopping server...")
server.stop()
print("Server stopped")