| """ |
| ai_client.py |
| Thin wrapper around the Anthropic API with chunked processing and streaming. |
| Author: algorembrant |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from typing import Iterator, Optional |
|
|
| import anthropic |
|
|
| from config import DEFAULT_MODEL, MAX_TOKENS, CHUNK_SIZE |
|
|
|
|
| |
| |
| |
| _client: Optional[anthropic.Anthropic] = None |
|
|
|
|
| def _get_client() -> anthropic.Anthropic: |
| global _client |
| if _client is None: |
| _client = anthropic.Anthropic() |
| return _client |
|
|
|
|
| |
| |
| |
|
|
| def complete( |
| system: str, |
| user: str, |
| model: str = DEFAULT_MODEL, |
| max_tokens: int = MAX_TOKENS, |
| stream: bool = True, |
| ) -> str: |
| """ |
| Run a single completion and return the full response text. |
| Streams tokens to stderr if `stream=True` so the user sees progress. |
| """ |
| client = _get_client() |
|
|
| if stream: |
| result_parts: list[str] = [] |
| with client.messages.stream( |
| model=model, |
| max_tokens=max_tokens, |
| system=system, |
| messages=[{"role": "user", "content": user}], |
| ) as stream_ctx: |
| for text in stream_ctx.text_stream: |
| print(text, end="", flush=True, file=sys.stderr) |
| result_parts.append(text) |
| print(file=sys.stderr) |
| return "".join(result_parts) |
| else: |
| response = client.messages.create( |
| model=model, |
| max_tokens=max_tokens, |
| system=system, |
| messages=[{"role": "user", "content": user}], |
| ) |
| return response.content[0].text |
|
|
|
|
| def _split_into_chunks(text: str, chunk_size: int = CHUNK_SIZE) -> list[str]: |
| """ |
| Split text into chunks of at most `chunk_size` characters, |
| breaking on paragraph or sentence boundaries where possible. |
| """ |
| if len(text) <= chunk_size: |
| return [text] |
|
|
| chunks: list[str] = [] |
| start = 0 |
| while start < len(text): |
| end = start + chunk_size |
| if end >= len(text): |
| chunks.append(text[start:]) |
| break |
|
|
| |
| split_at = text.rfind("\n\n", start, end) |
| if split_at == -1: |
| |
| split_at = text.rfind(". ", start, end) |
| if split_at == -1: |
| |
| split_at = text.rfind(" ", start, end) |
| if split_at == -1: |
| split_at = end |
|
|
| chunks.append(text[start : split_at + 1]) |
| start = split_at + 1 |
|
|
| return chunks |
|
|
|
|
| def complete_long( |
| system: str, |
| user_prefix: str, |
| text: str, |
| user_suffix: str = "", |
| model: str = DEFAULT_MODEL, |
| max_tokens: int = MAX_TOKENS, |
| merge_system: Optional[str] = None, |
| stream: bool = True, |
| ) -> str: |
| """ |
| Process a potentially long text by splitting it into chunks, |
| running a completion on each, then optionally merging the results. |
| |
| Args: |
| system: System prompt. |
| user_prefix: Text prepended before each chunk in the user message. |
| text: The main content to process (may be chunked). |
| user_suffix: Text appended after each chunk in the user message. |
| model: Anthropic model identifier. |
| max_tokens: Max output tokens per call. |
| merge_system: If provided and there are multiple chunks, a final |
| merge pass is run with this system prompt. |
| stream: Whether to stream tokens to stderr. |
| |
| Returns: |
| Final processed text (merged if multi-chunk). |
| """ |
| chunks = _split_into_chunks(text) |
| n = len(chunks) |
|
|
| if n == 1: |
| user_msg = f"{user_prefix}\n\n{chunks[0]}" |
| if user_suffix: |
| user_msg += f"\n\n{user_suffix}" |
| return complete(system, user_msg, model=model, max_tokens=max_tokens, stream=stream) |
|
|
| |
| print( |
| f"[info] Text is large ({len(text):,} chars). Processing in {n} chunks.", |
| file=sys.stderr, |
| ) |
| partial_results: list[str] = [] |
| for i, chunk in enumerate(chunks, 1): |
| print(f"\n[chunk {i}/{n}]", file=sys.stderr) |
| user_msg = ( |
| f"{user_prefix}\n\n" |
| f"[Part {i} of {n}]\n\n{chunk}" |
| ) |
| if user_suffix: |
| user_msg += f"\n\n{user_suffix}" |
| result = complete(system, user_msg, model=model, max_tokens=max_tokens, stream=stream) |
| partial_results.append(result) |
|
|
| combined = "\n\n".join(partial_results) |
|
|
| |
| if merge_system and n > 1: |
| print(f"\n[merging {n} chunks into final output]", file=sys.stderr) |
| combined = complete( |
| merge_system, |
| f"Merge and unify the following {n} sections into a single cohesive output:\n\n{combined}", |
| model=model, |
| max_tokens=max_tokens, |
| stream=stream, |
| ) |
|
|
| return combined |
|
|