Spaces:
Running
Running
| """SSE event builder for Anthropic-format streaming responses.""" | |
| import hashlib | |
| import json | |
| from collections.abc import Iterator | |
| from dataclasses import dataclass, field | |
| from loguru import logger | |
| try: | |
| import tiktoken | |
| ENCODER = tiktoken.get_encoding("cl100k_base") | |
| except Exception: | |
| ENCODER = None | |
| # Standard headers for Anthropic-style ``text/event-stream`` responses from this proxy. | |
| ANTHROPIC_SSE_RESPONSE_HEADERS: dict[str, str] = { | |
| "X-Accel-Buffering": "no", | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| } | |
| STOP_REASON_MAP = { | |
| "stop": "end_turn", | |
| "length": "max_tokens", | |
| "tool_calls": "tool_use", | |
| "content_filter": "end_turn", | |
| } | |
| def map_stop_reason(openai_reason: str | None) -> str: | |
| """Map OpenAI finish_reason to Anthropic stop_reason.""" | |
| return ( | |
| STOP_REASON_MAP.get(openai_reason, "end_turn") if openai_reason else "end_turn" | |
| ) | |
| def _safe_usage_int(value: object) -> int: | |
| """Coerce streamed usage counters to int; non-integers become 0.""" | |
| return value if isinstance(value, int) else 0 | |
| def format_sse_event(event_type: str, data: dict) -> str: | |
| """Format one Anthropic-style SSE event (no logging).""" | |
| return f"event: {event_type}\ndata: {json.dumps(data)}\n\n" | |
| class ToolCallState: | |
| """State for a single streaming tool call.""" | |
| block_index: int | |
| tool_id: str | |
| name: str | |
| contents: list[str] = field(default_factory=list) | |
| started: bool = False | |
| task_arg_buffer: str = "" | |
| task_args_emitted: bool = False | |
| pre_start_args: str = "" | |
| class ContentBlockManager: | |
| """Manage content block indices and state.""" | |
| next_index: int = 0 | |
| thinking_index: int = -1 | |
| text_index: int = -1 | |
| thinking_started: bool = False | |
| text_started: bool = False | |
| tool_states: dict[int, ToolCallState] = field(default_factory=dict) | |
| def allocate_index(self) -> int: | |
| idx = self.next_index | |
| self.next_index += 1 | |
| return idx | |
| def ensure_tool_state(self, index: int) -> ToolCallState: | |
| """Create tool stream state for ``index`` when the first tool delta arrives.""" | |
| if index not in self.tool_states: | |
| self.tool_states[index] = ToolCallState(block_index=-1, tool_id="", name="") | |
| return self.tool_states[index] | |
| def set_stream_tool_id(self, index: int, tool_id: str | None) -> None: | |
| """Record OpenAI tool call id before ``content_block_start`` (split-stream providers).""" | |
| if not tool_id: | |
| return | |
| state = self.ensure_tool_state(index) | |
| state.tool_id = str(tool_id) | |
| def register_tool_name(self, index: int, name: str) -> None: | |
| """Record tool name fragments as they arrive from chunked OpenAI streams. | |
| Names may be split across deltas; later chunks can extend (``ab`` + ``c``) | |
| or repeat prefixes, so we merge conservatively. | |
| """ | |
| if index not in self.tool_states: | |
| self.tool_states[index] = ToolCallState( | |
| block_index=-1, tool_id="", name=name | |
| ) | |
| return | |
| state = self.tool_states[index] | |
| prev = state.name | |
| if not prev or name.startswith(prev): | |
| state.name = name | |
| elif not prev.startswith(name): | |
| state.name = prev + name | |
| def buffer_task_args(self, index: int, args: str) -> dict | None: | |
| state = self.tool_states.get(index) | |
| if state is None or state.task_args_emitted: | |
| return None | |
| state.task_arg_buffer += args | |
| try: | |
| args_json = json.loads(state.task_arg_buffer) | |
| except Exception: | |
| return None | |
| _normalize_task_run_in_background(args_json) | |
| state.task_args_emitted = True | |
| state.task_arg_buffer = "" | |
| return args_json | |
| def has_emitted_tool_block(self) -> bool: | |
| """True when native OpenAI tool streaming has started a ``tool_use`` block.""" | |
| return any(s.started for s in self.tool_states.values()) | |
| def flush_task_arg_buffers(self) -> list[tuple[int, str]]: | |
| results: list[tuple[int, str]] = [] | |
| for tool_index, state in list(self.tool_states.items()): | |
| if not state.task_arg_buffer or state.task_args_emitted: | |
| continue | |
| out = "{}" | |
| try: | |
| args_json = json.loads(state.task_arg_buffer) | |
| _normalize_task_run_in_background(args_json) | |
| out = json.dumps(args_json) | |
| except (json.JSONDecodeError, TypeError, ValueError) as e: | |
| digest = hashlib.sha256( | |
| state.task_arg_buffer.encode("utf-8", errors="replace") | |
| ).hexdigest()[:16] | |
| logger.warning( | |
| "Task args invalid JSON (id={} len={} buffer_sha256_prefix={}): {}", | |
| state.tool_id or "unknown", | |
| len(state.task_arg_buffer), | |
| digest, | |
| e, | |
| ) | |
| state.task_args_emitted = True | |
| state.task_arg_buffer = "" | |
| results.append((tool_index, out)) | |
| return results | |
| def _normalize_task_run_in_background(args_json: dict) -> None: | |
| """Force Claude Code Task subagents to run in foreground (single shared rule).""" | |
| if args_json.get("run_in_background") is not False: | |
| args_json["run_in_background"] = False | |
| class SSEBuilder: | |
| """Builder for Anthropic SSE streaming events.""" | |
| def __init__( | |
| self, | |
| message_id: str, | |
| model: str, | |
| input_tokens: int = 0, | |
| *, | |
| log_raw_events: bool = False, | |
| ): | |
| self.message_id = message_id | |
| self.model = model | |
| self.input_tokens = input_tokens | |
| self._log_raw_events = log_raw_events | |
| self.blocks = ContentBlockManager() | |
| self._accumulated_text_parts: list[str] = [] | |
| self._accumulated_reasoning_parts: list[str] = [] | |
| def _format_event(self, event_type: str, data: dict) -> str: | |
| event_str = format_sse_event(event_type, data) | |
| if self._log_raw_events: | |
| logger.debug("SSE_EVENT: {} - {}", event_type, event_str.strip()) | |
| else: | |
| logger.debug( | |
| "SSE_EVENT: event_type={} serialized_bytes={}", | |
| event_type, | |
| len(event_str.encode("utf-8")), | |
| ) | |
| return event_str | |
| def message_start(self) -> str: | |
| safe_input = _safe_usage_int(self.input_tokens) | |
| usage = {"input_tokens": safe_input, "output_tokens": 1} | |
| return self._format_event( | |
| "message_start", | |
| { | |
| "type": "message_start", | |
| "message": { | |
| "id": self.message_id, | |
| "type": "message", | |
| "role": "assistant", | |
| "content": [], | |
| "model": self.model, | |
| "stop_reason": None, | |
| "stop_sequence": None, | |
| "usage": usage, | |
| }, | |
| }, | |
| ) | |
| def message_delta(self, stop_reason: str, output_tokens: int | None) -> str: | |
| safe_in = _safe_usage_int(self.input_tokens) | |
| safe_out = output_tokens if isinstance(output_tokens, int) else 0 | |
| return self._format_event( | |
| "message_delta", | |
| { | |
| "type": "message_delta", | |
| "delta": {"stop_reason": stop_reason, "stop_sequence": None}, | |
| "usage": { | |
| "input_tokens": safe_in, | |
| "output_tokens": safe_out, | |
| }, | |
| }, | |
| ) | |
| def message_stop(self) -> str: | |
| return self._format_event("message_stop", {"type": "message_stop"}) | |
| def content_block_start(self, index: int, block_type: str, **kwargs) -> str: | |
| content_block: dict = {"type": block_type} | |
| if block_type == "thinking": | |
| content_block["thinking"] = kwargs.get("thinking", "") | |
| elif block_type == "text": | |
| content_block["text"] = kwargs.get("text", "") | |
| elif block_type == "tool_use": | |
| content_block["id"] = kwargs.get("id", "") | |
| content_block["name"] = kwargs.get("name", "") | |
| content_block["input"] = kwargs.get("input", {}) | |
| return self._format_event( | |
| "content_block_start", | |
| { | |
| "type": "content_block_start", | |
| "index": index, | |
| "content_block": content_block, | |
| }, | |
| ) | |
| def content_block_delta(self, index: int, delta_type: str, content: str) -> str: | |
| delta: dict = {"type": delta_type} | |
| if delta_type == "thinking_delta": | |
| delta["thinking"] = content | |
| elif delta_type == "text_delta": | |
| delta["text"] = content | |
| elif delta_type == "input_json_delta": | |
| delta["partial_json"] = content | |
| return self._format_event( | |
| "content_block_delta", | |
| { | |
| "type": "content_block_delta", | |
| "index": index, | |
| "delta": delta, | |
| }, | |
| ) | |
| def content_block_stop(self, index: int) -> str: | |
| return self._format_event( | |
| "content_block_stop", | |
| { | |
| "type": "content_block_stop", | |
| "index": index, | |
| }, | |
| ) | |
| def start_thinking_block(self) -> str: | |
| self.blocks.thinking_index = self.blocks.allocate_index() | |
| self.blocks.thinking_started = True | |
| return self.content_block_start(self.blocks.thinking_index, "thinking") | |
| def emit_thinking_delta(self, content: str) -> str: | |
| self._accumulated_reasoning_parts.append(content) | |
| return self.content_block_delta( | |
| self.blocks.thinking_index, "thinking_delta", content | |
| ) | |
| def stop_thinking_block(self) -> str: | |
| self.blocks.thinking_started = False | |
| return self.content_block_stop(self.blocks.thinking_index) | |
| def start_text_block(self) -> str: | |
| self.blocks.text_index = self.blocks.allocate_index() | |
| self.blocks.text_started = True | |
| return self.content_block_start(self.blocks.text_index, "text") | |
| def emit_text_delta(self, content: str) -> str: | |
| self._accumulated_text_parts.append(content) | |
| return self.content_block_delta(self.blocks.text_index, "text_delta", content) | |
| def stop_text_block(self) -> str: | |
| self.blocks.text_started = False | |
| return self.content_block_stop(self.blocks.text_index) | |
| def start_tool_block(self, tool_index: int, tool_id: str, name: str) -> str: | |
| block_idx = self.blocks.allocate_index() | |
| if tool_index in self.blocks.tool_states: | |
| state = self.blocks.tool_states[tool_index] | |
| state.block_index = block_idx | |
| state.tool_id = tool_id | |
| state.started = True | |
| else: | |
| self.blocks.tool_states[tool_index] = ToolCallState( | |
| block_index=block_idx, | |
| tool_id=tool_id, | |
| name=name, | |
| started=True, | |
| ) | |
| return self.content_block_start(block_idx, "tool_use", id=tool_id, name=name) | |
| def emit_tool_delta(self, tool_index: int, partial_json: str) -> str: | |
| state = self.blocks.tool_states[tool_index] | |
| state.contents.append(partial_json) | |
| return self.content_block_delta( | |
| state.block_index, "input_json_delta", partial_json | |
| ) | |
| def stop_tool_block(self, tool_index: int) -> str: | |
| block_idx = self.blocks.tool_states[tool_index].block_index | |
| return self.content_block_stop(block_idx) | |
| def ensure_thinking_block(self) -> Iterator[str]: | |
| if self.blocks.text_started: | |
| yield self.stop_text_block() | |
| if not self.blocks.thinking_started: | |
| yield self.start_thinking_block() | |
| def ensure_text_block(self) -> Iterator[str]: | |
| if self.blocks.thinking_started: | |
| yield self.stop_thinking_block() | |
| if not self.blocks.text_started: | |
| yield self.start_text_block() | |
| def close_content_blocks(self) -> Iterator[str]: | |
| if self.blocks.thinking_started: | |
| yield self.stop_thinking_block() | |
| if self.blocks.text_started: | |
| yield self.stop_text_block() | |
| def close_all_blocks(self) -> Iterator[str]: | |
| yield from self.close_content_blocks() | |
| for tool_index, state in list(self.blocks.tool_states.items()): | |
| if state.started: | |
| yield self.stop_tool_block(tool_index) | |
| def emit_error(self, error_message: str) -> Iterator[str]: | |
| error_index = self.blocks.allocate_index() | |
| yield self.content_block_start(error_index, "text") | |
| yield self.content_block_delta(error_index, "text_delta", error_message) | |
| yield self.content_block_stop(error_index) | |
| def emit_top_level_error(self, error_message: str) -> str: | |
| """Emit a top-level ``event: error`` (not assistant text) for transport failures.""" | |
| return self._format_event( | |
| "error", | |
| { | |
| "type": "error", | |
| "error": { | |
| "type": "api_error", | |
| "message": error_message, | |
| }, | |
| }, | |
| ) | |
| def accumulated_text(self) -> str: | |
| return "".join(self._accumulated_text_parts) | |
| def accumulated_reasoning(self) -> str: | |
| return "".join(self._accumulated_reasoning_parts) | |
| def estimate_output_tokens(self) -> int: | |
| accumulated_text = self.accumulated_text | |
| accumulated_reasoning = self.accumulated_reasoning | |
| if ENCODER: | |
| text_tokens = len(ENCODER.encode(accumulated_text)) | |
| reasoning_tokens = len(ENCODER.encode(accumulated_reasoning)) | |
| tool_tokens = 0 | |
| started_tool_count = 0 | |
| for state in self.blocks.tool_states.values(): | |
| tool_tokens += len(ENCODER.encode(state.name)) | |
| tool_tokens += len(ENCODER.encode("".join(state.contents))) | |
| tool_tokens += 15 | |
| if state.started: | |
| started_tool_count += 1 | |
| block_count = ( | |
| (1 if accumulated_reasoning else 0) | |
| + (1 if accumulated_text else 0) | |
| + started_tool_count | |
| ) | |
| return text_tokens + reasoning_tokens + tool_tokens + (block_count * 4) | |
| text_tokens = len(accumulated_text) // 4 | |
| reasoning_tokens = len(accumulated_reasoning) // 4 | |
| tool_tokens = ( | |
| sum(1 for state in self.blocks.tool_states.values() if state.started) * 50 | |
| ) | |
| return text_tokens + reasoning_tokens + tool_tokens | |