File size: 8,276 Bytes
2cad7c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python3
"""
Mock inference server for testing CLI topic extraction without API costs.

This server mimics an inference-server API endpoint and returns dummy
responses that satisfy the validation requirements (markdown tables with |).
"""

import json
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Optional


class MockInferenceServerHandler(BaseHTTPRequestHandler):
    """HTTP request handler for the mock inference server."""

    def _generate_mock_response(self, prompt: str, system_prompt: str) -> str:
        """
        Generate a mock response that satisfies validation requirements.

        The response must:
        - Be longer than 120 characters
        - Contain a markdown table (with | characters)

        Args:
            prompt: The user prompt
            system_prompt: The system prompt

        Returns:
            A mock markdown table response
        """
        # Generate a simple markdown table that satisfies the validation
        # This mimics a topic extraction table response
        mock_table = """| Reference | General Topic | Sub-topic | Sentiment |
|-----------|---------------|-----------|-----------|
| 1 | Test Topic | Test Subtopic | Positive |
| 2 | Another Topic | Another Subtopic | Neutral |
| 3 | Third Topic | Third Subtopic | Negative |

This is a mock response from the test inference server. The actual content would be generated by a real LLM model, but for testing purposes, this dummy response allows us to verify that the CLI commands work correctly without incurring API costs."""

        return mock_table

    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count (rough approximation: ~4 characters per token)."""
        return max(1, len(text) // 4)

    def do_POST(self):
        """Handle POST requests to /v1/chat/completions."""
        print(f"[Mock Server] Received POST request to: {self.path}")
        if self.path == "/v1/chat/completions":
            try:
                # Read request body
                content_length = int(self.headers.get("Content-Length", 0))
                print(f"[Mock Server] Content-Length: {content_length}")
                body = self.rfile.read(content_length)
                payload = json.loads(body.decode("utf-8"))
                print("[Mock Server] Payload received, processing...")

                # Extract messages
                messages = payload.get("messages", [])
                system_prompt = ""
                user_prompt = ""

                for msg in messages:
                    role = msg.get("role", "")
                    content = msg.get("content", "")
                    if role == "system":
                        system_prompt = content
                    elif role == "user":
                        user_prompt = content

                # Generate mock response
                response_text = self._generate_mock_response(user_prompt, system_prompt)

                # Estimate tokens
                input_tokens = self._estimate_tokens(system_prompt + "\n" + user_prompt)
                output_tokens = self._estimate_tokens(response_text)

                # Check if streaming is requested
                stream = payload.get("stream", False)

                if stream:
                    # Handle streaming response
                    self.send_response(200)
                    self.send_header("Content-Type", "text/event-stream")
                    self.send_header("Cache-Control", "no-cache")
                    self.send_header("Connection", "keep-alive")
                    self.end_headers()

                    # Send streaming chunks
                    chunk_size = 20  # Characters per chunk
                    for i in range(0, len(response_text), chunk_size):
                        chunk = response_text[i : i + chunk_size]
                        chunk_data = {
                            "choices": [
                                {
                                    "delta": {"content": chunk},
                                    "index": 0,
                                    "finish_reason": None,
                                }
                            ]
                        }
                        self.wfile.write(f"data: {json.dumps(chunk_data)}\n\n".encode())
                        self.wfile.flush()

                    # Send final done message
                    self.wfile.write(b"data: [DONE]\n\n")
                    self.wfile.flush()
                else:
                    # Handle non-streaming response
                    response_data = {
                        "choices": [
                            {
                                "index": 0,
                                "finish_reason": "stop",
                                "message": {
                                    "role": "assistant",
                                    "content": response_text,
                                },
                            }
                        ],
                        "usage": {
                            "prompt_tokens": input_tokens,
                            "completion_tokens": output_tokens,
                            "total_tokens": input_tokens + output_tokens,
                        },
                    }

                    self.send_response(200)
                    self.send_header("Content-Type", "application/json")
                    self.end_headers()
                    self.wfile.write(json.dumps(response_data).encode())

            except Exception as e:
                self.send_response(500)
                self.send_header("Content-Type", "application/json")
                self.end_headers()
                error_response = {"error": {"message": str(e), "type": "server_error"}}
                self.wfile.write(json.dumps(error_response).encode())
        else:
            self.send_response(404)
            self.end_headers()

    def log_message(self, format, *args):
        """Log messages for debugging."""
        # Enable logging for debugging
        print(f"[Mock Server] {format % args}")


class MockInferenceServer:
    """Mock inference server that can be started and stopped for testing."""

    def __init__(self, host: str = "localhost", port: int = 8080):
        """
        Initialize the mock server.

        Args:
            host: Host to bind to (default: localhost)
            port: Port to bind to (default: 8080)
        """
        self.host = host
        self.port = port
        self.server: Optional[HTTPServer] = None
        self.server_thread: Optional[threading.Thread] = None
        self.running = False

    def start(self):
        """Start the mock server in a separate thread."""
        if self.running:
            return

        def run_server():
            self.server = HTTPServer((self.host, self.port), MockInferenceServerHandler)
            self.running = True
            self.server.serve_forever()

        self.server_thread = threading.Thread(target=run_server, daemon=True)
        self.server_thread.start()

        # Wait a moment for server to start
        import time

        time.sleep(0.5)

    def stop(self):
        """Stop the mock server."""
        if self.server and self.running:
            self.server.shutdown()
            self.server.server_close()
            self.running = False

    def get_url(self) -> str:
        """Get the server URL."""
        return f"http://{self.host}:{self.port}"

    def __enter__(self):
        """Context manager entry."""
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.stop()


if __name__ == "__main__":
    # Test the server
    print("Starting mock inference server on http://localhost:8080")
    print("Press Ctrl+C to stop")

    server = MockInferenceServer()
    try:
        server.start()
        print(f"Server running at {server.get_url()}")
        # Keep running
        while True:
            import time

            time.sleep(1)
    except KeyboardInterrupt:
        print("\nStopping server...")
        server.stop()
        print("Server stopped")