File size: 11,031 Bytes
4d2821f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python3
"""
Patch OpenEnv web_interface.py to add:
- Loss/perplexity chart and updateLossChart()
- POST /web/run-baseline and GET /web/current-task for baseline comparison
Idempotent: safe to run multiple times.
"""
import sys
from pathlib import Path


def _apply_routes_patch(text: str) -> str:
    """Add /web/run-baseline and /web/current-task routes."""
    old_routes = (
        '    @app.get("/web/state")\n'
        "    async def web_state():\n"
        '        """State endpoint for web interface."""\n'
        "        return web_manager.get_state()\n"
        "\n"
        "    return app"
    )
    new_routes = (
        '    @app.get("/web/state")\n'
        "    async def web_state():\n"
        '        """State endpoint for web interface."""\n'
        "        return web_manager.get_state()\n"
        "\n"
        '    @app.get("/web/current-task")\n'
        "    async def web_current_task():\n"
        '        """Current task spec for baseline comparison (if env supports it)."""\n'
        "        get_spec = getattr(web_manager.env, \"get_current_task_spec\", None)\n"
        "        if get_spec is None:\n"
        "            return {}\n"
        "        return get_spec() or {}\n"
        "\n"
        '    @app.post("/web/run-baseline")\n'
        "    async def web_run_baseline():\n"
        '        """Run baseline optimizer for current task; returns loss_trajectory and steps."""\n'
        "        run_bl = getattr(web_manager.env, \"run_baseline\", None)\n"
        "        if run_bl is None:\n"
        "            return {\"loss_trajectory\": [], \"steps\": [], \"error\": \"Env has no run_baseline\"}\n"
        "        return run_bl()\n"
        "\n"
        "    return app"
    )
    if "web/run-baseline" not in text and "web/state" in text and "return web_manager.get_state()" in text:
        text = text.replace(old_routes, new_routes, 1)
    return text


