oscarzhang commited on
Commit
8713446
·
verified ·
1 Parent(s): 11080d1

Upload gradio_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_app.py +163 -0
gradio_app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from pathlib import Path
4
+ import importlib.util
5
+ from wearable_anomaly_detector import WearableAnomalyDetector
6
+
7
+ BASE_DIR = Path(__file__).parent
8
+ MODEL_DIR = BASE_DIR / "checkpoints" / "phase2" / "exp_factor_balanced"
9
+ LLM_INPUT_DIR = BASE_DIR / "demo_llm_inputs"
10
+ LLM_MANIFEST = LLM_INPUT_DIR / "manifest.json"
11
+ PATCHAD_CASE_DIR = BASE_DIR / "demo_patchad_cases"
12
+ PATCHAD_CASE_MANIFEST = PATCHAD_CASE_DIR / "manifest.json"
13
+
14
+ SAMPLES = {
15
+ "示例: 正常": BASE_DIR / "test_data" / "example_window.json",
16
+ "示例: 短期异常": BASE_DIR / "data_storage" / "users" / "demo_anomaly.jsonl",
17
+ "示例: 长期异常": BASE_DIR / "data_storage" / "users" / "demo_pattern.jsonl",
18
+ "示例: 缺失数据": BASE_DIR / "data_storage" / "users" / "demo_missing.jsonl",
19
+ }
20
+
21
+ LLM_CASES = {}
22
+ if LLM_MANIFEST.exists():
23
+ manifest = json.loads(LLM_MANIFEST.read_text(encoding="utf-8"))
24
+ for item in manifest:
25
+ display = item.get("title") or item.get("case_id")
26
+ file_name = item.get("file")
27
+ if display and file_name:
28
+ LLM_CASES[display] = LLM_INPUT_DIR / file_name
29
+
30
+ PATCHAD_CASES = {}
31
+ if PATCHAD_CASE_MANIFEST.exists():
32
+ manifest = json.loads(PATCHAD_CASE_MANIFEST.read_text(encoding="utf-8"))
33
+ for item in manifest:
34
+ display = item.get("title")
35
+ file_name = item.get("file")
36
+ if display and file_name:
37
+ PATCHAD_CASES[display] = PATCHAD_CASE_DIR / file_name
38
+
39
+ formatter_spec = importlib.util.spec_from_file_location(
40
+ "formatter", BASE_DIR / "utils" / "formatter.py"
41
+ )
42
+ formatter_module = importlib.util.module_from_spec(formatter_spec)
43
+ formatter_spec.loader.exec_module(formatter_module)
44
+ AnomalyFormatter = formatter_module.AnomalyFormatter
45
+
46
+ detector = WearableAnomalyDetector(model_dir=MODEL_DIR, device="cpu")
47
+ formatter = AnomalyFormatter()
48
+
49
+
50
+ def load_sample(path: Path):
51
+ if path.suffix == ".jsonl":
52
+ data = [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
53
+ else:
54
+ data = json.loads(path.read_text(encoding="utf-8"))
55
+ if isinstance(data, dict):
56
+ data = data.get("records") or data.get("data") or [data]
57
+ if not isinstance(data, list):
58
+ raise ValueError("样例文件需包含数组")
59
+ return data
60
+
61
+
62
+ def infer(sample_name: str):
63
+ window = load_sample(SAMPLES[sample_name])
64
+ realtime = detector.detect_realtime(window, update_baseline=False, return_details=True)
65
+ baseline_info = {
66
+ "baseline_mean": window[0]["features"].get("baseline_hrv_mean", 75.0),
67
+ "baseline_std": window[0]["features"].get("baseline_hrv_std", 5.0),
68
+ "current_value": sum(pt["features"].get("hrv_rmssd", 0) for pt in window) / len(window),
69
+ "deviation_pct": 0.0,
70
+ }
71
+ text = formatter.format_for_llm(realtime, baseline_info=baseline_info)
72
+ return json.dumps(realtime, ensure_ascii=False, indent=2), text
73
+
74
+
75
+ def show_llm_input(case_name: str):
76
+ if not LLM_CASES:
77
+ return {}, "当前未提供LLM输入示例。"
78
+ path = LLM_CASES[case_name]
79
+ data = json.loads(path.read_text(encoding="utf-8"))
80
+ messages = data.get("messages", [])
81
+ system_text = ""
82
+ user_text = ""
83
+ if messages:
84
+ system_msg = next((m for m in messages if m.get("role") == "system"), {})
85
+ user_msg = next((m for m in messages if m.get("role") == "user"), {})
86
+ system_text = system_msg.get("content", "")
87
+ user_text = user_msg.get("content", "")
88
+ display_text = "## 系统提示\n"
89
+ display_text += (system_text or "(无)")
90
+ display_text += "\n\n---\n\n## 用户输入(Markdown)\n"
91
+ display_text += (user_text or "(无)")
92
+ return data, display_text
93
+
94
+
95
+ def show_patchad_case(case_name: str):
96
+ if not PATCHAD_CASES:
97
+ return {}, {}, "当前未提供 PatchTrAD 示例。"
98
+ path = PATCHAD_CASES[case_name]
99
+ data = json.loads(path.read_text(encoding="utf-8"))
100
+ bundle = data.get("case_bundle", {})
101
+ summary = {
102
+ "sample": data.get("sample"),
103
+ "mode": data.get("mode"),
104
+ "precheck": data.get("precheck"),
105
+ "validation": bundle.get("validation"),
106
+ }
107
+ case_json = bundle.get("case", {})
108
+ llm_text = bundle.get("llm_input", "(无)")
109
+ return summary, case_json, llm_text
110
+
111
+
112
+ realtime_demo = gr.Interface(
113
+ fn=infer,
114
+ inputs=gr.Dropdown(choices=list(SAMPLES.keys()), value="示例: 正常", label="选择测试数据"),
115
+ outputs=[
116
+ gr.JSON(label="模型输出"),
117
+ gr.Markdown(label="LLM 文本"),
118
+ ],
119
+ title="Wearable Anomaly Detector Demo",
120
+ description="选择预置数据(正常/短期异常/长期异常/缺失数据)即可查看时序模型输出及格式化LLM文本。",
121
+ )
122
+
123
+ llm_input_demo = gr.Interface(
124
+ fn=show_llm_input,
125
+ inputs=gr.Dropdown(
126
+ choices=list(LLM_CASES.keys()) or ["暂无示例"],
127
+ value=list(LLM_CASES.keys())[0] if LLM_CASES else "暂无示例",
128
+ label="选择LLM输入示例",
129
+ ),
130
+ outputs=[
131
+ gr.JSON(label="messages JSON"),
132
+ gr.Markdown(label="完整用户输入"),
133
+ ],
134
+ title="标准化LLM输入示例",
135
+ description="直接展示系统提示+用户输入的完整Markdown,便于在Hugging Face页面选择典型案例查看。",
136
+ )
137
+
138
+ demo = gr.TabbedInterface(
139
+ [
140
+ realtime_demo,
141
+ llm_input_demo,
142
+ gr.Interface(
143
+ fn=show_patchad_case,
144
+ inputs=gr.Dropdown(
145
+ choices=list(PATCHAD_CASES.keys()) or ["暂无示例"],
146
+ value=list(PATCHAD_CASES.keys())[0] if PATCHAD_CASES else "暂无示例",
147
+ label="选择 PatchTrAD 案例",
148
+ ),
149
+ outputs=[
150
+ gr.JSON(label="摘要 / 预筛信息"),
151
+ gr.JSON(label="Case JSON"),
152
+ gr.Markdown(label="LLM 输入"),
153
+ ],
154
+ title="PatchTrAD + build_case 案例",
155
+ description="从预置示例中选择模式A/B与不同数据样本,查看完整 case 与 LLM 输入。",
156
+ ),
157
+ ],
158
+ ["实时窗口检测", "LLM输入示例", "PatchTrAD案例"],
159
+ )
160
+
161
+ if __name__ == "__main__":
162
+ demo.launch()
163
+