#!/usr/bin/env python3 """ run_official_inference.py 最小化测试脚本:读取一个窗口 JSON 文件 -> 调用 WearableAnomalyDetector -> 打印模型输出及格式化文本。 使用方式: python run_official_inference.py \ --window-file test_data/example_window.json \ --model-dir checkpoints/phase2/exp_factor_balanced """ from __future__ import annotations import argparse import json from pathlib import Path from typing import List, Dict, Any import importlib.util from wearable_anomaly_detector import WearableAnomalyDetector def load_formatter(): formatter_path = Path(__file__).parent / "utils" / "formatter.py" spec = importlib.util.spec_from_file_location("formatter", formatter_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module.AnomalyFormatter def load_window(path: Path) -> List[Dict[str, Any]]: if path.suffix == ".jsonl": with open(path, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f if line.strip()] else: with open(path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, dict): data = data.get("records") or data.get("data") or [data] if not isinstance(data, list) or not data: raise ValueError("窗口文件必须是非空列表") if len(data) < 12: raise ValueError("窗口数据至少需要 12 条记录") return data[-12:] def build_baseline_info(window: List[Dict[str, Any]]) -> Dict[str, float]: # 优先使用输入中的 baseline 字段,否则简单按窗口平均值估算 for point in window: baseline_mean = point["features"].get("baseline_hrv_mean") baseline_std = point["features"].get("baseline_hrv_std") if baseline_mean is not None and baseline_std is not None: current = point["features"].get("hrv_rmssd") deviation = 0.0 if current is not None: deviation = (current - baseline_mean) / baseline_mean * 100 return { "baseline_mean": float(baseline_mean), "baseline_std": float(baseline_std), "current_value": float(current or baseline_mean), "deviation_pct": float(deviation), } avg_hrv = sum(pt["features"].get("hrv_rmssd", 0.0) for pt in window) / len(window) return { "baseline_mean": avg_hrv, "baseline_std": 5.0, "current_value": avg_hrv, "deviation_pct": 0.0, } def main() -> None: parser = argparse.ArgumentParser(description="Run wearable anomaly detector on a JSON window file.") parser.add_argument( "--window-file", type=Path, default=Path("test_data/example_window.json"), help="包含 12 条数据点的 JSON 文件路径", ) parser.add_argument( "--model-dir", type=Path, default=Path("checkpoints/phase2/exp_factor_balanced"), help="Phase2 最佳模型所在目录", ) parser.add_argument( "--device", type=str, default=None, help="可选:cpu / cuda / cuda:0 等", ) args = parser.parse_args() if not args.window_file.exists(): raise FileNotFoundError(f"窗口文件不存在:{args.window_file}") window = load_window(args.window_file) detector = WearableAnomalyDetector(model_dir=args.model_dir, device=args.device) result = detector.detect_realtime(window, update_baseline=False, return_details=True) print("\n=== 模型输出(JSON)===") print(json.dumps(result, ensure_ascii=False, indent=2)) formatter_cls = load_formatter() formatter = formatter_cls() baseline_info = build_baseline_info(window) formatted = formatter.format_for_llm( anomaly_result=result, baseline_info=baseline_info, daily_results=None, ) print("\n=== LLM 文本 ===") print(formatted) if __name__ == "__main__": main()