| |
| """ |
| Patch OpenEnv web_interface.py to add: |
| - Loss/perplexity chart and updateLossChart() |
| - POST /web/run-baseline and GET /web/current-task for baseline comparison |
| Idempotent: safe to run multiple times. |
| """ |
| import sys |
| from pathlib import Path |
|
|
|
|
| def _apply_routes_patch(text: str) -> str: |
| """Add /web/run-baseline and /web/current-task routes.""" |
| old_routes = ( |
| ' @app.get("/web/state")\n' |
| " async def web_state():\n" |
| ' """State endpoint for web interface."""\n' |
| " return web_manager.get_state()\n" |
| "\n" |
| " return app" |
| ) |
| new_routes = ( |
| ' @app.get("/web/state")\n' |
| " async def web_state():\n" |
| ' """State endpoint for web interface."""\n' |
| " return web_manager.get_state()\n" |
| "\n" |
| ' @app.get("/web/current-task")\n' |
| " async def web_current_task():\n" |
| ' """Current task spec for baseline comparison (if env supports it)."""\n' |
| " get_spec = getattr(web_manager.env, \"get_current_task_spec\", None)\n" |
| " if get_spec is None:\n" |
| " return {}\n" |
| " return get_spec() or {}\n" |
| "\n" |
| ' @app.post("/web/run-baseline")\n' |
| " async def web_run_baseline():\n" |
| ' """Run baseline optimizer for current task; returns loss_trajectory and steps."""\n' |
| " run_bl = getattr(web_manager.env, \"run_baseline\", None)\n" |
| " if run_bl is None:\n" |
| " return {\"loss_trajectory\": [], \"steps\": [], \"error\": \"Env has no run_baseline\"}\n" |
| " return run_bl()\n" |
| "\n" |
| " return app" |
| ) |
| if "web/run-baseline" not in text and "web/state" in text and "return web_manager.get_state()" in text: |
| text = text.replace(old_routes, new_routes, 1) |
| return text |
|
|
|
|
| def main() -> None: |
| if len(sys.argv) < 2: |
| import openenv.core.env_server.web_interface as m |
| path = Path(m.__file__).resolve() |
| else: |
| path = Path(sys.argv[1]).resolve() |
|
|
| if not path.exists(): |
| print(f"File not found: {path}", file=sys.stderr) |
| sys.exit(1) |
|
|
| text = path.read_text() |
|
|
| |
| chart_script = ' <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>\n' |
| old_head = "<title>OpenEnv Web Interface</title>\n <style>" |
| new_head = "<title>OpenEnv Web Interface</title>\n" + chart_script + " <style>" |
| if chart_script not in text and old_head in text: |
| text = text.replace(old_head, new_head, 1) |
|
|
| |
| old_section = """ </div> |
| </div> |
| |
| <!-- Action Logs --> |
| <div class="logs-container">""" |
| new_section = """ </div> |
| </div> |
| |
| <!-- Loss chart --> |
| <div class="state-display"> |
| <h3>Loss / Perplexity</h3> |
| <div id="loss-chart-container" style="height:200px;"><canvas id="loss-chart"></canvas></div> |
| <button type="button" id="run-baseline-btn" class="btn btn-secondary" style="margin-top:8px;">Run baseline (AdamW)</button> |
| </div> |
| |
| <!-- Action Logs --> |
| <div class="logs-container">""" |
| if "loss-chart-container" not in text and old_section in text: |
| text = text.replace(old_section, new_section, 1) |
| |
| if "loss-chart-container" in text and "run-baseline-btn" not in text: |
| text = text.replace( |
| "<canvas id=\"loss-chart\"></canvas></div>\n </div>", |
| "<canvas id=\"loss-chart\"></canvas></div>\n <button type=\"button\" id=\"run-baseline-btn\" class=\"btn btn-secondary\" style=\"margin-top:8px;\">Run baseline (AdamW)</button>\n </div>", |
| 1, |
| ) |
|
|
| |
| old_update = """ }} |
| }} |
| |
| updateChatInterface(episodeState) {{""" |
| new_update = """ }} |
| this.updateLossChart(episodeState); |
| }} |
| |
| updateLossChart(episodeState) {{ |
| const container = document.getElementById('loss-chart-container'); |
| if (!container) return; |
| const steps = []; |
| const losses = []; |
| const perplexities = []; |
| if (episodeState.current_observation && typeof episodeState.current_observation.loss === 'number') {{ |
| const o = episodeState.current_observation; |
| steps.push(o.step_count != null ? o.step_count : 0); |
| losses.push(o.loss); |
| if (typeof o.perplexity === 'number') perplexities.push(o.perplexity); |
| }} |
| (episodeState.action_logs || []).forEach(log => {{ |
| if (log.observation && typeof log.observation.loss === 'number') {{ |
| steps.push(log.observation.step_count != null ? log.observation.step_count : log.step_count); |
| losses.push(log.observation.loss); |
| if (typeof log.observation.perplexity === 'number') perplexities.push(log.observation.perplexity); |
| }} |
| }}); |
| if (steps.length === 0) return; |
| const ctx = document.getElementById('loss-chart'); |
| if (!ctx) return; |
| if (this._lossChart) this._lossChart.destroy(); |
| this._lossChart = new Chart(ctx, {{ |
| type: 'line', |
| data: {{ |
| labels: steps, |
| datasets: [ |
| {{ label: 'Loss', data: losses, borderColor: '#007bff', tension: 0.2, fill: false }} |
| ].concat(perplexities.length ? [{{ label: 'Perplexity', data: perplexities, borderColor: '#28a745', tension: 0.2, fill: false }}] : []) |
| }}, |
| options: {{ responsive: true, maintainAspectRatio: false, scales: {{ x: {{ title: {{ display: true, text: 'Step' }} }} }} }} |
| }}); |
| }} |
| |
| async runBaseline() {{ |
| const btn = document.getElementById('run-baseline-btn'); |
| if (btn) btn.disabled = true; |
| try {{ |
| const r = await fetch('/web/run-baseline', {{ method: 'POST' }}); |
| const data = await r.json(); |
| if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }} |
| const L = data.loss_trajectory.length; |
| const steps = data.steps && data.steps.length === L ? data.steps : Array.from({{ length: L }}, (_, i) => i); |
| const curLen = this._lossChart.data.labels.length; |
| const newLen = Math.max(curLen, steps.length); |
| const newLabels = Array.from({{ length: newLen }}, (_, i) => i); |
| this._lossChart.data.labels = newLabels; |
| this._lossChart.data.datasets.forEach(ds => {{ |
| while (ds.data.length < newLen) ds.data.push(null); |
| }}); |
| const baselineData = data.loss_trajectory.slice(); |
| while (baselineData.length < newLen) baselineData.push(null); |
| this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }}); |
| this._lossChart.update(); |
| }} finally {{ if (btn) btn.disabled = false; }} |
| }} |
| |
| updateChatInterface(episodeState) {{""" |
| if "updateLossChart(episodeState)" not in text and old_update in text: |
| text = text.replace(old_update, new_update, 1) |
|
|
| |
| old_listener = """ // State button |
| document.getElementById('state-btn').addEventListener('click', () => {{ |
| this.getState(); |
| }}); |
| }}""" |
| new_listener = """ // State button |
| document.getElementById('state-btn').addEventListener('click', () => {{ |
| this.getState(); |
| }}); |
| |
| const runBaselineBtn = document.getElementById('run-baseline-btn'); |
| if (runBaselineBtn) runBaselineBtn.addEventListener('click', () => this.runBaseline()); |
| }}""" |
| if "run-baseline-btn" not in text or "runBaselineBtn.addEventListener" not in text: |
| if old_listener in text: |
| text = text.replace(old_listener, new_listener, 1) |
|
|
| |
| if "updateLossChart(episodeState)" in text and "async runBaseline()" not in text: |
| run_baseline_method = """ |
| async runBaseline() {{ |
| const btn = document.getElementById('run-baseline-btn'); |
| if (btn) btn.disabled = true; |
| try {{ |
| const r = await fetch('/web/run-baseline', {{ method: 'POST' }}); |
| const data = await r.json(); |
| if (data.error || !data.loss_trajectory || !this._lossChart) {{ if (btn) btn.disabled = false; return; }} |
| const L = data.loss_trajectory.length; |
| const newLen = Math.max(this._lossChart.data.labels.length, L); |
| const newLabels = Array.from({{ length: newLen }}, (_, i) => i); |
| this._lossChart.data.labels = newLabels; |
| this._lossChart.data.datasets.forEach(ds => {{ |
| while (ds.data.length < newLen) ds.data.push(null); |
| }}); |
| const baselineData = data.loss_trajectory.slice(); |
| while (baselineData.length < newLen) baselineData.push(null); |
| this._lossChart.data.datasets.push({{ label: 'Baseline (AdamW)', data: baselineData, borderColor: '#dc3545', tension: 0.2, fill: false }}); |
| this._lossChart.update(); |
| }} finally {{ if (btn) btn.disabled = false; }} |
| }} |
| """ |
| text = text.replace( |
| " }});\n }}\n\n updateChatInterface(episodeState) {{", |
| " }});\n }}\n" + run_baseline_method + "\n updateChatInterface(episodeState) {{", |
| 1, |
| ) |
|
|
| |
| text = _apply_routes_patch(text) |
|
|
| path.write_text(text) |
| print("Patched (chart + run-baseline):", path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|