def main() -> None:
    if len(sys.argv) < 2:
        import openenv.core.env_server.web_interface as m
        path = Path(m.__file__).resolve()
    else:
        path = Path(sys.argv[1]).resolve()

    if not path.exists():
        print(f"File not found: {path}", file=sys.stderr)
        sys.exit(1)

    text = path.read_text()

    # 1) Add Chart.js script in head (after title)
    chart_script = '    <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>\n'
    old_head = "<title>OpenEnv Web Interface</title>\n    <style>"
    new_head = "<title>OpenEnv Web Interface</title>\n" + chart_script + "    <style>"
    if chart_script not in text and old_head in text:
        text = text.replace(old_head, new_head, 1)

    # 2) Add chart container between Current Observation and Action Logs
    old_section = """                </div>
                </div>

                <!-- Action Logs -->
                <div class="logs-container">"""
    new_section = """                </div>
                </div>

                <!-- Loss chart -->
                <div class="state-display">
                    <h3>Loss / Perplexity</h3>
                    <div id="loss-chart-container" style="height:200px;"><canvas id="loss-chart"></canvas></div>
                    <button type="button" id="run-baseline-btn" class="btn btn-secondary" style="margin-top:8px;">Run baseline (AdamW)</button>
                </div>

                <!-- Action Logs -->
                <div class="logs-container">"""
    if "loss-chart-container" not in text and old_section in text:
        text = text.replace(old_section, new_section, 1)
    # If chart container exists but button does not, add button
    if "loss-chart-container" in text and "run-baseline-btn" not in text:
        text = text.replace(
            "<canvas id=\"loss-chart\"></canvas></div>\n                </div>",
            "<canvas id=\"loss-chart\"></canvas></div>\n                    <button type=\"button\" id=\"run-baseline-btn\" class=\"btn btn-secondary\" style=\"margin-top:8px;\">Run baseline (AdamW)</button>\n                </div>",
            1,
        )

    # 3) Add updateLossChart call and method before updateChatInterface
    old_update = """                }}
            }}

            updateChatInterface(episodeState) {{"""
    new_update = """                }}
                this.updateLossChart(episodeState);
            }}

            updateLossChart(episodeState) {{
                const container = document.getElementById('loss-chart-container');
                if (!container) return;
                const steps = [];
                const losses = [];
                const perplexities = [];
                if (episodeState.current_observation && typeof episodeState.current_observation.loss === 'number') {{
                    const o = episodeState.current_observation;
                    steps.push(o.step_count != null ? o.step_count : 0);
                    losses.push(o.loss);
                    if (typeof o.perplexity === 'number') perplexities.push(o.perplexity);
                }}
                (episodeState.action_logs || []).forEach(log => {{
                    if (log.observation && typeof log.observation.loss === 'number') {{
                        steps.push(log.observation.step_count != null ? log.observation.step_count : log.step_count);
                        losses.push(log.observation.loss);
                        if (typeof log.observation.perplexity === 'number') perplexities.push(log.observation.perplexity);
                    }}
                }});
                if (steps.length === 0) return;
                const ctx = document.getElementById('loss-chart');
                if (!ctx) return;
                if (this._lossChart) this._lossChart.destroy();
                this._lossChart = new Chart(ctx, {{
                    type: 'line',
                    data: {{
                        labels: steps,
                        datasets: [
                            {{ label: 'Loss', data: losses, borderColor: '#007bff', tension: 0.2, fill: false }}
                        ].concat(perplexities.length ? [{{ label: 'Perplexity', data: perplexities, borderColor: '#28a745', tension: 0.2, fill: false }}] : [])
                    }},
                    options: {{ responsive: true, maintainAspectRatio: false, scales: {{ x: {{ title: {{ display: true, text: 'Step' }} }} }} }}
                }});
            }}

            async runBaseline() {{
                const btn = document.getElementById('run-baseline-btn');
                if (btn) btn.disabled = true;
                try {{
                    const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
                    const data = await r.json();
                    if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
                    const L = data.loss_trajectory.length;
                    const steps = data.steps && data.steps.length === L ? data.steps : Array.from({{ length: L }}, (_, i) => i);
                    const curLen = this._lossChart.data.labels.length;
                    const newLen = Math.max(curLen, steps.length);
                    const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
                    this._lossChart.data.labels = newLabels;
                    this._lossChart.data.datasets.forEach(ds => {{
                        while (ds.data.length < newLen) ds.data.push(null);
                    }});
                    const baselineData = data.loss_trajectory.slice();
                    while (baselineData.length < newLen) baselineData.push(null);
                    this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
                    this._lossChart.update();
                }} finally {{ if (btn) btn.disabled = false; }}
            }}

            updateChatInterface(episodeState) {{"""
    if "updateLossChart(episodeState)" not in text and old_update in text:
        text = text.replace(old_update, new_update, 1)

    # 3b) Add Run baseline button click listener
    old_listener = """                // State button
                document.getElementById('state-btn').addEventListener('click', () => {{
                    this.getState();
                }});
            }}"""
    new_listener = """                // State button
                document.getElementById('state-btn').addEventListener('click', () => {{
                    this.getState();
                }});

                const runBaselineBtn = document.getElementById('run-baseline-btn');
                if (runBaselineBtn) runBaselineBtn.addEventListener('click', () => this.runBaseline());
            }}"""
    if "run-baseline-btn" not in text or "runBaselineBtn.addEventListener" not in text:
        if old_listener in text:
            text = text.replace(old_listener, new_listener, 1)

    # 3c) If updateLossChart exists but runBaseline does not, insert runBaseline
    if "updateLossChart(episodeState)" in text and "async runBaseline()" not in text:
        run_baseline_method = """
            async runBaseline() {{
                const btn = document.getElementById('run-baseline-btn');
                if (btn) btn.disabled = true;
                try {{
                    const r = await fetch('/web/run-baseline', {{ method: 'POST' }});
                    const data = await r.json();
                    if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }}
                    const L = data.loss_trajectory.length;
                    const newLen = Math.max(this._lossChart.data.labels.length, L);
                    const newLabels = Array.from({{ length: newLen }}, (_, i) => i);
                    this._lossChart.data.labels = newLabels;
                    this._lossChart.data.datasets.forEach(ds => {{
                        while (ds.data.length < newLen) ds.data.push(null);
                    }});
                    const baselineData = data.loss_trajectory.slice();
                    while (baselineData.length < newLen) baselineData.push(null);
                    this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }});
                    this._lossChart.update();
                }} finally {{ if (btn) btn.disabled = false; }}
            }}
"""
        text = text.replace(
            "                }});\n            }}\n\n            updateChatInterface(episodeState) {{",
            "                }});\n            }}\n" + run_baseline_method + "\n            updateChatInterface(episodeState) {{",
            1,
        )

    # 4) Add run-baseline and current-task routes
    text = _apply_routes_patch(text)

    path.write_text(text)
    print("Patched (chart + run-baseline):", path)


if __name__ == "__main__":
    main()