Commit
·
3572ba0
1
Parent(s):
ae5413a
fix: address SPEC-21 code quality issues
Browse files- Remove dead ContextVar code from token_tracking.py (was never set)
- Replace orphaned get_token_stats() with instance method get_stats()
- Add proper exports to workflows/__init__.py (SubIterationMiddleware, etc.)
- Clean up imports in huggingface.py (use package import)
- Add type: ignore for mypy list invariance issue
All 311 tests pass.
- src/clients/huggingface.py +4 -5
- src/middleware/token_tracking.py +11 -11
- src/workflows/__init__.py +13 -1
src/clients/huggingface.py
CHANGED
|
@@ -27,8 +27,7 @@ from agent_framework._types import FunctionCallContent, FunctionResultContent
|
|
| 27 |
from agent_framework.observability import use_observability
|
| 28 |
from huggingface_hub import InferenceClient
|
| 29 |
|
| 30 |
-
from src.middleware
|
| 31 |
-
from src.middleware.token_tracking import TokenTrackingMiddleware
|
| 32 |
from src.utils.config import settings
|
| 33 |
|
| 34 |
logger = structlog.get_logger()
|
|
@@ -53,13 +52,13 @@ class HuggingFaceChatClient(BaseChatClient): # type: ignore[misc]
|
|
| 53 |
api_key: HF_TOKEN (optional, defaults to env var).
|
| 54 |
**kwargs: Additional arguments passed to BaseChatClient.
|
| 55 |
"""
|
| 56 |
-
# Create middleware instances
|
| 57 |
-
middleware
|
| 58 |
RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0),
|
| 59 |
TokenTrackingMiddleware(),
|
| 60 |
]
|
| 61 |
|
| 62 |
-
super().__init__(middleware=middleware, **kwargs)
|
| 63 |
# FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
|
| 64 |
self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
|
| 65 |
self.api_key = api_key or settings.hf_token
|
|
|
|
| 27 |
from agent_framework.observability import use_observability
|
| 28 |
from huggingface_hub import InferenceClient
|
| 29 |
|
| 30 |
+
from src.middleware import RetryMiddleware, TokenTrackingMiddleware
|
|
|
|
| 31 |
from src.utils.config import settings
|
| 32 |
|
| 33 |
logger = structlog.get_logger()
|
|
|
|
| 52 |
api_key: HF_TOKEN (optional, defaults to env var).
|
| 53 |
**kwargs: Additional arguments passed to BaseChatClient.
|
| 54 |
"""
|
| 55 |
+
# Create middleware instances for retry and token tracking
|
| 56 |
+
middleware = [
|
| 57 |
RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0),
|
| 58 |
TokenTrackingMiddleware(),
|
| 59 |
]
|
| 60 |
|
| 61 |
+
super().__init__(middleware=middleware, **kwargs) # type: ignore[arg-type]
|
| 62 |
# FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
|
| 63 |
self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
|
| 64 |
self.api_key = api_key or settings.hf_token
|
src/middleware/token_tracking.py
CHANGED
|
@@ -1,16 +1,12 @@
|
|
| 1 |
"""Token tracking middleware for monitoring API usage."""
|
| 2 |
|
| 3 |
from collections.abc import Awaitable, Callable
|
| 4 |
-
from contextvars import ContextVar
|
| 5 |
|
| 6 |
import structlog
|
| 7 |
from agent_framework._middleware import ChatContext, ChatMiddleware
|
| 8 |
|
| 9 |
logger = structlog.get_logger()
|
| 10 |
|
| 11 |
-
# ContextVar for per-request token tracking
|
| 12 |
-
_request_tokens: ContextVar[dict[str, int]] = ContextVar("request_tokens")
|
| 13 |
-
|
| 14 |
|
| 15 |
class TokenTrackingMiddleware(ChatMiddleware):
|
| 16 |
"""Tracks token usage across chat requests.
|
|
@@ -64,10 +60,14 @@ class TokenTrackingMiddleware(ChatMiddleware):
|
|
| 64 |
total_requests=self.request_count,
|
| 65 |
)
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Token tracking middleware for monitoring API usage."""
|
| 2 |
|
| 3 |
from collections.abc import Awaitable, Callable
|
|
|
|
| 4 |
|
| 5 |
import structlog
|
| 6 |
from agent_framework._middleware import ChatContext, ChatMiddleware
|
| 7 |
|
| 8 |
logger = structlog.get_logger()
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class TokenTrackingMiddleware(ChatMiddleware):
|
| 12 |
"""Tracks token usage across chat requests.
|
|
|
|
| 60 |
total_requests=self.request_count,
|
| 61 |
)
|
| 62 |
|
| 63 |
+
def get_stats(self) -> dict[str, int]:
|
| 64 |
+
"""Get cumulative token usage statistics.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Dictionary with total_input, total_output, and request_count.
|
| 68 |
+
"""
|
| 69 |
+
return {
|
| 70 |
+
"total_input": self.total_input_tokens,
|
| 71 |
+
"total_output": self.total_output_tokens,
|
| 72 |
+
"request_count": self.request_count,
|
| 73 |
+
}
|
src/workflows/__init__.py
CHANGED
|
@@ -1 +1,13 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Workflow components for orchestration.
|
| 2 |
+
|
| 3 |
+
These are workflow patterns (e.g., team→judge loops), NOT interceptor middleware.
|
| 4 |
+
For interceptor middleware, see src/middleware/.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from src.workflows.sub_iteration import (
|
| 8 |
+
SubIterationJudge,
|
| 9 |
+
SubIterationMiddleware,
|
| 10 |
+
SubIterationTeam,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
__all__ = ["SubIterationJudge", "SubIterationMiddleware", "SubIterationTeam"]
|