|
|
import json
|
|
|
import os
|
|
|
import sys
|
|
|
import urllib.error
|
|
|
import urllib.request
|
|
|
|
|
|
|
|
|
def _http_json(method: str, url: str, body: dict | None = None, headers: dict | None = None, timeout: int = 60):
|
|
|
data = None
|
|
|
if body is not None:
|
|
|
data = json.dumps(body).encode("utf-8")
|
|
|
|
|
|
req = urllib.request.Request(url=url, data=data, method=method)
|
|
|
req.add_header("Content-Type", "application/json")
|
|
|
|
|
|
if headers:
|
|
|
for k, v in headers.items():
|
|
|
if v is not None and v != "":
|
|
|
req.add_header(k, v)
|
|
|
|
|
|
try:
|
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
|
raw = resp.read().decode("utf-8")
|
|
|
if not raw:
|
|
|
return resp.status, None
|
|
|
return resp.status, json.loads(raw)
|
|
|
except urllib.error.HTTPError as e:
|
|
|
raw = e.read().decode("utf-8") if e.fp else ""
|
|
|
try:
|
|
|
parsed = json.loads(raw) if raw else {"detail": raw}
|
|
|
except Exception:
|
|
|
parsed = {"detail": raw}
|
|
|
return e.code, parsed
|
|
|
|
|
|
|
|
|
def _assert(cond: bool, msg: str):
|
|
|
if not cond:
|
|
|
raise AssertionError(msg)
|
|
|
|
|
|
|
|
|
def _validate_visualization(vis: dict):
|
|
|
_assert(isinstance(vis, dict), "visualization must be an object")
|
|
|
_assert("type" in vis and isinstance(vis["type"], str), "visualization.type must be a string")
|
|
|
_assert(vis["type"] in {"bar", "line", "area", "pie", "table", "report"}, "visualization.type invalid")
|
|
|
_assert("title" in vis and isinstance(vis["title"], str), "visualization.title must be a string")
|
|
|
|
|
|
if "data" in vis and vis["data"] is not None:
|
|
|
_assert(isinstance(vis["data"], list), "visualization.data must be a list")
|
|
|
|
|
|
if "xAxisKey" in vis and vis["xAxisKey"] is not None:
|
|
|
_assert(isinstance(vis["xAxisKey"], str), "visualization.xAxisKey must be a string")
|
|
|
|
|
|
if "series" in vis and vis["series"] is not None:
|
|
|
_assert(isinstance(vis["series"], list), "visualization.series must be a list")
|
|
|
for s in vis["series"]:
|
|
|
_assert(isinstance(s, dict), "series item must be an object")
|
|
|
_assert(isinstance(s.get("dataKey"), str), "series.dataKey must be a string")
|
|
|
_assert(isinstance(s.get("name"), str), "series.name must be a string")
|
|
|
_assert(isinstance(s.get("color"), str), "series.color must be a string")
|
|
|
|
|
|
if "columns" in vis and vis["columns"] is not None:
|
|
|
_assert(isinstance(vis["columns"], list), "visualization.columns must be a list")
|
|
|
for c in vis["columns"]:
|
|
|
_assert(isinstance(c, dict), "column item must be an object")
|
|
|
_assert(isinstance(c.get("key"), str), "column.key must be a string")
|
|
|
_assert(isinstance(c.get("label"), str), "column.label must be a string")
|
|
|
|
|
|
if "insights" in vis and vis["insights"] is not None:
|
|
|
_assert(isinstance(vis["insights"], list), "visualization.insights must be a list")
|
|
|
for i in vis["insights"]:
|
|
|
_assert(isinstance(i, str), "insights items must be strings")
|
|
|
|
|
|
|
|
|
def _validate_response(resp: dict):
|
|
|
_assert(isinstance(resp, dict), "Response must be a JSON object")
|
|
|
_assert("answer" in resp and isinstance(resp["answer"], str), "Response.answer must be a string")
|
|
|
_assert("visualization" in resp, "Response.visualization must exist (can be null)")
|
|
|
if resp["visualization"] is not None:
|
|
|
_validate_visualization(resp["visualization"])
|
|
|
|
|
|
|
|
|
def main():
|
|
|
base_url = os.environ.get("BASE_URL") or (sys.argv[1] if len(sys.argv) > 1 else "").strip()
|
|
|
if not base_url:
|
|
|
print("Usage: python test.py <BASE_URL> OR set BASE_URL env var")
|
|
|
print("Example: python test.py https://your-space-name.hf.space")
|
|
|
sys.exit(2)
|
|
|
|
|
|
base_url = base_url.rstrip("/")
|
|
|
|
|
|
api_key = os.environ.get("BACKEND_API_KEY", "")
|
|
|
headers = {}
|
|
|
if api_key:
|
|
|
headers["X-API-Key"] = api_key
|
|
|
|
|
|
print(f"BASE_URL = {base_url}")
|
|
|
|
|
|
|
|
|
status, health = _http_json("GET", f"{base_url}/health", headers=headers)
|
|
|
print(f"/health -> {status}")
|
|
|
print(json.dumps(health, indent=2, ensure_ascii=False))
|
|
|
|
|
|
|
|
|
prompts = [
|
|
|
"Show me the water level trend for Koyna in May 2025",
|
|
|
"List top 10 dams by live_storage",
|
|
|
"Give me a pie distribution of dams by storage bands",
|
|
|
]
|
|
|
|
|
|
for idx, prompt in enumerate(prompts, start=1):
|
|
|
payload = {"prompt": prompt, "context": {"page": "reports"}}
|
|
|
status, resp = _http_json("POST", f"{base_url}/api/ai/query", body=payload, headers=headers)
|
|
|
print("\n" + "=" * 80)
|
|
|
print(f"Test #{idx}: {prompt}")
|
|
|
print(f"/api/ai/query -> {status}")
|
|
|
print(json.dumps(resp, indent=2, ensure_ascii=False))
|
|
|
|
|
|
if status == 200:
|
|
|
_validate_response(resp)
|
|
|
print("Contract: OK")
|
|
|
else:
|
|
|
print("Contract: SKIPPED (non-200)")
|
|
|
|
|
|
print("\nDone.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|