diff --git a/.claude/commands/stream.md b/.claude/commands/stream.md new file mode 100644 index 0000000..a883a86 --- /dev/null +++ b/.claude/commands/stream.md @@ -0,0 +1,48 @@ +Read `.claude/plans/STREAMS.md` and help the user manage work streams. + +**Behavior based on args:** + +- **No args** (`/stream`): Show the roadmap table from STREAMS.md. Ask which stream to attach to, or offer to create a new one. + +- **Stream ID** (`/stream fix-error-handling`): Attach to that stream. Read its section, show plan + last log entry, start working on next unchecked item. + +- **`new`** (`/stream new`): Ask for: ID, type, priority, summary, goal. Add section + table row to STREAMS.md. + +- **`branch `** (`/stream branch feature-x`): Import an existing git branch as a stream. + 1. Run `git log --oneline main..` (or `master..`) to see commits + 2. Run `git diff --stat main..` to see affected files + 3. Propose a stream ID, type, and summary based on the branch content + 4. Ask user to confirm or adjust + 5. Create the stream section in STREAMS.md with the branch noted, committed work as checked items, and uncommitted/remaining as unchecked + 6. If the branch has uncommitted changes (check `git status`), note those in the plan too + +- **`status`** (`/stream status`): Show just the table with current statuses. + +- **`done`** (`/stream done`): Mark current stream as `done`. Add dated log entry. + +- **`release`** (`/stream release`): Set current stream to `planned` (pausing). Log what's left. + +**When attaching to a stream:** +1. Set its status to `active` in the table row +2. Note the current git branch in the Branch column +3. Read the stream's plan section +4. Tell the user what's done and what's next +5. Start working on the first unchecked `- [ ]` item + +**When finishing work (session ending or switching streams):** +1. Update checkboxes for completed items +2. Append a dated log line summarizing what was done and what's next +3. Update `STREAMS.md` — this is the only file that matters + +**New stream section format:** +``` +--- + +## +**Goal:** +**Branch:** (if from `/stream branch`) +**Files:** +- [x] Already done thing (from git log) +- [ ] Remaining thing +**Log:** — Created from branch . +``` diff --git a/.claude/plans/STREAMS.md b/.claude/plans/STREAMS.md new file mode 100644 index 0000000..409d69a --- /dev/null +++ b/.claude/plans/STREAMS.md @@ -0,0 +1,114 @@ +# Streams + +> **Protocol:** claim a stream (`status → active`, note your branch), work, +> update checkboxes + log, release when done (`→ done`) or paused (`→ planned`). +> New stream: add a section + row to table. Use `/stream` to browse/attach. +> Use `/stream branch ` to import an existing git branch as a stream. + +| ID | Type | Pri | Status | Branch | Summary | +|----|------|-----|--------|--------|---------| +| v2-stabilization | chore | p0 | active | claude_skill | Post-2.0 stabilization: maintenance fixes, demo restructure, dashboard UX, docs, coordination | +| fix-error-handling | fix | p2 | planned | — | Replace bare `except: pass` with specific types + logging | +| fix-api-consistency | refactor | p3 | planned | — | Standardize REST responses to `{status, data?, error?}` | +| fix-adapter-imports | fix | p3 | planned | — | Gate lightning/hf imports with friendly ImportError | +| feature-test-coverage | test | p1 | planned | — | Integration tests for demos, launcher, dashboard E2E | +| docs-examples-refresh | docs | p2 | planned | — | Verify examples, add notebooks | +| chore-release-prep | chore | p1 | planned | — | sdist/wheel validation, changelog, PyPI publish | + +Dependencies: `chore-release-prep` blocks on `v2-stabilization` + `feature-test-coverage` + +--- + +## v2-stabilization +**Goal:** Full post-2.0 stabilization pass — audit, fix, restructure, document. This is the `claude_skill` branch. +**Branch:** `claude_skill` (diverged from `master`, 1 commit + large uncommitted working tree) +**Scope:** +- MAINTENANCE.md P0-P2 audit and fixes (28 items fixed) +- Demo restructuring to HotKernel integration path +- Dashboard UX (controls hydration, stale data, run dir backup) +- Docs cleanup (INTEGRATION.md, concepts.md, custom_training_configs.md) +- Claude Code skill (`.claude/skills/hotcb-autopilot/`) +- Multi-agent coordination system (`.claude/plans/`, `/stream` command) + +**Done:** +- [x] MAINTENANCE.md audit — all P0 user-reported (1.1-1.6) +- [x] Frontend P1 fixes (API error handling, WS backoff, listener leaks, Three.js, tooltip colors, forecast polling, intervals) +- [x] Backend P1 fixes (malformed JSON, JSONL locking, FreezeState validation, duplicate imports, deps) +- [x] Packaging P2 (MANIFEST.in, py.typed) +- [x] Accessibility P2 (focus styles, ARIA labels) +- [x] Demo restructuring — 3 demos rewritten to HotKernel + MetricsCollector + actuators +- [x] Docs — fixed mc.log() refs, updated framework examples, deleted legacy examples/ +- [x] NaN/inf, [object Object], chart waiting, staged knob highlights, sys import +- [x] Claude Code skill (SKILL.md with 5-phase autopilot protocol) +- [x] Dashboard slider sync from WS metrics, `_slidersInitialized` +- [x] Launcher run dir backup (`_backup_run_dir_if_needed`) +- [x] Launcher JSONL truncation on start (was skip-if-exists) +- [x] `weight_decay` added to demo metrics +- [x] Multi-agent coordination (STREAMS.md + /stream command) + +**Remaining:** +- [ ] Manual verify: sliders sync on each demo config +- [ ] Manual verify: no stale timeline on restart +- [ ] Manual verify: backup dir created on re-run +- [ ] Commit all working tree changes +- [ ] PR to main + +**Log:** +- 2026-03-12: MAINTENANCE.md audit, P0-P2 fixes, demo restructuring, docs cleanup +- 2026-03-13: Dashboard UX fixes (slider sync, stale data, backup). Plans system created. 754 tests pass. + +--- + +## fix-error-handling +**Goal:** Replace silent `except Exception: pass` with specific types + `log.warning()`. +**Files:** `cli.py`, `kernel.py`, `recipe.py` (audit for others) +- [ ] Audit all bare-except sites +- [ ] Replace with specific exceptions + logging +- [ ] Verify tests pass + +--- + +## fix-api-consistency +**Goal:** Unify REST responses to `{status, data?, error?}`. Breaking change — needs frontend + SKILL.md updates. +**Files:** `api.py`, `utils.js`, `controls.js`, `init.js`, `panels.js`, `SKILL.md`, `INTEGRATION.md` +- [ ] Catalog current response shapes +- [ ] Design envelope schema +- [ ] Update backend + frontend + docs + +--- + +## fix-adapter-imports +**Goal:** Friendly error when `pytorch_lightning` / `transformers` not installed. +**Files:** `adapters/lightning.py`, `adapters/hf.py` +- [ ] Wrap imports in try/except with install instructions +- [ ] Add test for the friendly error message + +--- + +## feature-test-coverage +**Goal:** Integration tests for demo→launcher→dashboard→stop cycle. Currently 754 unit tests, zero integration. +**Files:** `src/hotcb/tests/` +- [ ] Test demo functions: run 10 steps, check metrics JSONL fields +- [ ] Test launcher lifecycle: start → status → stop → reset +- [ ] Test run dir backup: existing data → start → verify backup +- [ ] Test `/api/state/controls` returns correct values +- [ ] Test WS initial data burst + +--- + +## docs-examples-refresh +**Goal:** Verify `docs/examples/` match v2.0, add Jupyter notebooks. +**Files:** `docs/examples/*.py` +- [ ] Verify 3 existing examples run +- [ ] Create notebook using `launch()` API +- [ ] Create notebook for autopilot comparison + +--- + +## chore-release-prep +**Goal:** PyPI 2.0.0 publish. Depends on `v2-stabilization` + `feature-test-coverage`. +- [ ] Merge fix branches to main +- [ ] Build sdist + wheel, verify static files included +- [ ] Test install in fresh venv, run `hotcb demo` +- [ ] Write CHANGELOG.md +- [ ] Tag + publish diff --git a/.claude/plans/dapper-baking-spark.md b/.claude/plans/dapper-baking-spark.md new file mode 100644 index 0000000..5e43fbb --- /dev/null +++ b/.claude/plans/dapper-baking-spark.md @@ -0,0 +1,617 @@ +# Holistic Dashboard Stabilization — MutableState Redesign + +## Context + +The dashboard has drifted from hotcb's core principle: **the training framework is +sacred — hook onto it, never alter it.** The split into opt/loss/cb modules creates +hardcoded assumptions at every layer. External projects can't add custom controls +without fitting into one of 3 buckets. Ramps, terms, and custom signals get filtered +out at `kernel.py:263` because they don't match any module. + +**The fix:** Unify all mutable params into a single `MutableState` container with +per-param `HotcbActuator` instances. Each actuator has a user-provided `apply_fn`, +a metrics link for verification, and a state machine (INIT→UNTOUCHED→UNVERIFIED→ +VERIFIED→DISABLED). The user declares this on their model; the kernel discovers it +at first step; the dashboard generates controls from it. + +**Key decisions:** +- Existing opt/loss/cb modules become **thin wrappers** delegating to MutableState (backward compat) +- Verification window is **configurable per-actuator** (default 5 steps) +- Built-in `apply_fn` helpers provided for common cases (optimizer lr, wd, etc.) + +--- + +## Phase A: Foundation Bug Fixes + +Fix known bugs before any architectural changes. + +### A1: FileCursor deduplication +- `src/hotcb/modules/cb/util.py:12-35` — stale copy missing `last_size`/`truncated` +- Replace with re-exports from `hotcb.util` + +### A2: `_applied_cache` stale after reset +- `app.py:659` — add size-0 early return + `_clear_applied_cache()` function +- `launcher.py` reset endpoint calls cache clear + +### A3: `_resolve_active_run_dir` unreliable mtime +- `app.py:727` — sort by `hotcb.metrics.jsonl` mtime, not dir mtime + +### A4: Forecast gated on pinned metrics +- `charts.js:1040` — auto-select first 3 metrics if none pinned + +**Tests first** (`src/hotcb/tests/test_server_stability.py` — NEW): +``` +test_cb_util_reexports_canonical_filecursor +test_read_new_jsonl_detects_truncation +test_applied_summary_empty_file_returns_empty +test_applied_summary_cache_invalidated +test_resolve_uses_metrics_file_mtime +test_resolve_skips_backup_dirs +test_resolve_direct_metrics_returns_root +test_forecast_endpoint_returns_data +``` + +**Files:** `modules/cb/util.py`, `server/app.py`, `server/launcher.py`, `static/js/charts.js` + +--- + +## Phase B: `MutableState` + `HotcbActuator` Core + +The heart of the redesign. Defines the unified control plane data model. + +### `HotcbActuator` — 1 param ↔ 1 actuator + +```python +# src/hotcb/mutable_state.py (NEW) + +class ActuatorState(str, Enum): + INIT = "init" # declared but kernel hasn't seen it + UNTOUCHED = "untouched" # kernel discovered, populated current_value + UNVERIFIED = "unverified" # mutation applied, waiting for metrics verification + VERIFIED = "verified" # metrics confirmed change took effect + DISABLED = "disabled" # verification failed — grayed out in UI + +@dataclass +class Mutation: + step: int + old_value: Any + new_value: Any + verified: bool = False + verify_deadline: int = 0 # step by which verification must happen + +@dataclass +class HotcbActuator: + param_key: str # unique name: "lr", "weight_a", "temperature" + type: str # "float", "int", "bool", "choice" + apply_fn: Callable # user-provided: actually mutates the value + metrics_dict_name: str # key in env["metrics"] for verification + current_value: Any = None # INIT or last known + default_value: Any = None # initial/reset value + last_changed: int = -1 # step of last verified change (-1 = never) + state: ActuatorState = ActuatorState.INIT + mutations: List[Mutation] = field(default_factory=list) + verification_window: int = 5 # steps to wait before DISABLED + bounds: tuple = None # (min, max) for float/int + choices: list = None # for "choice" type + scale: str = "linear" # "linear" or "log10" for UI slider + label: str = "" # display label (defaults to param_key) + group: str = "" # visual grouping hint + + def to_spec(self) -> dict: + """Serialize for /api/controls — dashboard reads this.""" + return { + "param_key": self.param_key, + "label": self.label or self.param_key, + "type": self.type, + "scale": self.scale, + "bounds": list(self.bounds) if self.bounds else None, + "choices": self.choices, + "current_value": self.current_value, + "default_value": self.default_value, + "state": self.state.value, + "group": self.group, + } +``` + +### `MutableState` container + +```python +class MutableState: + actuators: Dict[str, HotcbActuator] # keyed by param_key + + def init_datastructures(self, metrics: dict): + """Called by kernel at first step. Verify metrics links, populate values.""" + for act in self.actuators.values(): + if act.metrics_dict_name in metrics: + act.current_value = metrics[act.metrics_dict_name] + act.state = ActuatorState.UNTOUCHED + # If metrics_dict_name not found, stay INIT (warn but don't disable yet) + + def change(self, param_key: str, new_value: Any, at_step: int): + """Apply mutation via actuator's apply_fn. Mark UNVERIFIED.""" + act = self.actuators[param_key] + if act.state == ActuatorState.DISABLED: + raise ValueError(f"Actuator {param_key} is disabled") + old = act.current_value + act.apply_fn(new_value) + act.mutations.append(Mutation( + step=at_step, old_value=old, new_value=new_value, + verify_deadline=at_step + act.verification_window + )) + act.current_value = new_value + act.state = ActuatorState.UNVERIFIED + + def verify_pending(self, metrics: dict, current_step: int): + """Called each step by kernel. Check metrics for verification.""" + for act in self.actuators.values(): + if act.state != ActuatorState.UNVERIFIED: + continue + pending = [m for m in act.mutations if not m.verified] + if not pending: + continue + latest = pending[-1] + metric_val = metrics.get(act.metrics_dict_name) + if metric_val is not None and metric_val != latest.old_value: + latest.verified = True + act.state = ActuatorState.VERIFIED + act.last_changed = current_step + elif current_step >= latest.verify_deadline: + act.state = ActuatorState.DISABLED + + def snapshot(self) -> dict: + """For rollback / recipe export.""" + return {k: act.current_value for k, act in self.actuators.items()} + + def get_specs(self) -> list: + """For dashboard /api/controls.""" + return [act.to_spec() for act in self.actuators.values()] +``` + +### Built-in `apply_fn` helpers + +```python +# src/hotcb/mutable_state.py — convenience functions + +def apply_optimizer_lr(optimizer, param_group_idx=0): + """Returns an apply_fn for optimizer learning rate.""" + def _apply(value): + for pg in optimizer.param_groups: + pg["lr"] = value + return _apply + +def apply_optimizer_wd(optimizer): + def _apply(value): + for pg in optimizer.param_groups: + pg["weight_decay"] = value + return _apply + +def apply_dict_key(target_dict, key): + """Returns an apply_fn for a dict key (e.g., loss weights).""" + def _apply(value): + target_dict[key] = value + return _apply + +def apply_attr(obj, attr_name): + """Returns an apply_fn for an object attribute.""" + def _apply(value): + setattr(obj, attr_name, value) + return _apply +``` + +### User-facing factory + +```python +def mutable_state(actuators: List[HotcbActuator]) -> MutableState: + """Create a MutableState from a list of actuators.""" + ms = MutableState() + ms.actuators = {a.param_key: a for a in actuators} + return ms +``` + +### User integration example + +```python +class MyModel(pl.LightningModule): + def __init__(self): + super().__init__() + # ... model layers ... + self.mutable_state = hotcb.mutable_state([ + HotcbActuator( + param_key="lr", type="float", scale="log10", + apply_fn=hotcb.apply_optimizer_lr(self.optimizer), + metrics_dict_name="lr", + bounds=(1e-7, 1.0), default_value=3e-4, + ), + HotcbActuator( + param_key="temperature", type="float", + apply_fn=hotcb.apply_attr(self, "temperature"), + metrics_dict_name="temperature", + bounds=(0.1, 10.0), default_value=1.0, + verification_window=10, # slow metric, give more time + ), + ]) +``` + +**Tests first** (`src/hotcb/tests/test_mutable_state.py` — NEW): +``` +test_actuator_state_machine_init_to_untouched +test_actuator_state_machine_untouched_to_unverified +test_actuator_state_machine_verified_on_metric_change +test_actuator_state_machine_disabled_on_timeout +test_mutable_state_init_from_metrics +test_mutable_state_change_calls_apply_fn +test_mutable_state_verify_pending_marks_verified +test_mutable_state_verify_pending_marks_disabled +test_mutable_state_snapshot +test_mutable_state_get_specs +test_builtin_apply_optimizer_lr +test_builtin_apply_optimizer_wd +test_builtin_apply_dict_key +test_builtin_apply_attr +test_mutable_state_disabled_rejects_change +test_actuator_custom_verification_window +test_mutable_state_multiple_mutations_same_param +``` + +**Files:** `src/hotcb/mutable_state.py` (NEW) + +--- + +## Phase C: Kernel Integration + Module Wrappers + +Wire `MutableState` into the kernel and make opt/loss/cb thin wrappers. + +### Kernel changes (`kernel.py`) + +1. **Discovery at first step:** In `apply()`, after first `_should_poll()`, check + `env` for `mutable_state` attribute on model/trainer objects. If found, store + reference and call `mutable_state.init_datastructures(env.get("metrics", {}))`. + +2. **Default module for MutableState commands:** At `kernel.py:263`, instead of + `unknown_module` error, route to MutableState if param_key matches a known actuator: + ```python + # After existing module lookup fails: + if self._mutable_state and op.params: + for key in op.params: + if key in self._mutable_state.actuators: + # Route through MutableState + self._mutable_state.change(key, op.params[key], current_step) + self._write_ledger(op, event, step, decision="applied", ...) + return + ``` + +3. **Verification each step:** After applying ops, call + `self._mutable_state.verify_pending(env.get("metrics", {}), current_step)` + +4. **Controls endpoint data source:** Add `kernel.get_control_specs()` that returns + `self._mutable_state.get_specs()` for the dashboard. + +### Thin wrappers for backward compat + +**opt module** (`modules/opt.py`): +- `apply_op()` still works for `{"module": "opt", "op": "set_params", "params": {"lr": 0.001}}` +- Internally delegates: if MutableState has an actuator for "lr", calls + `mutable_state.change("lr", 0.001, step)` instead of direct optimizer mutation +- If no MutableState, falls back to current direct optimizer mutation (legacy path) + +**loss module** (`modules/loss.py`): +- Same pattern: delegates to MutableState for known keys, falls back to direct + mutable_state dict mutation for legacy projects + +**cb module** (`modules/cb/`): +- Unchanged — callbacks are a different beast (load/unload/enable/disable). + But `set_params` can optionally delegate to MutableState if the callback + registered its params there. + +### Command format + +Dashboard sends: `POST /api/controls/apply {"changes": {"lr": 0.001, "temperature": 0.5}}` + +Server writes to `hotcb.commands.jsonl`: +```json +{"module": "mutable", "op": "set_params", "params": {"lr": 0.001, "temperature": 0.5}} +``` + +Kernel routes to MutableState default module handler. + +**Tests first** (`src/hotcb/tests/test_kernel_mutable.py` — NEW): +``` +test_kernel_discovers_mutable_state_at_first_step +test_kernel_routes_mutable_command +test_kernel_verifies_pending_each_step +test_kernel_opt_wrapper_delegates_to_mutable_state +test_kernel_loss_wrapper_delegates_to_mutable_state +test_kernel_legacy_opt_without_mutable_state +test_kernel_legacy_loss_without_mutable_state +test_kernel_unknown_param_key_fails_gracefully +test_kernel_disabled_actuator_rejects_command +test_kernel_get_control_specs +``` + +**Files:** `kernel.py`, `modules/opt.py`, `modules/loss.py` + +--- + +## Phase D: DashboardConfig + `/api/config` + `/api/controls` + +### DashboardConfig (`src/hotcb/server/config.py` — NEW) + +```python +@dataclass +class ServerConfig: + host: str = "0.0.0.0" + port: int = 8421 + poll_interval: float = 0.5 + ws_initial_burst: int = 500 + +@dataclass +class ChartConfig: + max_render_points: int = 2000 + line_tension: float = 0.15 + +@dataclass +class UIConfig: + state_save_interval: int = 5000 + alert_poll_interval: int = 15000 + forecast_poll_interval: int = 10000 + forecast_step_cadence: int = 20 + staged_change_threshold: float = 0.005 + +@dataclass +class DashboardConfig: + server: ServerConfig + chart: ChartConfig + ui: UIConfig + run_dir: str = "" + demo_mode: bool = False +``` + +### API endpoints + +- `GET /api/config` — returns `DashboardConfig` as JSON (frontend stores in `S.config`) +- `GET /api/controls` — returns `mutable_state.get_specs()` (actuator list with types, bounds, state) +- `POST /api/controls/apply` — accepts `{"changes": {key: value}}`, writes to commands.jsonl + +**Tests first** (`src/hotcb/tests/test_dashboard_config.py` — NEW): +``` +test_dashboard_config_defaults +test_config_endpoint_returns_full +test_controls_endpoint_returns_actuator_specs +test_controls_endpoint_without_mutable_state (returns empty/defaults) +test_controls_apply_writes_command +test_controls_apply_diff_only (0.5% threshold) +``` + +**Files:** `server/config.py` (NEW), `server/app.py`, `server/api.py` + +--- + +## Phase E: Dynamic Control Generation (Frontend) + +Replace hardcoded knob HTML with type-based templates generated from `/api/controls`. + +### Templates per actuator type + +| Actuator type | UI Template | +|---------------|-------------| +| `float + log10` | Range slider with log transform + exponential display | +| `float + linear` | Range slider with direct value + decimal display | +| `int` | Range slider with integer step | +| `bool` | Toggle switch | +| `choice` | Select dropdown | + +### Changes + +**`index.html`:** +- Remove hardcoded knob rows (~lines 360-408) +- Keep: `
` + Apply/Schedule buttons +- Remove CSS classes: `single-loss-only`, `multitask-only`, `finetune-only` + +**`controls.js`:** +- Add `buildControlsFromSpecs(specs)` — iterates specs, calls `buildKnobRow(spec)` +- Add `buildKnobRow(spec)` — creates slider/toggle/dropdown based on `spec.type` +- Each row shows actuator state: normal=VERIFIED, spinner=UNVERIFIED, grayed=DISABLED +- Apply handler: read all `[data-param]` inputs, diff against applied values, + POST to `/api/controls/apply` with changed params only +- Remove `_trainConfigDefaults`, `_updateConfigControls()`, all module-specific routing +- `demo_mode === false` → hide Training config dropdown row only (not entire card) + +**`init.js`:** +- Fetch `S.config = await api('GET', '/api/config')` at startup +- Fetch controls: `var specs = await api('GET', '/api/controls'); buildControlsFromSpecs(specs);` + +**`state.js`:** +- Add `S.config = null` + +**Actuator state in UI:** +- INIT: not shown (controls buffered until discovered) +- UNTOUCHED: normal slider, label shows "ready" +- UNVERIFIED: subtle pulse/spinner on the control row +- VERIFIED: solid, normal appearance +- DISABLED: grayed out, tooltip "Verification failed — param may not be mutable" + +**Tests (manual verification checklist):** +``` +hotcb demo → lr, wd, loss_w generated dynamically, all VERIFIED +hotcb demo --golden → lr, wd, weight_a, weight_b generated +hotcb serve --dir → controls from MutableState (or defaults) +Apply lr change → slider shows UNVERIFIED → VERIFIED after metric confirms +Apply broken param → shows UNVERIFIED → DISABLED after timeout +``` + +**Files:** `static/index.html`, `static/js/controls.js`, `static/js/init.js`, +`static/js/state.js`, `static/css/dashboard.css` + +--- + +## Phase F: Launcher Simplification + +**Principle:** Train function accepts minimal args. `max_steps`/`step_delay` are demo concerns. + +Use `inspect.signature` to detect arity: +- 0 args: `fn()` — fully external +- 1 arg: `fn(stop_event)` — respects stop signal +- 2 args: `fn(run_dir, stop_event)` — needs IPC dir +- 4 args: `fn(run_dir, max_steps, step_delay, stop_event)` — demo contract + +**Tests first:** +``` +test_launch_zero_arg_fn +test_launch_one_arg_fn +test_launch_two_arg_fn +test_launch_four_arg_backward_compat +``` + +**Files:** `launch.py`, `server/launcher.py` + +--- + +## Phase G: Metric UI Fixes + +### G1: Tooltip = non-interactive color sphere +Chart tooltip shows filled circle with `pointer-events: none`. Read-only color ref. + +### G2: Explicit pin button in metrics dropdown +Each dropdown item: `[dot] [name] [...] [pin-btn]` +- Dot: click toggles visibility (filled/hollow) +- Pin button: always visible, outline pushpin when unpinned, filled when pinned +- Click pin → `toggleMetricCard(name)` + +### G3: Persistent end-of-run summary pane +- Add "Summary" tab to left column tabs (index.html) +- On run complete: auto-save summary to `hotcb.run.summary.json`, populate Summary tab +- Tab persists — always accessible, not a dismissible popup +- `GET /api/run/summary` endpoint + +**Tests first:** +``` +test_run_summary_auto_saved +test_run_summary_endpoint +``` + +**Files:** `charts.js`, `panels.js`, `index.html`, `dashboard.css`, `app.py`, `launcher.py` + +--- + +## Phase H: External Golden Demo + Comprehensive Tests + +### External golden demo (`src/hotcb/tests/external_golden_demo.py` — NEW) + +Self-contained script that behaves exactly like an external project: +1. Isolated temp dir +2. Creates model with `mutable_state = hotcb.mutable_state([...])` +3. Registers custom actuators: `"alpha"`, `"beta"`, `"temperature"` (not standard names) +4. Uses built-in `apply_fn` helpers where applicable +5. Runs 50 steps with HotKernel +6. Returns run_dir for test verification + +### Integration tests (`src/hotcb/tests/test_external_integration.py` — NEW) + +``` +# Discovery +test_external_controls_discovered — /api/controls returns custom actuators +test_external_controls_have_bounds — specs have correct bounds +test_external_controls_show_state — UNTOUCHED initially, VERIFIED after change + +# Command flow +test_external_apply_custom_control — POST apply with alpha=0.5 +test_external_apply_triggers_verification — state goes UNVERIFIED → VERIFIED +test_external_apply_broken_param_disables — bad metrics_dict_name → DISABLED + +# Dashboard +test_external_demo_mode_false — Training config hidden, controls visible +test_external_capabilities_detected — capabilities endpoint works + +# Randomized +test_controls_random_float_names — random names +test_controls_random_types — mix of float, int, bool, choice +test_controls_random_bounds — random bound ranges + +# Both demo variants +test_builtin_demo_controls_match — built-in demo has expected controls +test_external_demo_controls_match — external demo has expected controls +test_both_demos_same_kernel_behavior — same apply/verify flow +``` + +--- + +## Phase I: Magic Number Centralization + +Replace all hardcoded constants with `DashboardConfig`/`S.config` references. + +**Frontend → `S.config.*`:** +- `_maxRenderPoints`, `tension`, dash patterns, poll intervals, thresholds, batch sizes + +**Backend → `config.*`:** +- History limits, WS burst, poll intervals, record estimates, cadence thresholds + +**Files:** All JS files, `app.py`, `tailer.py`, `ai_engine.py` + +--- + +## Phase Dependencies + +``` +A ─── Foundation fixes +│ +B ─── MutableState + HotcbActuator core (the heart) +│ +C ─── Kernel integration + module wrappers (depends on B) +│ +├── D ─── DashboardConfig + /api/config + /api/controls (depends on C) +│ │ +│ └── E ─── Dynamic frontend (depends on D) +│ +├── F ─── Launcher simplification (parallel with D/E) +│ +├── G ─── Metric UI fixes (parallel with D/E/F) +│ +└── H ─── External golden demo + integration tests (validates A-G) + │ + I ─── Magic number centralization (last) +``` + +## Files Summary + +| File | Phases | Action | +|------|--------|--------| +| `src/hotcb/mutable_state.py` | B | NEW — HotcbActuator, MutableState, apply_fn helpers | +| `src/hotcb/kernel.py` | C | MutableState discovery, routing, verification | +| `src/hotcb/modules/opt.py` | C | Thin wrapper delegating to MutableState | +| `src/hotcb/modules/loss.py` | C | Thin wrapper delegating to MutableState | +| `src/hotcb/modules/cb/util.py` | A | Remove dupe FileCursor | +| `src/hotcb/server/config.py` | D | NEW — DashboardConfig | +| `src/hotcb/server/app.py` | A,D,G | Fix cache/mtime, add /api/config, /api/controls, /api/run/summary | +| `src/hotcb/server/api.py` | D | Add /api/controls/apply | +| `src/hotcb/server/launcher.py` | A,F,G | Cache clear, signature detect, summary save | +| `src/hotcb/launch.py` | F | Signature detection | +| `static/index.html` | E,G | Remove hardcoded knobs, add Summary tab | +| `static/js/controls.js` | E | buildControlsFromSpecs(), type templates | +| `static/js/charts.js` | A,G | Forecast fix, pin button, tooltip sphere | +| `static/js/init.js` | D,E | Fetch config + controls | +| `static/js/state.js` | D | Add S.config | +| `static/js/panels.js` | G | Summary tab, replace overlay | +| `static/css/dashboard.css` | E,G | Control templates, pin button, summary tab | +| `tests/test_server_stability.py` | A | NEW — bug fix tests | +| `tests/test_mutable_state.py` | B | NEW — state machine + actuator tests | +| `tests/test_kernel_mutable.py` | C | NEW — kernel integration tests | +| `tests/test_dashboard_config.py` | D | NEW — config + controls endpoint tests | +| `tests/external_golden_demo.py` | H | NEW — external project simulation | +| `tests/test_external_integration.py` | H | NEW — comprehensive integration tests | + +## Verification + +```bash +# Every phase: +pytest src/hotcb/tests/ -x -q --no-cov + +# After Phase C: +# Existing opt/loss tests still pass (backward compat wrappers) + +# After Phase E (manual): +hotcb demo → controls generated from MutableState, state indicators work +hotcb serve --dir → custom controls discovered, verification visible + +# After Phase H: +pytest src/hotcb/tests/test_external_integration.py -v # full external flow +``` diff --git a/.claude/plans/unified-actuator-config.md b/.claude/plans/unified-actuator-config.md new file mode 100644 index 0000000..7623016 --- /dev/null +++ b/.claude/plans/unified-actuator-config.md @@ -0,0 +1,971 @@ +# Unified Actuator Model + Config Refactor + +## Context + +Two planned efforts — the **MutableState/HotcbActuator redesign** and the +**principled config refactor** (`dev-notes/principled_config_refactor.md`) — share a +critical overlap at the controls layer. If done separately, config refactor Phase 4 +(dynamic controls from `describe_space()`) would be immediately rewritten when the +actuator redesign changes the data source. Merging them eliminates throwaway work and +produces a cleaner result. + +**No backward compatibility needed** — opt/loss module wire format is not in production. + +--- + +## What We're Fixing + +### Problem 1: Three-module duplication + +`kernel.py:66-75` hard-codes `{cb, opt, loss, tune}`. `register_actuator()` (lines 89-102) +has `if name == "opt"` / `elif name == "loss"` wiring. `_apply_single()` at line 263 +rejects anything that isn't one of these four with `unknown_module` error. + +`HotOptController` and `HotLossController` duplicate the same pattern: +resolve target → optional actuator validation → apply params → handle errors → +enable/disable state. Each has its own `_resolve_*()`, `set_actuator()`, +`_actuator_patches()` translation layer. + +### Problem 2: Filtering / masquerading + +`MutableStateActuator.snapshot()` only backs up `weights` — terms and ramps have no +rollback. `HotLossController._actuator_weight_patches()` explicitly skips `terms.*` +and `ramps.*`, so they bypass actuator validation entirely. `describe_space()` doesn't +include them, so tune can't optimize them. + +Any external control that isn't a float weight gets shoved into an unvalidated bucket. +The user's additional ramps/terms are "masqueraded as loss state and getting filtered +out at some connections." + +### Problem 3: Dashboard hardcoded controls + +The UI has hardcoded slider names (`knobLr`, `knobWd`, `knobLossW`, etc.) and CSS +visibility classes (`single-loss-only`, `multitask-only`, `finetune-only`). Meanwhile +the core already has `TrainingCapabilities` and `describe_space()` — the dashboard +ignores both and hardcodes its own schema. + +### Problem 4: Magic numbers everywhere + +Poll intervals, history limits, thresholds, pixel sizes, batch sizes — bare literals +scattered across JS and Python files. No central source of truth. + +--- + +## Design + +### A. `HotcbActuator` — The Unified Parameter Handle + +Every controllable scalar/toggle/choice becomes one `HotcbActuator`. Replaces the +current split between `OptimizerActuator` (all optimizer params in one), +`MutableStateActuator` (all loss weights in one), and the module controllers +(`HotOptController`, `HotLossController`). + +```python +# src/hotcb/actuators/actuator.py + +class ActuatorType(Enum): + BOOL = "bool" + FLOAT = "float" + INT = "int" + CHOICE = "choice" # discrete set of allowed values + LOG_FLOAT = "log_float" # float on log scale (lr, wd) + TUPLE = "tuple" # e.g. betas + +class ActuatorState(Enum): + INIT = "init" # registered but not yet observed + UNTOUCHED = "untouched" # observed initial value, no mutations applied + UNVERIFIED = "unverified" # mutation applied, not yet confirmed via metrics + VERIFIED = "verified" # mutation confirmed via metrics_dict_name + DISABLED = "disabled" # user-disabled or auto-disabled on error + +@dataclass +class Mutation: + step: int + old_value: Any + new_value: Any + verified: bool = False + +@dataclass +class HotcbActuator: + """Single controllable parameter.""" + param_key: str # unique key, e.g. "lr", "recon_w", "use_augment" + type: ActuatorType # drives UI control type + apply_fn: Callable[[Any, dict], ApplyResult] # (value, env) -> result + metrics_dict_name: str = "" # metric name for verification (empty = no verification) + label: str = "" # display label, defaults to param_key + group: str = "" # UI grouping hint ("optimizer", "loss", "custom") + + # Bounds (for FLOAT, LOG_FLOAT, INT) + min_value: Optional[float] = None + max_value: Optional[float] = None + step_size: Optional[float] = None + log_base: float = 10.0 # for LOG_FLOAT + + # For CHOICE type + choices: Optional[list] = None + + # Mutable runtime state + current_value: Any = field(default=_INIT_SENTINEL) + state: ActuatorState = field(default=ActuatorState.INIT) + last_changed_step: int = -1 + mutations: list[Mutation] = field(default_factory=list) + + def validate(self, value: Any) -> ValidationResult: + """Type-check and bounds-check a proposed value.""" + ... + + def snapshot(self) -> dict: + """Return state for rollback.""" + return {"value": self.current_value, "state": self.state} + + def restore(self, snapshot: dict, env: dict) -> ApplyResult: + """Rollback to snapshot.""" + return self.apply_fn(snapshot["value"], env) + + def describe_space(self) -> dict: + """Return schema for tune search + UI generation.""" + return { + "param_key": self.param_key, + "type": self.type.value, + "label": self.label or self.param_key, + "group": self.group, + "min": self.min_value, + "max": self.max_value, + "step": self.step_size, + "log_base": self.log_base, + "choices": self.choices, + "current": self.current_value, + "state": self.state.value, + } +``` + +### B. `MutableState` — The Container + +```python +# src/hotcb/actuators/state.py + +class MutableState: + """Container of HotcbActuator instances. This is the user-facing API.""" + + def __init__(self, actuators: list[HotcbActuator]): + self._actuators: dict[str, HotcbActuator] = {a.param_key: a for a in actuators} + + def get(self, key: str) -> Optional[HotcbActuator]: + return self._actuators.get(key) + + def keys(self) -> list[str]: + return list(self._actuators.keys()) + + def apply(self, key: str, value: Any, env: dict, step: int) -> ApplyResult: + """Validate, apply, record mutation, transition state.""" + act = self._actuators.get(key) + if act is None: + return ApplyResult(success=False, error=f"unknown_param:{key}") + vr = act.validate(value) + if not vr.valid: + return ApplyResult(success=False, error="; ".join(vr.errors)) + old = act.current_value + result = act.apply_fn(value, env) + if result.success: + act.current_value = value + act.mutations.append(Mutation(step=step, old_value=old, new_value=value)) + act.last_changed_step = step + act.state = ActuatorState.UNVERIFIED + return result + + def initialize(self, env: dict) -> None: + """Read current values from live objects at first step. + Transitions all actuators INIT → UNTOUCHED.""" + ... + + def verify(self, key: str, metrics: dict) -> bool: + """Check metrics_dict_name in latest metrics. UNVERIFIED → VERIFIED if match.""" + ... + + def snapshot_all(self) -> dict: + """Snapshot all actuators for rollback.""" + ... + + def describe_all(self) -> list[dict]: + """Return describe_space() for all actuators. Used by config endpoint + tune.""" + ... +``` + +### C. Convenience Constructors + +```python +# src/hotcb/actuators/__init__.py + +def optimizer_actuators(optimizer, lr_bounds=(1e-7, 1.0), wd_bounds=(0, 1.0)) -> list[HotcbActuator]: + """Auto-create actuators for lr, weight_decay, betas, eps from a torch optimizer.""" + ... + +def loss_actuators(loss_weights: dict, global_bounds=(0.0, 100.0)) -> list[HotcbActuator]: + """Auto-create actuators from a dict of loss weight names → values.""" + ... + +def mutable_state(actuators: list[HotcbActuator]) -> MutableState: + """User-facing constructor.""" + return MutableState(actuators) +``` + +### D. Kernel Default Stream + +`kernel.py:263-266` becomes: + +```python +def _apply_single(self, op, env, event, step): + # ... freeze enforcement ... + + if op.module == "core": + ... # freeze/recipe ops + return + + if op.module == "cb": + ... # cb stays special — code lifecycle + return + + if op.module == "tune": + ... # tune stays special — search orchestrator + return + + # DEFAULT STREAM: opt, loss, or any custom param_key + # Route through MutableState + if self._mutable_state is not None: + key = self._resolve_param_key(op) # from op.params.key, op.target, or op.id + result = self._mutable_state.apply(key, op.params.get("value"), env, step) + self._write_ledger(op, event, step, + decision=result.decision, error=result.error, + payload=op.to_dict(), env=env) + else: + self._write_ledger(op, event, step, + decision="failed", error="no_mutable_state", + payload=op.to_dict(), env=env) +``` + +The `op.module` field is preserved in commands/ledger for grouping and UI display, +but it no longer determines which code path executes. Everything goes through the +same `MutableState.apply()` → `HotcbActuator.apply_fn()` pipeline. + +### E. `DashboardConfig` — Centralized Configuration + +```python +# src/hotcb/server/config.py + +@dataclass(frozen=True) +class ServerConfig: + host: str = "0.0.0.0" + port: int = 8421 + poll_interval: float = 0.5 + history_limit_metrics: int = 500 + history_limit_applied: int = 200 + ws_initial_burst: int = 200 + ws_max_retries: int = 20 + ws_retry_base: float = 3.0 + ws_retry_cap: float = 30.0 + +@dataclass(frozen=True) +class ChartConfig: + max_render_points: int = 2000 + line_tension: float = 0.15 + forecast_dash: tuple = (6, 3) + mutation_dash: tuple = (3, 4) + annotation_stagger_rows: int = 10 + annotation_min_distance: int = 70 + +@dataclass(frozen=True) +class AutopilotConfig: + divergence_threshold: float = 2.0 + ratio_threshold: float = 0.5 + ai_min_interval: int = 10 + ai_max_wait: int = 200 + ai_default_cadence: int = 50 + +@dataclass(frozen=True) +class UIConfig: + state_save_interval: int = 5000 + alert_poll_interval: int = 15000 + manifold_refresh_interval: int = 10000 + recipe_refresh_interval: int = 5000 + forecast_poll_interval: int = 5000 + forecast_step_cadence: int = 10 + forecast_batch_size: int = 8 + staged_change_threshold: float = 0.005 + health_ema_alpha: float = 0.1 + +@dataclass +class DashboardConfig: + server: ServerConfig = field(default_factory=ServerConfig) + chart: ChartConfig = field(default_factory=ChartConfig) + autopilot: AutopilotConfig = field(default_factory=AutopilotConfig) + ui: UIConfig = field(default_factory=UIConfig) + run_dir: str = "" # IMMUTABLE after startup + controls: list[dict] = field(default_factory=list) # populated from MutableState.describe_all() + + @classmethod + def load(cls, run_dir: str, yaml_path: Optional[str] = None, **cli_overrides) -> "DashboardConfig": + """Resolve: defaults → YAML → CLI → env vars → actuator discovery.""" + ... + + def to_dict(self) -> dict: + """Serialize for /api/config endpoint.""" + ... +``` + +The `controls` field is populated at startup from `MutableState.describe_all()` when +capabilities are available, or left empty for observe-only dashboards. + +### F. Immutable `run_dir` + +`run_dir` is set once at `create_app()` and never changes. Remove: +- `app.state.run_dir` mutation in launcher +- `_rewire_dir()` helper +- `tailer.rewire()` method + +Launcher writes to `config.run_dir` directly. For multi-run, launcher returns the new +path and the user restarts `hotcb serve` pointing at it. Compare tab reads sibling +dirs read-only via `/api/runs/discover`. + +### G. Dynamic Frontend Controls + +```javascript +// controls.js — replaces hardcoded HTML +function buildControls(controlSpecs) { + var panel = $('#knobPanel'); + panel.innerHTML = ''; + controlSpecs.forEach(function(spec) { + // spec = {param_key, type, label, group, min, max, step, current, state} + panel.appendChild(buildKnobRow(spec)); + }); +} + +function buildKnobRow(spec) { + // Generate slider/toggle/dropdown based on spec.type + // "log_float" → log-scale slider + // "float" → linear slider + // "bool" → toggle switch + // "choice" → dropdown + // "int" → integer stepper +} +``` + +No more `single-loss-only`, `multitask-only`, `finetune-only` CSS classes. + +--- + +## What Gets Deleted + +| File/Code | Reason | +|-----------|--------| +| `src/hotcb/modules/opt.py` | Absorbed into `HotcbActuator` + `optimizer_actuators()` | +| `src/hotcb/modules/loss.py` | Absorbed into `HotcbActuator` + `loss_actuators()` | +| `src/hotcb/actuators/optimizer.py` | Replaced by per-param `HotcbActuator` instances | +| `src/hotcb/actuators/mutable_state.py` | Replaced by per-param `HotcbActuator` instances | +| `kernel.py` hard-coded `opt`/`loss` module init | Default stream handles all | +| `kernel.py` `register_actuator()` `if name == "opt"` wiring | No more module↔actuator coupling | +| `capabilities.py` `validate_mutable_state()` | `MutableState` replaces the raw dict | +| `tailer.py` `rewire()` method | `run_dir` immutable | +| `app.py` `_rewire_dir()` helper | `run_dir` immutable | +| `app.py` `app.state.run_dir` mutation | `run_dir` immutable | +| `index.html` hardcoded knob rows | Dynamic generation from config | +| `controls.js` hardcoded slider names | Dynamic generation from config | +| CSS `single-loss-only`, `multitask-only`, `finetune-only` | Gone entirely | + +--- + +## What Stays + +| Component | Why | +|-----------|-----| +| `modules/cb/` | Code lifecycle management is fundamentally different from scalar params | +| `modules/tune/` | Search orchestrator — consumes actuators, doesn't become one | +| `actuators/base.py` `BaseActuator` Protocol | `HotcbActuator` implements a superset of it | +| `HotOp` and `command_to_hotop()` | Command format is orthogonal; `op.module` becomes metadata | +| `TrainingCapabilities` | Still useful for framework-level info (num_optimizers, has_scheduler, etc.) | +| JSONL filesystem IPC | Untouched | +| Adapter layer | Still populates capabilities; now also creates `MutableState` | +| `modules/result.py` `ModuleResult` | Still used by cb and tune modules | + +--- + +## Implementation Phases + +### Phase 1: `DashboardConfig` Foundation +**Goal:** Config dataclass exists, loads from YAML/env, serves at `/api/config`. +Frontend fetches once at startup. + +**New files:** +- `src/hotcb/server/config.py` +- `src/hotcb/tests/test_dashboard_config.py` + +**Modified files:** +- `src/hotcb/server/app.py` — add `/api/config` endpoint, construct config at startup +- `src/hotcb/server/static/js/state.js` — add `S.config` +- `src/hotcb/server/static/js/init.js` — fetch config before other init + +**Tests (write first):** +``` +test_config_defaults() + ServerConfig(), ChartConfig(), etc. have documented defaults. + DashboardConfig() is valid with all defaults. + +test_config_from_yaml() + Write a YAML file with server.port=9000, chart.line_tension=0.3. + Load → verify overrides applied, other defaults preserved. + +test_config_from_yaml_missing_file() + Load with nonexistent YAML → all defaults, no error. + +test_config_from_env() + Set HOTCB_PORT=9000, HOTCB_POLL_INTERVAL=1.0 in env. + Load → verify env overrides applied. + +test_config_env_overrides_yaml() + YAML sets port=9000, env sets HOTCB_PORT=8000. + Load → port is 8000 (env wins). + +test_config_cli_overrides_all() + YAML + env + cli_overrides={port: 7000}. + Load → port is 7000 (CLI wins). + +test_config_to_dict_roundtrip() + config.to_dict() → JSON serializable. + All nested sub-configs appear. + +test_config_run_dir_in_dict() + config = DashboardConfig(run_dir="/tmp/x") + d = config.to_dict() + assert d["run_dir"] == "/tmp/x" + +test_config_endpoint_returns_full(client) + GET /api/config → 200, body has server, chart, autopilot, ui, run_dir keys. + +test_config_endpoint_reflects_overrides(client) + App created with yaml overriding port. + GET /api/config → server.port matches override. +``` + +**Phase 1 does NOT touch:** controls, actuators, kernel, modules. + +--- + +### Phase 2: `HotcbActuator` + `MutableState` +**Goal:** New actuator types exist, state machine works, convenience constructors +produce correct actuators from optimizer/loss dicts. + +**New files:** +- `src/hotcb/actuators/actuator.py` — `HotcbActuator`, `ActuatorType`, `ActuatorState`, `Mutation` +- `src/hotcb/actuators/state.py` — `MutableState` container +- `src/hotcb/tests/test_actuator_unified.py` + +**Modified files:** +- `src/hotcb/actuators/__init__.py` — export `mutable_state()`, `optimizer_actuators()`, `loss_actuators()` + +**Tests (write first):** +``` +--- ActuatorType & validation --- + +test_float_actuator_validate_in_bounds() + HotcbActuator(type=FLOAT, min=0, max=1). validate(0.5) → valid. + +test_float_actuator_validate_out_of_bounds() + validate(1.5) → invalid, error mentions bounds. + +test_log_float_actuator_validate() + HotcbActuator(type=LOG_FLOAT, min=1e-7, max=1.0). validate(1e-4) → valid. + validate(-1.0) → invalid. + +test_bool_actuator_validate() + HotcbActuator(type=BOOL). validate(True) → valid. validate("yes") → invalid. + +test_int_actuator_validate() + HotcbActuator(type=INT, min=0, max=100). validate(50) → valid. + validate(50.5) → invalid (not int). + +test_choice_actuator_validate() + HotcbActuator(type=CHOICE, choices=["adam", "sgd", "adamw"]). + validate("adam") → valid. validate("rmsprop") → invalid. + +test_tuple_actuator_validate() + HotcbActuator(type=TUPLE). validate((0.9, 0.999)) → valid. + validate("not a tuple") → invalid. + +--- State machine --- + +test_initial_state_is_init() + Fresh actuator → state == INIT. + +test_initialize_transitions_to_untouched() + MutableState with lr actuator. + ms.initialize(env) → lr.state == UNTOUCHED, lr.current_value == actual lr. + +test_apply_transitions_to_unverified() + After initialize, ms.apply("lr", 1e-3, env, step=10). + lr.state == UNVERIFIED. + +test_verify_transitions_to_verified() + After apply, ms.verify("lr", {"lr": 1e-3}). + lr.state == VERIFIED. + +test_apply_after_verified_goes_back_to_unverified() + VERIFIED → apply new value → UNVERIFIED again. + +test_disabled_actuator_rejects_apply() + act.state = DISABLED. ms.apply("lr", ...) → fails with "actuator_disabled". + +test_disable_actuator() + ms.disable("lr") → lr.state == DISABLED. + +--- Mutation tracking --- + +test_mutation_recorded_on_apply() + ms.apply("lr", 1e-3, env, step=10) → lr.mutations has 1 entry. + mutation.step == 10, old_value == original, new_value == 1e-3. + +test_multiple_mutations_accumulated() + 3 applies → 3 mutations in list. + +test_last_changed_step_updated() + ms.apply("lr", ..., step=50) → lr.last_changed_step == 50. + +--- apply_fn --- + +test_apply_fn_receives_value_and_env() + Mock apply_fn. ms.apply("lr", 1e-3, env, step=1). + apply_fn called with (1e-3, env). + +test_apply_fn_failure_does_not_mutate_state() + apply_fn returns ApplyResult(success=False). + current_value unchanged, no mutation recorded, state unchanged. + +test_apply_fn_exception_caught() + apply_fn raises RuntimeError. + ms.apply() returns ApplyResult(success=False, error=...). + State not corrupted. + +--- Snapshot / restore --- + +test_snapshot_all() + MutableState with lr + wd. + snapshot = ms.snapshot_all() + Has entries for both keys with value + state. + +test_restore_from_snapshot() + Apply mutations, snapshot, apply more, restore. + Values back to snapshot state. + +--- describe_space --- + +test_describe_space_includes_all_fields() + act.describe_space() → dict with param_key, type, label, group, + min, max, step, choices, current, state. + +test_describe_all() + MutableState with 3 actuators. + ms.describe_all() → list of 3 dicts. + +--- Convenience constructors --- + +test_optimizer_actuators_from_torch_optimizer() + opt = MockOptimizer(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)). + acts = optimizer_actuators(opt) + → 3 actuators: lr (LOG_FLOAT), weight_decay (LOG_FLOAT), betas (TUPLE). + Each has correct current_value from optimizer. + +test_optimizer_actuators_bounds() + acts = optimizer_actuators(opt, lr_bounds=(1e-6, 0.1)) + → lr actuator has min=1e-6, max=0.1. + +test_optimizer_actuators_apply_fn_sets_param_groups() + acts = optimizer_actuators(opt) + lr_act = [a for a in acts if a.param_key == "lr"][0] + lr_act.apply_fn(5e-4, {"optimizer": opt}) + → all param groups now have lr=5e-4. + +test_optimizer_actuators_apply_fn_coordinates_scheduler() + Optimizer + scheduler with base_lrs. + lr_act.apply_fn(5e-4, {"optimizer": opt, "scheduler": sched}) + → scheduler.base_lrs updated too. + +test_loss_actuators_from_dict() + weights = {"recon": 1.0, "kl": 0.5, "perceptual": 0.3} + acts = loss_actuators(weights) + → 3 actuators, all FLOAT, correct current_values. + +test_loss_actuators_apply_fn_mutates_dict() + acts = loss_actuators(weights) + acts[0].apply_fn(2.0, {}) → weights["recon"] == 2.0. + +test_loss_actuators_bounds() + acts = loss_actuators(weights, global_bounds=(0, 10)) + → all have min=0, max=10. + +test_mutable_state_constructor() + ms = hotcb.mutable_state([a1, a2, a3]) + → isinstance(ms, MutableState), ms.keys() == ["lr", "wd", "recon_w"]. +``` + +**Phase 2 does NOT touch:** kernel, modules, server, frontend. + +--- + +### Phase 3: Kernel Unification (Default Stream) +**Goal:** Kernel routes opt/loss/custom ops through `MutableState`. Delete +`HotOptController`, `HotLossController`, `OptimizerActuator`, `MutableStateActuator`. + +**Deleted files:** +- `src/hotcb/modules/opt.py` +- `src/hotcb/modules/loss.py` +- `src/hotcb/actuators/optimizer.py` +- `src/hotcb/actuators/mutable_state.py` + +**Modified files:** +- `src/hotcb/kernel.py` — accept `MutableState`, remove opt/loss from `self.modules`, + default stream routing +- `src/hotcb/actuators/__init__.py` — remove old exports +- `src/hotcb/modules/__init__.py` — remove opt/loss re-exports if any +- `src/hotcb/ops.py` — `command_to_hotop()` default_module stays "cb" (unchanged) + +**Tests (write first):** + +``` +--- Kernel accepts MutableState --- + +test_kernel_init_with_mutable_state() + k = HotKernel(run_dir=..., mutable_state=ms) + k._mutable_state is ms. + +test_kernel_init_without_mutable_state() + k = HotKernel(run_dir=...) + k._mutable_state is None. No error. + +--- Default stream routing --- + +test_opt_set_params_routes_to_mutable_state() + op = HotOp(module="opt", op="set_params", params={"key": "lr", "value": 1e-3}) + Kernel with MutableState containing lr actuator. + kernel._apply_single(op, env, "train_step", 10) + → Ledger record has decision="applied". + → MutableState lr actuator now has current_value=1e-3. + +test_loss_set_params_routes_to_mutable_state() + op = HotOp(module="loss", op="set_params", params={"key": "recon_w", "value": 2.0}) + → Same routing, applied via MutableState. + +test_custom_module_routes_to_mutable_state() + op = HotOp(module="custom", op="set_params", params={"key": "dropout", "value": 0.3}) + MutableState has "dropout" actuator. + → Routes through default stream, applied successfully. + +test_unknown_param_key_fails_gracefully() + op = HotOp(module="opt", op="set_params", params={"key": "nonexistent", "value": 1.0}) + → Ledger: decision="failed", error="unknown_param:nonexistent". + +test_no_mutable_state_fails_gracefully() + Kernel with mutable_state=None. + op with module="opt" → decision="failed", error="no_mutable_state". + +test_cb_still_routes_to_cb_module() + op = HotOp(module="cb", op="enable", ...) + → Routes to CallbackModule, not MutableState. + +test_tune_still_routes_to_tune_module() + op = HotOp(module="tune", op="observe", ...) + → Routes to HotTuneController. + +test_core_still_routes_to_core() + op = HotOp(module="core", op="freeze", ...) + → Routes to _apply_core_op. + +--- Freeze enforcement still works --- + +test_freeze_blocks_default_stream() + Kernel in freeze mode="prod". + op with module="opt" → decision="ignored_freeze". + +test_freeze_blocks_custom_module() + op with module="custom" → decision="ignored_freeze". + +--- Param key resolution --- + +test_param_key_from_params_key() + op.params = {"key": "lr", "value": 1e-3} + _resolve_param_key(op) → "lr". + +test_param_key_from_legacy_opt_format() + op.module = "opt", op.params = {"lr": 1e-3} (no "key" field) + _resolve_param_key(op) → extracts "lr" from params dict. + +test_param_key_from_legacy_loss_format() + op.module = "loss", op.params = {"recon_w": 2.0} (suffix convention) + _resolve_param_key(op) → "recon". + +--- register_actuator compatibility with tune --- + +test_register_mutable_state_propagates_to_tune() + k = HotKernel(..., mutable_state=ms) + tune module can see all actuators via ms.describe_all(). + +--- Enable/disable via default stream --- + +test_enable_disable_actuator_via_op() + op = HotOp(module="opt", op="disable", params={"key": "lr"}) + → lr actuator state becomes DISABLED. + op = HotOp(module="opt", op="enable", params={"key": "lr"}) + → lr actuator state becomes UNTOUCHED (or last known good state). + +--- Ledger format --- + +test_ledger_preserves_module_field() + op with module="opt" applied via default stream. + Ledger record has module="opt" (not "default" or "mutable_state"). + +test_ledger_records_mutation_detail() + Applied lr change. + Ledger payload includes old_value, new_value, param_key. +``` + +**Migration of existing tests:** +- `test_hotopt.py` (12 tests): Rewrite to test through kernel default stream + instead of `HotOptController` directly. Same behaviors, different entry point. +- `test_hotloss.py` (14 tests): Same — rewrite to test through kernel. +- `test_kernel_core.py` `test_route_to_opt`, `test_route_to_loss`: Update to verify + default stream routing instead of module lookup. +- `test_hottune.py` actuator tests: Update `OptimizerActuator` → `optimizer_actuators()`, + `MutableStateActuator` → `loss_actuators()`. The tune controller's + `register_actuator()` interface may need updating to work with `MutableState`. +- `test_server_api.py` opt/loss endpoint tests: Still work — API writes same JSONL + commands, kernel routes differently. +- `test_new_features.py` opt/loss traceback tests: Rewrite for new error path. + +--- + +### Phase 4: Dynamic Controls from Actuators +**Goal:** `/api/config` `controls` field populated from `MutableState.describe_all()`. +Frontend generates controls dynamically. Remove all hardcoded slider HTML. + +**Modified files:** +- `src/hotcb/server/config.py` — `controls` populated from actuator metadata +- `src/hotcb/server/app.py` — wire `MutableState.describe_all()` into config +- `src/hotcb/server/static/index.html` — remove hardcoded knob rows, keep `
` +- `src/hotcb/server/static/js/controls.js` — `buildControls(specs)`, `buildKnobRow(spec)` +- `src/hotcb/server/static/css/dashboard.css` — remove `single-loss-only` etc. +- `src/hotcb/capabilities.py` — remove `validate_mutable_state()`, keep `TrainingCapabilities` + but remove `mutable_state_detected`/`mutable_state_keys` (now in `MutableState`) + +**Tests (write first):** +``` +test_config_controls_from_mutable_state() + App with MutableState(lr, wd, recon_w). + GET /api/config → controls has 3 entries. + Each has param_key, type, label, min, max, current. + +test_config_controls_empty_when_no_mutable_state() + App without MutableState. + GET /api/config → controls is []. + +test_config_controls_types_match_actuators() + MutableState with LOG_FLOAT lr, FLOAT recon_w, BOOL use_augment. + Controls: [{type: "log_float"}, {type: "float"}, {type: "bool"}]. + +test_config_controls_groups_present() + optimizer_actuators → group="optimizer". + loss_actuators → group="loss". + Custom → group="custom". + +test_config_controls_reflect_live_state() + Apply mutation to lr. + GET /api/config → lr control has updated current + state="unverified". + +--- Server API endpoints adapt --- + +test_opt_set_endpoint_still_works() + POST /api/opt/set with {params: {lr: 1e-3}} + → command written to JSONL with module="opt". + +test_loss_set_endpoint_still_works() + POST /api/loss/set with {params: {recon_w: 2.0}} + → command written to JSONL with module="loss". + +test_control_state_endpoint_uses_mutable_state() + GET /api/state/controls + → Returns live values from MutableState, not hardcoded schema. +``` + +--- + +### Phase 5: Immutable `run_dir` +**Goal:** `run_dir` set once, never mutated. Remove rewire infrastructure. + +**Modified files:** +- `src/hotcb/server/app.py` — remove `app.state.run_dir` mutation, use `config.run_dir` +- `src/hotcb/server/launcher.py` — write to `config.run_dir` directly, no subdirs +- `src/hotcb/server/tailer.py` — remove `rewire()` method + +**Tests (write first):** +``` +test_run_dir_set_once() + config = DashboardConfig(run_dir="/tmp/x") + All endpoints use "/tmp/x". No mutation observed. + +test_launcher_writes_to_config_run_dir() + Launcher.start() → JSONL files appear in config.run_dir, not subdirs. + +test_launcher_truncates_on_restart() + Launcher.start() twice → JSONL files truncated (like reset), no subdirs. + +test_tailer_no_rewire_method() + JsonlTailer has no rewire() attribute. + +test_endpoints_use_immutable_run_dir() + Start app. Simulate launcher. + GET /api/metrics/history → reads from original run_dir. + +test_compare_reads_siblings_readonly() + GET /api/runs/discover → scans parent dir. + Original monitored dir unchanged. +``` + +--- + +### Phase 6: Magic Number Extraction +**Goal:** All bare literals replaced with config reads. + +**Modified files:** +- `src/hotcb/server/app.py` — history limits, WS burst from config +- `src/hotcb/server/tailer.py` — poll_interval from config +- `src/hotcb/server/autopilot.py` — thresholds from config +- `src/hotcb/server/ai_engine.py` — cadence/thresholds from config +- `src/hotcb/server/static/js/charts.js` — read from `S.config.chart.*` +- `src/hotcb/server/static/js/controls.js` — threshold from `S.config.ui.*` +- `src/hotcb/server/static/js/websocket.js` — retries from `S.config.server.*` +- `src/hotcb/server/static/js/init.js` — intervals from `S.config.ui.*` +- `src/hotcb/server/static/js/panels.js` — intervals from `S.config.ui.*` + +**Tests (write first):** +``` +test_tailer_uses_config_poll_interval() + Config with poll_interval=2.0. + Tailer constructed with it → internal interval is 2.0. + +test_history_limits_from_config() + Config with history_limit_metrics=100. + GET /api/metrics/history → returns at most 100 records. + +test_ws_burst_from_config() + Config with ws_initial_burst=50. + WS connect → at most 50 records in initial burst. + +test_autopilot_thresholds_from_config() + Config with divergence_threshold=5.0. + Autopilot engine uses 5.0, not hardcoded 2.0. + +test_ai_cadence_from_config() + Config with ai_default_cadence=100. + AI engine uses 100, not hardcoded 50. +``` + +Frontend magic number replacement is verified manually + by existing test suite +(backend serves correct config, frontend reads it). + +--- + +## Adapter Integration + +Adapters (`lightning.py`, `hf.py`) currently populate `TrainingCapabilities` and +put `optimizer`/`mutable_state` in `env`. After this change: + +1. Adapters create `MutableState` from the optimizer + any user-registered actuators +2. Pass `MutableState` to kernel constructor (or register after init) +3. `env` still carries `optimizer`, `scheduler`, etc. for `apply_fn` closures +4. `TrainingCapabilities` still written to `hotcb.capabilities.json` for + framework-level info, but `mutable_state_detected`/`mutable_state_keys` + are replaced by `MutableState.describe_all()` + +--- + +## Demo Updates + +All 3 demos (`demo.py`, `golden_demo.py`, `finetune_demo.py`) currently create +`_OptProxy` + register `OptimizerActuator` and `MutableStateActuator`. After this: + +```python +# Instead of: +k.register_actuator("opt", OptimizerActuator()) +k.register_actuator("loss", MutableStateActuator()) + +# Becomes: +from hotcb.actuators import optimizer_actuators, loss_actuators, mutable_state +ms = mutable_state( + optimizer_actuators(opt_proxy) + loss_actuators(loss_weights) +) +k = HotKernel(run_dir=..., mutable_state=ms) +``` + +--- + +## Impact on Existing Test Suites + +| Test file | Tests | Impact | +|-----------|-------|--------| +| `test_hotopt.py` | 12 | **Rewrite**: Test via kernel default stream | +| `test_hotloss.py` | 14 | **Rewrite**: Test via kernel default stream | +| `test_kernel_core.py` | 19 | **Update 4**: route_to_opt/loss tests change; rest unchanged | +| `test_hottune.py` | 104+ | **Update ~30**: Replace OptimizerActuator/MutableStateActuator with new types | +| `test_server_api.py` | 40+ | **Unchanged**: API writes same JSONL, routing is kernel-internal | +| `test_server_app.py` | 15 | **Update ~5**: Config endpoint, status endpoint | +| `test_new_features.py` | 49 | **Update 2**: opt/loss traceback tests | +| `test_launch.py` | 21 | **Update ~5**: Actuator registration changes | +| `test_backend_gaps.py` | varies | **Update**: Actuator-related tests | + +New test files: +- `test_dashboard_config.py` — ~15 tests (Phase 1) +- `test_actuator_unified.py` — ~35 tests (Phase 2) +- Phase 3 kernel tests integrated into `test_kernel_core.py` + +--- + +## Phase Ordering & Dependencies + +``` +Phase 1 (Config) ──────────────────────────────────────→ Phase 5 (Immutable run_dir) + → Phase 6 (Magic numbers) +Phase 2 (Actuator types) → Phase 3 (Kernel) → Phase 4 (Dynamic controls) +``` + +Phases 1 and 2 are independent — can be done in parallel. +Phase 3 depends on Phase 2. +Phase 4 depends on Phases 1 + 3 (needs config endpoint + actuator metadata). +Phase 5 depends on Phase 1 (uses config). +Phase 6 depends on Phases 1 + 5 (config exists + run_dir stable). + +**Recommended order:** 1 → 2 → 3 → 4 → 5 → 6 +(Phases 1+2 could run in parallel if two sessions available.) + +--- + +## Verification + +After each phase: +```bash +pytest src/hotcb/tests/ -x -q --no-cov # full suite passes +``` + +After Phase 3 (the big one): +```bash +hotcb demo # all 3 demo configs work +hotcb demo --golden # golden demo metrics flow +``` + +After Phase 4: +```bash +hotcb serve --dir +# Dashboard shows 5 loss weight sliders (not 2 hardcoded ones) +``` + +After Phase 6: +```bash +# Create hotcb.dashboard.yaml with custom intervals +hotcb serve --dir runs/exp1 +# Verify custom values appear in browser console: S.config +``` diff --git a/.claude/skills/hotcb-autopilot/SKILL.md b/.claude/skills/hotcb-autopilot/SKILL.md new file mode 100644 index 0000000..aac7b3b --- /dev/null +++ b/.claude/skills/hotcb-autopilot/SKILL.md @@ -0,0 +1,407 @@ +--- +name: hotcb-autopilot +description: Launch AI-driven training optimization with hotcb. Checks installation, asks the user for run config, optionally deep-reads training code for context-aware optimization, then runs as the live autopilot — reading metrics, analyzing trends, and issuing hotcb commands to tune hyperparameters, loss weights, and callbacks during training. Use when the user wants to optimize a PyTorch training run with AI assistance. +--- + +# hotcb AI Autopilot — Claude Code Skill + +You are acting as the AI autopilot for hotcb, a live training control plane for PyTorch. +Your job: read training metrics, analyze trends, and issue hotcb commands to optimize training — all without restarting. + +## Phase 0: Understand the Training Code (optional, user-gated) + +**Before anything else**, ask the user one question: + +> "Can I read your training code? This lets me understand your model architecture, loss function, optimizer setup, augmentation pipeline, and logging — so I can make smarter decisions during training (and occasionally suggest code changes). This uses extra tokens but gives much better results. (y/n)" + +### If the user says yes: + +Ask them to point you to the key files (or directories) — e.g. "Which files contain your training loop, model definition, loss, and data pipeline?" + +Then read those files and produce a **training context summary** saved to `/hotcb.training_context.md`. This file is reused in future invocations so you don't re-read the codebase every time. + +The summary should cover: + +```markdown +# Training Context — +Generated by hotcb autopilot on + +## Model Architecture +- Type: (e.g. ResNet-50, GPT-2, ViT-B/16, custom UNet) +- Parameters: (approx count if visible) +- Key layers / heads: (what matters for loss routing, freezing, etc.) + +## Loss Function +- Type: (CrossEntropy, MSE, multi-task weighted sum, custom) +- Terms: (list each term, its weight, whether it's toggleable) +- Key metric relationship: (which loss terms most affect which metrics) + +## Optimizer & Scheduler +- Optimizer: (Adam, AdamW, SGD, etc.) with initial params (lr, wd, momentum) +- Scheduler: (cosine, step, warmup+decay, none) +- Known sensitivities: (e.g. "lr > 1e-3 likely unstable for this arch") + +## Data Pipeline +- Dataset: (name/size if visible) +- Augmentations: (list key transforms) +- Batch size / accumulation: (if visible) + +## Logging & Metrics +- What metrics are logged: (train_loss, val_loss, val_acc, grad_norm, lr, etc.) +- Logging frequency: (every step, every N steps, every epoch) +- Validation frequency: (every N steps, every epoch) + +## Notable Code Patterns +- (Anything relevant: gradient clipping, mixed precision, EMA, custom callbacks, etc.) +``` + +**Keep the summary concise** — aim for 50-100 lines. This is context for the autopilot, not full documentation. + +If `hotcb.training_context.md` already exists in the run dir, ask: "I found an existing training context from a previous session. Should I reuse it, or re-read the code?" If reuse, just load it. Do not re-scan. + +### If the user says no: + +Skip entirely. Proceed to Phase 1. You'll operate on metrics alone (still effective, just less context for decisions). + +### Using the context during autopilot + +When `hotcb.training_context.md` exists: +- **Inform your decisions**: e.g. if the model uses cosine scheduling, don't fight the scheduler with lr adjustments unless it's clearly broken +- **Suggest code changes** (karpathy-style): if you notice something during training that a code change would fix better than a knob turn — e.g. "your augmentation pipeline doesn't include mixup, which could help with the overfitting I'm seeing" or "your warmup is 100 steps but this model typically needs 500" — suggest the edit with a specific diff. Don't force it; suggest and let the user decide. +- **Loss term awareness**: if you know the loss is a weighted sum of 3 terms, you can make much smarter `loss set_params` calls +- **Architecture awareness**: know what "grad_norm is spiking" means for this specific model + +## Phase 1: Quick Setup + +### 1.1 Check hotcb installation + +```bash +python3 -c "import hotcb; print('hotcb OK')" 2>&1 || PYTHONPATH=src python3 -c "import hotcb; print('hotcb OK (dev mode)')" 2>&1 +python3 -c "import fastapi, uvicorn, websockets; print('dashboard deps OK')" 2>&1 +``` + +If the import only works with `PYTHONPATH=src`, prefix all subsequent `python3` and `hotcb` commands with `PYTHONPATH=src`. + +If not installed: tell user to run `pip install "hotcb[dashboard]"` (or `pip install -e ".[dev,all]"` if developing from source). + +### 1.2 Check for launch config + +Look for `hotcb.launch.json` in the project root (or current directory). This file is created during integration and contains everything needed to launch — **if it exists, skip all user questions and go straight to Phase 2.** + +```bash +cat hotcb.launch.json 2>/dev/null || echo "NOT_FOUND" +``` + +The file looks like: + +```json +{ + "train_fn": "my_project.train:train", + "run_dir": "./runs", + "key_metric": "val_loss", + "max_steps": 5000, + "max_time": 300, + "autopilot": "ai_suggest", + "port": 8421 +} +``` + +All fields are optional except `train_fn` (or training must already be running). If `hotcb.launch.json` exists: +- Use its values as defaults +- Print: "Found hotcb.launch.json — launching with: train_fn=X, key_metric=Y, max_time=Z" +- Go directly to Phase 2 (still ask Phase 0 code-reading question if `hotcb.training_context.md` doesn't exist yet) + +### 1.3 If no launch config — ask the user + +Do NOT scan the repo for integration. Ask the user directly: + +1. **Run directory**: "Where is your run directory? (path with `hotcb.metrics.jsonl`, or I'll create one)" +2. **Training status**: "Is training already running, or should I start it? If I should start it, what's the training function? (e.g. `my_module:train`)" +3. **Key metric**: "What metric should I optimize? (e.g. `val_loss`, `val_accuracy`)" +4. **Time/step limit**: "How long should training run? (e.g. `300` seconds, `5000` steps, or both)" +5. **Integration check**: "Does your training loop write `hotcb.metrics.jsonl` and read `hotcb.commands.jsonl`? (If not, I'll show you the 10-line integration snippet)" + +If the user says integration isn't done, show this minimal contract: + +```python +# Add to your training loop — 10 lines total +import json, os + +def write_metrics(run_dir, step, metrics_dict): + with open(os.path.join(run_dir, "hotcb.metrics.jsonl"), "a") as f: + f.write(json.dumps({"step": step, "metrics": metrics_dict}) + "\n") + +def read_commands(run_dir, offset=0): + path = os.path.join(run_dir, "hotcb.commands.jsonl") + if not os.path.exists(path): return [], offset + cmds = [] + with open(path) as f: + for i, line in enumerate(f): + if i < offset: continue + line = line.strip() + if line: + try: cmds.append(json.loads(line)) + except: pass + return cmds, offset + len(cmds) +``` + +Or for adapter users: `HotCBLightning(kernel)` / `HotCBHFCallback(kernel)` handle this automatically. + +For the full integration reference (all options, metrics/command format, loss weight exposure, train_fn contract), read `INTEGRATION.md` in the hotcb repo. That file is designed to be self-contained context for AI agents working on external training projects. + +## Phase 2: Launch + +Build the launch command from either `hotcb.launch.json` values or user answers. + +Use the autopilot mode from the config/user — do NOT override it to `off`. If the user specified `ai_suggest` or `ai_auto`, launch with that. If `hotcb.launch.json` has `"autopilot": "ai_suggest"`, use `--autopilot ai_suggest`. If no autopilot was specified, default to `ai_suggest` (since the user invoked this skill, they want AI optimization). + +### 2.1 If training needs to be started: + +```bash +hotcb launch \ + --train-fn \ + --run-dir \ + --max-steps \ + --max-time \ + --key-metric \ + --port \ + --autopilot & +LAUNCH_PID=$! +echo "Dashboard: http://localhost:" +``` + +Omit `--max-steps` or `--max-time` if not specified. Include both if both are set (whichever limit hits first wins). + +If `hotcb.launch.json` exists, you can also use: +```bash +hotcb launch --config-file hotcb.launch.json & +``` + +### 2.2 If training is already running (just attach dashboard): + +```bash +hotcb serve --dir --port --autopilot & +SERVER_PID=$! +echo "Dashboard: http://localhost:" +``` + +### 2.2 Wait for metrics + +Poll until metrics start flowing: +```bash +for i in $(seq 1 30); do + if [ -s "/hotcb.metrics.jsonl" ]; then + echo "Metrics flowing" + break + fi + sleep 1 +done +``` + +## Phase 3: Autopilot Loop + +**This is where YOU (Claude Code) act as the AI autopilot.** + +You replace the external LLM — you read metrics directly, reason about them, and issue hotcb commands via the REST API or CLI. No API key needed. + +### 3.1 Read current state + +Use the REST API to get current metrics and status: + +```bash +# Get latest metrics +curl -s http://localhost:8421/api/metrics/history?last_n=20 | python3 -m json.tool + +# Get metric names +curl -s http://localhost:8421/api/metrics/names | python3 -m json.tool + +# Get current status +curl -s http://localhost:8421/api/status | python3 -m json.tool + +# Get applied mutations +curl -s http://localhost:8421/api/applied/history?last_n=10 | python3 -m json.tool +``` + +Or read JSONL files directly: +```bash +tail -5 /hotcb.metrics.jsonl | python3 -m json.tool +``` + +### 3.2 Analyze trends + +For each metric, compute: +- **Direction**: Is it going up, down, or flat? +- **Rate**: How fast is it changing? +- **Volatility**: Is it noisy or smooth? +- **Key events**: Any spikes, plateaus, or trend reversals? + +Focus on the **key metric** but monitor all metrics for health. + +If `hotcb.training_context.md` exists, cross-reference your analysis with the known model/loss/optimizer setup. For example: if context says cosine LR schedule is active, a slowly decreasing lr is expected, not a problem. + +### 3.3 Decide and act + +Based on your analysis: + +**If training is healthy** (loss decreasing steadily, no divergence): +→ Do nothing. Report status to user. Check back in 50-100 steps. + +**If loss has plateaued** (flat for 20+ steps): +→ Reduce learning rate by 50%: +```bash +curl -s -X POST http://localhost:8421/api/opt/set \ + -H 'Content-Type: application/json' \ + -d '{"params": {"lr": }}' +``` +Or via CLI: `hotcb --dir set lr=` + +**If loss is diverging** (increasing sharply): +→ Reduce learning rate aggressively (10x): +```bash +curl -s -X POST http://localhost:8421/api/opt/set \ + -H 'Content-Type: application/json' \ + -d '{"params": {"lr": }}' +``` + +**If overfitting** (train loss << val loss, or val loss rising while train loss falls): +→ Increase weight decay: +```bash +curl -s -X POST http://localhost:8421/api/opt/set \ + -H 'Content-Type: application/json' \ + -d '{"params": {"weight_decay": }}' +``` + +**If multi-task and one task is dominating**: +→ Adjust loss weights: +```bash +curl -s -X POST http://localhost:8421/api/loss/set \ + -H 'Content-Type: application/json' \ + -d '{"params": {"weight_a": 0.5, "weight_b": 0.5}}' +``` + +### 3.4 Suggest code changes (if training context available) + +When you have `hotcb.training_context.md`, you can occasionally suggest code-level fixes — things a knob turn can't do: + +- "Your augmentation is missing X, which would help with the overfitting pattern I'm seeing" +- "The warmup schedule is too short for this model size — suggest changing line N in train.py" +- "Consider adding gradient clipping — grad_norm has been volatile" +- "The validation frequency is too low to catch divergence early — suggest validating every 50 steps instead of every epoch" + +**Present as a suggestion with a concrete diff.** Don't apply without user approval. These are between-check-in suggestions, not every cycle. + +### 3.5 Report to user + +After each analysis cycle, report: +1. Current step and key metrics +2. Trend assessment (healthy/plateau/diverging/overfitting) +3. Action taken (if any) with reasoning +4. Next check-in time + +Format: +``` +## Step {N} — {Assessment} +- train_loss: {val} ({trend}) +- val_loss: {val} ({trend}) +- lr: {val} +- Action: {what you did and why, or "none — training healthy"} +- Next check: step {N+50} +``` + +### 3.6 Cadence + +- **Normal**: Check every 50 steps +- **After intervention**: Check in 20 steps to see effect +- **If diverging**: Check every 10 steps +- **If healthy and improving**: Can extend to every 100 steps + +**Wall-clock cap**: Regardless of step-based cadence, NEVER wait more than 60 seconds between checks. If training is slow (e.g., 1 step/sec), the step-based cadence may translate to minutes of silence. Always set a wall-clock timeout: + +```bash +# Example: wait for step-based cadence OR 60s, whichever comes first +NEXT_CHECK_STEP= +TIMEOUT=60 +START=$(date +%s) +while true; do + LATEST=$(curl -s http://localhost:8421/api/metrics/history?last_n=1 | python3 -c "import sys,json; d=json.load(sys.stdin); print(d['records'][-1]['step'] if d.get('records') else 0)" 2>/dev/null || echo 0) + ELAPSED=$(( $(date +%s) - START )) + if [ "$LATEST" -ge "$NEXT_CHECK_STEP" ] || [ "$ELAPSED" -ge "$TIMEOUT" ]; then break; fi + sleep 3 +done +``` + +This ensures responsiveness even with slow training loops. + +### 3.7 Graduation principle + +Start with small changes: +1. First intervention: lr *= 0.5 +2. If no improvement after 30 steps: lr *= 0.5 again +3. If still no improvement: try weight_decay adjustment +4. Only after 3+ failed interventions: consider declaring run degenerate + +### 3.8 Multi-run awareness + +If a run is truly degenerate (NaN loss, diverged beyond recovery): +1. Save learnings about what failed +2. Suggest to user: "This run has diverged. I recommend restarting with lr={suggested_lr}." +3. Don't keep trying to fix an unfixable run + +## Phase 4: Finalization + +When training completes (metrics stop flowing or max_steps reached): + +1. Summarize the run: + - Final key metric value + - Total interventions made + - What worked and what didn't +2. Export recipe: `hotcb --dir recipe export` +3. If training context was read, suggest improvements for next run (code-level) +4. Suggest next steps + +## Key Principles + +1. **Conservative**: When in doubt, do nothing. Healthy training doesn't need intervention. +2. **Graduated**: Small changes first. Never change lr by more than 50% in one step. +3. **Observable**: Wait enough steps after an intervention to see its effect (20-50 steps minimum). +4. **Transparent**: Always explain your reasoning to the user. +5. **Reversible**: Prefer changes that can be undone. Don't disable callbacks unless clearly needed. +6. **Code-aware**: When you have training context, use it. A code suggestion is sometimes better than a knob turn. + +## Available hotcb Commands Reference + +```bash +# Optimizer +hotcb --dir set lr=0.001 +hotcb --dir set weight_decay=0.01 +hotcb --dir opt set_params lr=0.001 weight_decay=0.01 + +# Loss +hotcb --dir loss set_params weight_a=0.7 weight_b=0.3 + +# Callbacks +hotcb --dir cb enable +hotcb --dir cb disable + +# Status +hotcb --dir status + +# Recipe export +hotcb --dir recipe export +``` + +## REST API Reference + +``` +GET /api/metrics/names — list discovered metric names +GET /api/metrics/history — recent metric records (?last_n=500) +GET /api/applied/history — applied mutations (?last_n=200) +GET /api/status — run status (freeze, files) +GET /api/health — server health +GET /api/train/status — training thread status +POST /api/opt/set — set optimizer params {"params": {"lr": 0.001}} +POST /api/loss/set — set loss params {"params": {"weight_a": 0.7}} +GET /api/autopilot/status — autopilot state +GET /api/autopilot/ai/status — AI autopilot state (cost, key_metric, etc.) +GET /api/autopilot/ai/history — AI decision history with reasoning +``` diff --git a/.coverage b/.coverage index 571558e..4f96394 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index ac6abfb..e6f946f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ dist/ *.egg-info/ .eggs/ .idea/ -.claude/ +.claude/worktrees/ +.claude/setting* lightning_logs/ data/* # (optional) If you use pip editable installs, also ignore: diff --git a/CLAUDE.md b/CLAUDE.md index 733bb82..4fc579d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -47,21 +47,31 @@ The kernel and training process communicate through the filesystem (JSONL files) | Module | Path | Controls | |--------|------|----------| | **cb** | `modules/cb/` | Callback load/unload/enable/disable/reconfigure. Has its own controller, loader, protocol, adapters | -| **opt** | `modules/opt.py` | Live optimizer param changes (lr, weight_decay, clip) | -| **loss** | `modules/loss.py` | Loss weights, term toggles, ramp configs | | **tune** | `modules/tune/` | Online constrained HPO via Optuna (optional `hotcb[tune]`) | +| **opt/loss/custom** | Default stream → `MutableState` | All scalar parameter control (lr, weights, custom knobs) via unified actuator system | ### Key types - **`HotOp`** (`ops.py`): Normalized operation dataclass — every command becomes one. Fields: `module`, `op`, `id`, `params`, `target`, etc. - **`CallbackTarget`** (`ops.py`): Specifies a callback to load (kind, path, symbol). -- **`HotKernel`** (`kernel.py`): Central coordinator. Holds module instances, actuator registry, optional `metrics_collector`. Called via `kernel.apply(env=..., events=...)` each training step. +- **`HotKernel`** (`kernel.py`): Central coordinator. Holds module instances, `MutableState`, optional `metrics_collector`. Called via `kernel.apply(env=..., events=...)` each training step. Ops for `cb`/`tune` route to their modules; all others (opt/loss/custom) go through the default stream to `MutableState`. +- **`HotcbActuator`** (`actuators/actuator.py`): Single controllable parameter — 1:1 mapping (param_key ↔ actuator). Has type (BOOL/FLOAT/INT/CHOICE/LOG_FLOAT/TUPLE), `apply_fn`, bounds, state machine (INIT→UNTOUCHED→UNVERIFIED→VERIFIED→DISABLED). +- **`MutableState`** (`actuators/state.py`): Container of `HotcbActuator` instances. Provides `apply()`, `initialize()`, `verify()`, `describe_all()`. - **`FreezeState`** (`freeze.py`): Freeze mode manager (off/prod/replay/replay_adjusted). - **`RecipePlayer`** (`recipe.py`): Deterministic replay of exported recipes. ### Actuator system (`src/hotcb/actuators/`) -Protocol-based (`BaseActuator`) — optimizer and loss_state actuators register with the kernel and are auto-propagated to the tune controller. +Unified per-parameter actuator model. Convenience constructors: +- `optimizer_actuators(optimizer)` — creates lr, wd, betas actuators from a torch optimizer +- `loss_actuators(weights_dict)` — creates FLOAT actuators that mutate the original dict +- `mutable_state(actuators)` — wraps a list of `HotcbActuator` instances into a `MutableState` + +Adapters auto-discover optimizer actuators from the framework (Lightning/HF). Users register custom actuators via `mutable_state()`. + +### Dashboard config (`src/hotcb/server/config.py`) + +`DashboardConfig` centralizes all tunables (poll intervals, history limits, chart settings, UI timers). Loaded from defaults → YAML → env vars → CLI. Served at `/api/config`, fetched once by frontend into `S.config`. Controls are generated dynamically from `MutableState.describe_all()` — no hardcoded slider HTML. ### Server / Dashboard (`src/hotcb/server/`) @@ -74,6 +84,10 @@ FastAPI app (`app.py`) served via `hotcb serve`. Architecture: - **`launcher.py`**: Training launch/stop/reset from the dashboard - Static frontend: `server/static/` — vanilla JS (charts.js, controls.js, panels.js, websocket.js, state.js, init.js) +### Demos (`src/hotcb/demo.py`, `golden_demo.py`, `finetune_demo.py`) + +Synthetic training loops that use HotKernel + MetricsCollector + actuators — the same integration path as real projects. Demos use a lightweight `_OptProxy` (dict with `param_groups`) instead of a real torch optimizer. Recipe-driven changes are injected as commands to `hotcb.commands.jsonl` at scheduled steps (not freeze/replay mode), so interactive dashboard control works simultaneously. + ### Launch API (`src/hotcb/launch.py`) Programmatic API for starting training + dashboard + autopilot in one call. Returns `LaunchHandle` with methods for metrics access, live commands, and AI state inspection. Used by `hotcb launch` CLI and notebook workflows. @@ -91,6 +105,12 @@ Top-level adapters (`lightning.py`, `hf.py`) wrap HotKernel for PyTorch Lightnin Synthetic benchmarks and CIFAR-10 autopilot evaluation. `tasks.py` defines tasks, `runner.py` runs them, `report.py` generates outputs, `eval_autopilot.py` compares baseline vs autopilot. +## Multi-Agent Coordination + +`.claude/plans/STREAMS.md` is the shared roadmap for parallel Claude Code sessions. +One file, all streams. Use `/stream` to browse, attach, create, or release streams. +Claim a stream (`status → active`), update checkboxes + log as you work, release when done. + ## Conventions - **Filesystem as IPC**: Training ↔ dashboard communication is via JSONL files, not sockets or shared memory. @@ -100,3 +120,6 @@ Synthetic benchmarks and CIFAR-10 autopilot evaluation. `tasks.py` defines tasks - **Source layout**: `src/hotcb/` is the single package. All imports use `hotcb.*`. - **Autopilot modes**: `off`, `suggest`, `auto` (rule-based); `ai_suggest`, `ai_auto` (LLM-driven). Rules act as sensor layer for AI modes. - **AI autopilot uses OpenAI-compatible API**: configured via `HOTCB_AI_KEY` env var and `AIConfig`. Works with OpenAI, ollama, vLLM. + + +Always use skills /python-runtime-patterns /python-project-setup /python-dev-practices when working with this project. \ No newline at end of file diff --git a/INTEGRATION.md b/INTEGRATION.md index e104732..77ba8c5 100644 --- a/INTEGRATION.md +++ b/INTEGRATION.md @@ -80,24 +80,25 @@ for step in range(max_steps): from hotcb.kernel import HotKernel from hotcb.metrics import MetricsCollector -kernel = HotKernel(run_dir="./runs/exp1", debounce_steps=10) -mc = MetricsCollector("./runs/exp1/hotcb.metrics.jsonl") +mc = MetricsCollector(os.path.join("./runs/exp1", "hotcb.metrics.jsonl")) +kernel = HotKernel(run_dir="./runs/exp1", debounce_steps=10, metrics_collector=mc) for step, batch in enumerate(dataloader): loss = train_step(batch) - mc.log(step=step, metrics={"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) - kernel.apply( - env={ - "framework": "torch", - "phase": "train", - "step": step, - "optimizer": optimizer, - "loss_state": model.loss_state, # optional, for loss weight control - "log": print, + env = { + "framework": "torch", + "phase": "train", + "step": step, + "optimizer": optimizer, + "mutable_state": getattr(model, "mutable_state", None), # optional, for loss weight control + "metrics": { + "loss": loss.item(), + "lr": optimizer.param_groups[0]["lr"], }, - events=["train_step_end"], - ) + "log": print, + } + kernel.apply(env, events=["train_step_end"]) ``` ## Option C: Framework Adapters (Lightning / HuggingFace) @@ -159,7 +160,7 @@ If your model has a multi-task or weighted loss, expose it as a mutable dict: ```python # On your model or as a standalone dict -loss_state = { +mutable_state = { "weights": {"cls": 1.0, "recon": 0.5, "reg": 0.1}, "terms": {"cls": True, "recon": True, "reg": True}, # toggleable } @@ -262,15 +263,15 @@ Key points: ### Using HotKernel instead of manual command polling -If you prefer the full kernel (automatic command routing, loss_state support, freeze modes): +If you prefer the full kernel (automatic command routing, mutable_state support, freeze modes): ```python def train(run_dir, max_steps, step_delay, stop_event): from hotcb.kernel import HotKernel from hotcb.metrics import MetricsCollector - kernel = HotKernel(run_dir=run_dir, debounce_steps=10) mc = MetricsCollector(os.path.join(run_dir, "hotcb.metrics.jsonl")) + kernel = HotKernel(run_dir=run_dir, debounce_steps=10, metrics_collector=mc) model = build_model() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) @@ -284,12 +285,19 @@ def train(run_dir, max_steps, step_delay, stop_event): optimizer.step() optimizer.zero_grad() - mc.log(step=step, metrics={"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) - kernel.apply( - env={"framework": "torch", "phase": "train", "step": step, - "optimizer": optimizer, "loss_state": getattr(model, "loss_state", None), "log": print}, - events=["train_step_end"], - ) + env = { + "framework": "torch", + "phase": "train", + "step": step, + "optimizer": optimizer, + "mutable_state": getattr(model, "mutable_state", None), + "metrics": { + "loss": loss.item(), + "lr": optimizer.param_groups[0]["lr"], + }, + "log": print, + } + kernel.apply(env, events=["train_step_end"]) if step_delay > 0: time.sleep(step_delay) diff --git a/MAINTENANCE.md b/MAINTENANCE.md new file mode 100644 index 0000000..03f2db7 --- /dev/null +++ b/MAINTENANCE.md @@ -0,0 +1,233 @@ +# MAINTENANCE.md — hotcb Release Readiness Audit + +Generated: 2026-03-12 + +## Priority Legend +- **P0** — Must fix before release (broken functionality) +- **P1** — Should fix (poor UX, data loss risk) +- **P2** — Nice to fix (code quality, polish) +- **P3** — Backlog (low-impact improvements) + +--- + +## 1. User-Reported Issues (P0) + +### 1.1 Claude Skill PYTHONPATH Issue +**File:** `.claude/skills/hotcb-autopilot/SKILL.md:78-79` +**Problem:** Skill uses `python3 -c "import hotcb; ..."` which fails when hotcb is installed in editable mode or not on `sys.path`. Need `PYTHONPATH=src` prefix or `python3 -m hotcb` pattern. +**Fix:** Update all `python3 -c` invocations to use `PYTHONPATH=src python3 -c` or `python3 -m` equivalents. + +### 1.2 Claude Skill Cadence Cap +**File:** `.claude/skills/hotcb-autopilot/SKILL.md:309-314` +**Problem:** Phase 3.6 cadence section has no wall-clock time limit. If training is slow (1 step/sec), "check every 100 steps" means 100 seconds of silence. Need max 60-second wall-clock cap. +**Fix:** Add wall-clock time limit to cadence rules. + +### 1.3 Tooltip Position — Show on Side, Not Top +**File:** `src/hotcb/server/static/js/charts.js:312-314` +**Problem:** Chart.js tooltip appears above the cursor, blocking view of data points. Should appear to the side. +**Fix:** Register a custom Chart.js tooltip positioner that places tooltip to the right of the cursor. + +### 1.4 Controls Don't Reflect State for Non-Demo Projects +**Files:** `src/hotcb/server/static/js/controls.js:579-627`, `src/hotcb/server/app.py:398-468` +**Problem:** `/api/train/status` returns empty config when training wasn't started via the dashboard's `TrainingLauncher`. External training (via adapters, `hotcb launch`, or direct kernel usage) writes `hotcb.run.json` and `hotcb.applied.jsonl` but the frontend only syncs from launcher status. +**Fix:** On dashboard load, call `/api/state/controls` to hydrate sliders from last applied opt/loss params and run config from `hotcb.run.json`. Add frontend init that reads this endpoint. + +### 1.5 Mutation Capsules Don't Render for Non-Demo Runs +**File:** `src/hotcb/server/static/js/panels.js` (addTimelineItem function) +**Problem:** Applied JSONL records from external training use `payload` field for params, not `params`. The capsule rendering code only checks `rec.params`. Similarly, chart annotation code in `charts.js:76` only checks `rec.params`. +**Fix:** Check both `rec.params` and `rec.payload` in timeline rendering and chart annotations. + +### 1.6 Autopilot Alert Tooltips +**File:** `src/hotcb/server/static/js/controls.js:795-845` +**Problem:** Autopilot action items show truncated `condition_met` (80 chars). No hover tooltip with full detail. +**Fix:** Add `title` attribute with full condition text, and show rule parameters on hover. + +--- + +## 2. Frontend Issues (P1) + +### 2.1 API Error Handling — Missing HTTP Status Check +**File:** `src/hotcb/server/static/js/utils.js:8-15` +**Problem:** `api()` function doesn't check `r.ok` before calling `r.json()`. 404/500 responses cause silent failures or JSON parse errors. +**Fix:** Add `if (!r.ok)` check, return structured error. + +### 2.2 WebSocket Reconnection — No Backoff +**File:** `src/hotcb/server/static/js/websocket.js:16-25` +**Problem:** Reconnects every 3s with no exponential backoff, no max retries, no cleanup of old WS instance. +**Fix:** Add exponential backoff (3s → 6s → 12s → 30s cap), max 20 retries, cleanup old `S.ws`. + +### 2.3 Event Listener Leaks in Metric Dropdown +**File:** `src/hotcb/server/static/js/charts.js:648-653` +**Problem:** `document.addEventListener('click', ...)` is re-registered on every `_renderMetricDropdown()` call. Old listeners reference stale `wrapper` closures. +**Fix:** Use a single delegated listener or remove old listener before adding new one. + +### 2.4 Three.js Memory Leaks +**File:** `src/hotcb/server/static/js/manifold3d.js:70-131` +**Problem:** Scene children removed but geometry/material not disposed. GPU memory accumulates. +**Fix:** Call `.geometry.dispose()` and `.material.dispose()` before removing from scene. + +### 2.5 Chart Tooltip Hardcoded Colors +**File:** `src/hotcb/server/static/js/charts.js:312-313` +**Problem:** Tooltip background/border colors are hardcoded midnight theme values. Don't update on theme switch. +**Fix:** Read from CSS variables on tooltip render, or update chart options in `setTheme()`. + +### 2.6 Forecast Polling Floods +**File:** `src/hotcb/server/static/js/charts.js:824-829` +**Problem:** `fetchAllForecasts()` spawns one request per metric name. With 50+ metrics, this is 50+ concurrent requests every 5s. +**Fix:** Batch forecast API or limit concurrent requests to 5-10. + +### 2.7 Interval Accumulation +**Files:** `src/hotcb/server/static/js/init.js:127,204`, `src/hotcb/server/static/js/controls.js:194,317` +**Problem:** `setInterval` calls not stored in variables or cleared on reset. Multiple intervals can accumulate. +**Fix:** Store interval IDs, clear on reset. + +--- + +## 3. Backend Issues (P1) + +### 3.1 Malformed JSON Crash in read_new_jsonl +**File:** `src/hotcb/util.py:58` +**Problem:** `json.loads(s)` without try-catch. Malformed JSONL lines crash the kernel's command loading. +**Fix:** Wrap in try-except, log warning, skip malformed lines. + +### 3.2 JSONL Append Race Condition +**File:** `src/hotcb/util.py:66` +**Problem:** `append_jsonl()` has no file locking. Concurrent writes from dashboard API + training thread can interleave lines. +**Fix:** Add `fcntl.flock()` around write. + +### 3.3 FreezeState Missing Mode Validation +**File:** `src/hotcb/freeze.py:22-40` +**Problem:** `FreezeState.load()` accepts any string for mode without validation. +**Fix:** Validate mode is in `{"off", "prod", "replay", "replay_adjusted"}`. + +### 3.4 Duplicate Imports +**File:** `src/hotcb/modules/cb/controller.py:7,15-16` +**Problem:** `dataclasses` and `util` imported twice. +**Fix:** Remove duplicate imports. + +### 3.5 Incomplete `all` Optional Dependencies +**File:** `pyproject.toml:38-43` +**Problem:** `all` extras missing `slack_sdk>=3.0` (notifications), `matplotlib>=3.5`, `pandas>=1.5` (bench). +**Fix:** Add missing deps to `all`. + +--- + +## 4. Packaging Issues (P2) + +### 4.1 Missing MANIFEST.in +**Problem:** No `MANIFEST.in` for sdist. `pyproject.toml` package-data globs may not work with all setuptools versions for source distributions. +**Fix:** Create `MANIFEST.in` with `recursive-include` for static files, guidelines, prompt YAML. + +### 4.2 Missing py.typed Marker +**Problem:** No PEP 561 `py.typed` marker. Type checkers don't recognize hotcb as typed. +**Fix:** Create `src/hotcb/py.typed` (empty file). + +--- + +## 5. Accessibility Issues (P2) + +### 5.1 Focus Styles +**File:** `src/hotcb/server/static/css/dashboard.css` +**Problem:** Multiple inputs use `outline: none` without adequate replacement. Violates WCAG 2.4.7. +**Fix:** Use `outline: 2px solid var(--accent); outline-offset: 2px;` or `:focus-visible`. + +### 5.2 Missing ARIA Labels +**File:** `src/hotcb/server/static/index.html` +**Problem:** Icon buttons (pin, close, tour) lack `aria-label`. Inputs lack associated `
+ +
+ Training: -- + | + Autopilot: off + | + Step: -- +
+ + + + +
+ +
+ +
+
+
+
Waiting for metrics...
+
Start training or connect to a running experiment
+
+
-
@@ -255,8 +284,10 @@

Add Autopilot Rule

Experiment Comparison
+
+
@@ -265,6 +296,14 @@

Add Autopilot Rule

Select Runs to Compare
+ +
@@ -301,7 +340,10 @@

Add Autopilot Rule

-
Training Health
+
+ Training Health + +
--
@@ -309,7 +351,9 @@

Add Autopilot Rule

Evaluating...
-
+
+
+
@@ -317,7 +361,7 @@

Add Autopilot Rule

Command Bar
- +
@@ -332,55 +376,7 @@

Add Autopilot Rule

- - -
- lr - - -
-
Learning Rate controls how big each step is. Too high = unstable training. Too low = slow progress.
- -
- wd - - -
-
Weight Decay prevents memorization. Higher = more regularization, less overfitting risk.
- -
- loss_w - - -
-
Loss Weight controls how much this component matters relative to others.
- - - - - - - - +
@@ -448,12 +444,12 @@

Add Autopilot Rule

Two-task training with recipe-driven loss weight shifts.
- + - + s - +
@@ -494,7 +490,7 @@

Add Autopilot Rule

Alerts
-
+
No alerts
diff --git a/src/hotcb/server/static/js/charts.js b/src/hotcb/server/static/js/charts.js index 8f01f11..34804ca 100644 --- a/src/hotcb/server/static/js/charts.js +++ b/src/hotcb/server/static/js/charts.js @@ -9,6 +9,51 @@ var _forecastCache = {}; // metric -> {forecast, mutation} // Highlighted mutation step (set when user clicks a timeline item) var _highlightedMutationStep = null; +// Step range control: 'all' | 'last200' | 'last500' | {min, max} +var _chartStepRange = 'all'; + +// Y-axis normalization: when enabled, each metric is normalized to [0,1] +// Auto-detected on first data: true if metrics have divergent scales +var _chartNormalize = false; +var _chartNormalizeAuto = true; // auto-detection active until user manually toggles + +// Max points to render per dataset (avoids sluggish charts on very long runs) +var _maxRenderPoints = (S.config && S.config.chart) ? S.config.chart.max_render_points : 2000; + +/** + * LTTB (Largest-Triangle-Three-Buckets) downsampling. + * Takes [{x, y}] and returns a reduced array preserving visual shape. + */ +function _lttbDownsample(data, threshold) { + if (data.length <= threshold) return data; + var sampled = [data[0]]; // always keep first + var bucketSize = (data.length - 2) / (threshold - 2); + var a = 0; // index of previously selected point + for (var i = 0; i < threshold - 2; i++) { + // Calculate bucket range + var bStart = Math.floor((i + 1) * bucketSize) + 1; + var bEnd = Math.min(Math.floor((i + 2) * bucketSize) + 1, data.length - 1); + // Average of next bucket (for triangle area calc) + var avgX = 0, avgY = 0, cnt = 0; + var nbStart = Math.floor((i + 2) * bucketSize) + 1; + var nbEnd = Math.min(Math.floor((i + 3) * bucketSize) + 1, data.length - 1); + if (nbStart > data.length - 1) { nbStart = data.length - 1; nbEnd = data.length - 1; } + for (var j = nbStart; j <= nbEnd; j++) { avgX += data[j].x; avgY += data[j].y; cnt++; } + avgX /= cnt; avgY /= cnt; + // Pick point in current bucket with largest triangle area + var maxArea = -1, maxIdx = bStart; + var ax = data[a].x, ay = data[a].y; + for (var k = bStart; k <= bEnd; k++) { + var area = Math.abs((ax - avgX) * (data[k].y - ay) - (ax - data[k].x) * (avgY - ay)); + if (area > maxArea) { maxArea = area; maxIdx = k; } + } + sampled.push(data[maxIdx]); + a = maxIdx; + } + sampled.push(data[data.length - 1]); // always keep last + return sampled; +} + // ---- Mutation annotation plugin for Chart.js ---- var mutationAnnotationPlugin = { id: 'mutationAnnotations', @@ -21,12 +66,19 @@ var mutationAnnotationPlugin = { var top = yScale.top; var bottom = yScale.bottom; + // Determine actual data range — skip annotations outside metric data + var dataMinStep = _getMinStep(); + var dataMaxStep = _getMaxStep(); + if (dataMaxStep === 0) return; // no metric data yet + // Collect visible annotations with pixel positions for staggering var annotations = []; S.appliedData.forEach(function(rec) { var step = rec.step; if (step === undefined || step === null) return; + // Filter to both visible x-axis range AND actual data range if (step < xScale.min || step > xScale.max) return; + if (step < dataMinStep || step > dataMaxStep) return; var x = xScale.getPixelForValue(step); annotations.push({rec: rec, x: x, step: step}); }); @@ -73,12 +125,16 @@ var mutationAnnotationPlugin = { // Build compact label — split into multiple lines if needed var lines = []; - if (rec.params && typeof rec.params === 'object') { - var keys = Object.keys(rec.params); + var annotParams = (rec.params && typeof rec.params === 'object') ? rec.params : + (rec.payload && typeof rec.payload === 'object') ? rec.payload : null; + if (annotParams) { + var keys = Object.keys(annotParams); keys.slice(0, 3).forEach(function(k) { - var v = rec.params[k]; + var v = annotParams[k]; if (typeof v === 'number') { v = v < 0.01 || v > 1e4 ? v.toExponential(1) : parseFloat(v.toPrecision(3)); + } else if (typeof v === 'object' && v !== null) { + v = JSON.stringify(v); } lines.push(k + '\u2192' + v); }); @@ -128,6 +184,14 @@ var mutationAnnotationPlugin = { // Register the plugin globally Chart.register(mutationAnnotationPlugin); +// Custom tooltip positioner — show to the right of the cursor +Chart.Tooltip.positioners.rightOfCursor = function(elements, eventPosition) { + return { + x: eventPosition.x + 15, + y: eventPosition.y + }; +}; + // ---- Linear regression slope helper ---- function _linregSlope(points) { // points: [{step, value}, ...] — returns slope (value per step) @@ -285,13 +349,14 @@ function scrollChartToStep(step) { var xScale = S.chartInstance.scales.x; if (!xScale) return; var range = xScale.max - xScale.min; - if (range <= 0) return; + if (range <= 0) range = 200; // Only adjust if the step is outside the visible range if (step >= xScale.min && step <= xScale.max) return; var half = range / 2; - S.chartInstance.options.scales.x.min = step - half; - S.chartInstance.options.scales.x.max = step + half; + _chartStepRange = { min: Math.max(0, step - half), max: step + half }; + _applyChartStepRange(); S.chartInstance.update('none'); + _updateRangeButtons(); } function createMetricsChart() { @@ -309,16 +374,101 @@ function createMetricsChart() { }, plugins: { legend: {display:false}, - tooltip: { backgroundColor:'#121c2b', borderColor:'#2a4060', borderWidth:1, - titleFont:{family:'JetBrains Mono',size:11}, bodyFont:{family:'JetBrains Mono',size:11} } + tooltip: { backgroundColor: getComputedStyle(document.documentElement).getPropertyValue('--bg-card').trim() || '#121c2b', + borderColor: getComputedStyle(document.documentElement).getPropertyValue('--border-bright').trim() || '#2a4060', borderWidth:1, + titleFont:{family:'JetBrains Mono',size:11}, bodyFont:{family:'JetBrains Mono',size:11}, + usePointStyle: false, boxWidth: 12, boxHeight: 2, + intersect: false, mode: 'index', axis: 'x', position: 'rightOfCursor', + filter: function(item) { + // Hide internal datasets (confidence bands, etc.) + var label = item.dataset.label || ''; + return !label.startsWith('_'); + }, + itemSort: function(a, b) { + // Sort tooltip items by Y-pixel distance to cursor (TensorBoard-style) + var chart = a.chart; + var cursorY = 0; + if (chart && chart._lastEvent) cursorY = chart._lastEvent.y; + var ay = a.element ? a.element.y : 0; + var by = b.element ? b.element.y : 0; + return Math.abs(ay - cursorY) - Math.abs(by - cursorY); + }, + callbacks: { + label: function(ctx) { + var ds = ctx.dataset; + var label = ds.label || ''; + if (label.startsWith('_')) return null; + var rawY = ctx.parsed.y; + // Color swatch via dataset borderColor + var color = ds.borderColor || '#fff'; + var prefix = ''; + // When normalized, show both normalized and raw values + if (_chartNormalize) { + var metricName = label.replace(/ forecast$/, '').replace(/ post-change$/, ''); + var rawPts = S.metricsData[metricName]; + if (rawPts && rawPts.length > 0) { + var step = ctx.parsed.x; + var rawVal = null; + for (var i = 0; i < rawPts.length; i++) { + if (rawPts[i].step >= step) { rawVal = rawPts[i].value; break; } + rawVal = rawPts[i].value; + } + if (rawVal !== null) { + return ' ' + label + ': ' + fmtNum(rawVal); + } + } + } + return ' ' + label + ': ' + fmtNum(rawY); + } + } + } }, - elements: { point: {radius:0, hoverRadius:3}, line: {borderWidth:1.5, tension:0.3} } + interaction: { mode: 'index', axis: 'x', intersect: false }, + elements: { point: {radius:0, hoverRadius:5, hitRadius:10}, line: {borderWidth:1.5, tension: (S.config && S.config.chart) ? S.config.chart.line_tension : 0.15} } } }); } function updateChart() { if (!S.chartInstance) return; + + // Auto-detect normalization need: if enabled metrics span > 100x range, auto-enable + if (_chartNormalizeAuto) { + var globalMin = Infinity, globalMax = -Infinity; + var metricRanges = []; + S.metricNames.forEach(function(name) { + if (!_metricToggleState[name]) return; + var pts = S.metricsData[name] || []; + if (pts.length < 2) return; + var mn = Infinity, mx = -Infinity; + for (var i = Math.max(0, pts.length - 100); i < pts.length; i++) { + if (pts[i].value < mn) mn = pts[i].value; + if (pts[i].value > mx) mx = pts[i].value; + } + metricRanges.push({min: mn, max: mx}); + if (mn < globalMin) globalMin = mn; + if (mx > globalMax) globalMax = mx; + }); + if (metricRanges.length >= 2) { + // Check if the max range across metrics is > 100x the min range + var spans = metricRanges.map(function(r) { return Math.abs(r.max - r.min) || 1e-10; }); + var maxSpan = Math.max.apply(null, spans); + var minSpan = Math.min.apply(null, spans); + var needsNorm = (maxSpan / minSpan > 50) || (Math.abs(globalMax - globalMin) > 0 && ( + metricRanges.some(function(r) { + var mid = (r.max + r.min) / 2; + var halfRange = (globalMax - globalMin) / 2; + return halfRange > 0 && Math.abs(r.max - r.min) / halfRange < 0.01; + }) + )); + if (needsNorm !== _chartNormalize) { + _chartNormalize = needsNorm; + var normBtn = document.getElementById('btnNormalize'); + if (normBtn) normBtn.classList.toggle('btn-accent', _chartNormalize); + } + } + } + var datasets = []; var enabled = new Set(); S.metricNames.forEach(function(name) { @@ -335,83 +485,223 @@ function updateChart() { var pts = S.metricsData[name] || []; var color = getColor(name); - // Live data line + // Compute per-metric min/max for normalization + var metricMin = Infinity, metricMax = -Infinity; + if (_chartNormalize && pts.length > 0) { + pts.forEach(function(p) { + if (p.value < metricMin) metricMin = p.value; + if (p.value > metricMax) metricMax = p.value; + }); + if (metricMax === metricMin) { metricMin -= 0.5; metricMax += 0.5; } + } + var _normFn = (_chartNormalize && metricMax !== metricMin) + ? function(v) { return (v - metricMin) / (metricMax - metricMin); } + : function(v) { return v; }; + + // Live data line — sort by step to prevent backward line connectors + // (validation records can appear between training steps with lower step numbers) + var chartPts = pts.map(function(p) { return {x: p.step, y: _normFn(p.value)}; }); + chartPts.sort(function(a, b) { return a.x - b.x; }); + chartPts = _lttbDownsample(chartPts, _maxRenderPoints); + var _cfgTension = (S.config && S.config.chart) ? S.config.chart.line_tension : 0.15; + var _cfgForecastDash = (S.config && S.config.chart) ? S.config.chart.forecast_dash : [6, 3]; + var _cfgMutationDash = (S.config && S.config.chart) ? S.config.chart.mutation_dash : [3, 4]; + datasets.push({ label: name, - data: pts.map(function(p) { return {x: p.step, y: p.value}; }), + data: chartPts, borderColor: color, backgroundColor: 'transparent', - tension: 0.3, + tension: _cfgTension, }); var lastStep = pts.length ? pts[pts.length - 1].step : 0; var lastVal = pts.length ? pts[pts.length - 1].value : null; + var lastValNorm = lastVal !== null ? _normFn(lastVal) : null; var cache = _forecastCache[name]; + var showOverlays = S.focusMetric === name || (S.pinnedMetrics && S.pinnedMetrics.has(name)); - // Forecast overlay (dotted extension in same color but lighter) - if (cache && cache.forecast && cache.forecast.values && cache.forecast.values.length) { + // Forecast overlay — only shown for pinned or focused metrics to reduce clutter + // Guard: only render if forecast steps are contiguous with current data (avoids stale cross-run connectors) + if (showOverlays && cache && cache.forecast && cache.forecast.values && cache.forecast.values.length + && lastStep > 0 && lastVal !== null + && cache.forecast.steps && cache.forecast.steps[0] <= lastStep + 50) { var fc = cache.forecast; - var fcPts = [{x: lastStep, y: lastVal}]; // Connect to last actual point - fc.values.forEach(function(v, i) { fcPts.push({x: fc.steps[i], y: v}); }); + var fcPts = [{x: lastStep, y: lastValNorm}]; // Connect to last actual point + fc.values.forEach(function(v, i) { fcPts.push({x: fc.steps[i], y: _normFn(v)}); }); datasets.push({ label: name + ' forecast', data: fcPts, borderColor: color, - borderDash: [6, 3], + borderDash: _cfgForecastDash, backgroundColor: 'transparent', - tension: 0.3, + tension: _cfgTension, borderWidth: 1.2, pointRadius: 0, }); // Confidence band if (fc.lower && fc.upper) { - var loPts = [{x: lastStep, y: lastVal}]; - var hiPts = [{x: lastStep, y: lastVal}]; - fc.lower.forEach(function(v, i) { loPts.push({x: fc.steps[i], y: v}); }); - fc.upper.forEach(function(v, i) { hiPts.push({x: fc.steps[i], y: v}); }); + var loPts = [{x: lastStep, y: lastValNorm}]; + var hiPts = [{x: lastStep, y: lastValNorm}]; + fc.lower.forEach(function(v, i) { loPts.push({x: fc.steps[i], y: _normFn(v)}); }); + fc.upper.forEach(function(v, i) { hiPts.push({x: fc.steps[i], y: _normFn(v)}); }); datasets.push({ label: '_' + name + '_lo', data: loPts, borderColor: 'transparent', backgroundColor: hexToRgba(color, 0.06), fill: '+1', - pointRadius: 0, tension: 0.3, + pointRadius: 0, tension: _cfgTension, }); datasets.push({ label: '_' + name + '_hi', data: hiPts, borderColor: 'transparent', backgroundColor: 'transparent', - pointRadius: 0, tension: 0.3, + pointRadius: 0, tension: _cfgTension, }); } } - // Mutation impact overlay (cyan dotted) - if (cache && cache.mutation && cache.mutation.values && cache.mutation.values.length) { + // Mutation impact overlay — only shown for pinned or focused metrics + // Guard: only render if fromStep is within current data range (avoids stale cross-run connectors) + if (showOverlays && cache && cache.mutation && cache.mutation.values && cache.mutation.values.length) { var mu = cache.mutation; - var muPts = [{x: mu.fromStep, y: mu.fromVal}]; - mu.values.forEach(function(v, i) { muPts.push({x: mu.steps[i], y: v}); }); - datasets.push({ - label: name + ' post-change', - data: muPts, - borderColor: 'rgba(51,204,221,0.7)', - borderDash: [3, 4], - backgroundColor: 'transparent', - tension: 0.3, - borderWidth: 1.2, - pointRadius: 0, - }); + var inDataRange = pts.length > 0 && mu.fromStep >= pts[0].step && mu.fromStep <= pts[pts.length - 1].step; + if (inDataRange) { + var muPts = [{x: mu.fromStep, y: _normFn(mu.fromVal)}]; + mu.values.forEach(function(v, i) { muPts.push({x: mu.steps[i], y: _normFn(v)}); }); + datasets.push({ + label: name + ' post-change', + data: muPts, + borderColor: 'rgba(51,204,221,0.7)', + borderDash: _cfgMutationDash, + backgroundColor: 'transparent', + tension: _cfgTension, + borderWidth: 1.2, + pointRadius: 0, + }); + } } }); S.chartInstance.data.datasets = datasets; + + // Apply step range to x-axis + _applyChartStepRange(); + + // Update Y-axis title for normalization mode + var yOpts = S.chartInstance.options.scales.y; + if (_chartNormalize) { + yOpts.title = {display: true, text: 'Normalized [0,1]', color: '#7a8fa3', font: {size: 10}}; + } else { + yOpts.title = {display: false}; + } + S.chartInstance.update('none'); // Also update any pinned metric cards updateMetricCards(); } +function _getMaxStep() { + var maxStep = 0; + S.metricNames.forEach(function(name) { + var pts = S.metricsData[name] || []; + if (pts.length) maxStep = Math.max(maxStep, pts[pts.length - 1].step); + }); + return maxStep; +} + +function _getMinStep() { + var minStep = Infinity; + S.metricNames.forEach(function(name) { + var pts = S.metricsData[name] || []; + if (pts.length) minStep = Math.min(minStep, pts[0].step); + }); + return minStep === Infinity ? 0 : minStep; +} + +function _applyChartStepRange() { + if (!S.chartInstance) return; + var xOpts = S.chartInstance.options.scales.x; + if (_chartStepRange === 'all') { + delete xOpts.min; + delete xOpts.max; + } else if (_chartStepRange === 'last200') { + var mx = _getMaxStep(); + xOpts.min = Math.max(0, mx - 200); + delete xOpts.max; + } else if (_chartStepRange === 'last500') { + var mx2 = _getMaxStep(); + xOpts.min = Math.max(0, mx2 - 500); + delete xOpts.max; + } else if (typeof _chartStepRange === 'object' && _chartStepRange !== null) { + xOpts.min = _chartStepRange.min; + xOpts.max = _chartStepRange.max; + } +} + +function setChartStepRange(mode) { + _chartStepRange = mode; + updateChart(); + _updateRangeButtons(); +} + +function _updateRangeButtons() { + var btns = document.querySelectorAll('#stepRangeControls .range-btn[data-range]'); + btns.forEach(function(btn) { + var isActive = (typeof _chartStepRange === 'string' && btn.dataset.range === _chartStepRange); + btn.classList.toggle('active', isActive); + }); + // Update custom inputs if range is a custom object + var minEl = document.getElementById('rangeMin'); + var maxEl = document.getElementById('rangeMax'); + if (minEl && maxEl) { + if (typeof _chartStepRange === 'object' && _chartStepRange !== null) { + minEl.value = _chartStepRange.min != null ? _chartStepRange.min : ''; + maxEl.value = _chartStepRange.max != null ? _chartStepRange.max : ''; + } + } +} + +function initStepRangeControls() { + var container = document.getElementById('stepRangeControls'); + if (!container) return; + + // Preset buttons + container.querySelectorAll('.range-btn[data-range]').forEach(function(btn) { + btn.addEventListener('click', function() { + setChartStepRange(btn.dataset.range); + }); + }); + + // Custom range apply + var applyBtn = document.getElementById('rangeApply'); + var minEl = document.getElementById('rangeMin'); + var maxEl = document.getElementById('rangeMax'); + if (applyBtn && minEl && maxEl) { + applyBtn.addEventListener('click', function() { + var mn = minEl.value !== '' ? parseInt(minEl.value, 10) : undefined; + var mx = maxEl.value !== '' ? parseInt(maxEl.value, 10) : undefined; + if (mn === undefined && mx === undefined) { + setChartStepRange('all'); + } else { + setChartStepRange({ + min: mn !== undefined ? mn : 0, + max: mx !== undefined ? mx : undefined + }); + } + }); + // Enter key in inputs triggers apply + [minEl, maxEl].forEach(function(el) { + el.addEventListener('keydown', function(e) { + if (e.key === 'Enter') applyBtn.click(); + }); + }); + } +} + function hexToRgba(hex, alpha) { var r = parseInt(hex.slice(1, 3), 16); var g = parseInt(hex.slice(3, 5), 16); @@ -421,6 +711,7 @@ function hexToRgba(hex, alpha) { var _metricToggleState = {}; // metric name -> boolean (checked) var _metricDropdownShowAll = false; +var _dropdownCloseHandler = null; var _commonMetricPatterns = [ 'train_loss', 'val_loss', 'loss', 'accuracy', 'val_accuracy', 'train_accuracy', 'lr', 'learning_rate', 'grad_norm', 'grad_norm_total', @@ -428,6 +719,9 @@ var _commonMetricPatterns = [ 'precision', 'recall', 'auc', 'val_auc', 'bleu' ]; +// Metrics shown by default on chart — losses and key metric only +var _defaultOnPatterns = ['loss', 'val_loss', 'train_loss']; + function _isCommonMetric(name) { var lower = name.toLowerCase(); for (var i = 0; i < _commonMetricPatterns.length; i++) { @@ -468,7 +762,16 @@ function updateMetricToggles() { S.metricNames.forEach(function(name) { currentCount++; if (!(name in _metricToggleState)) { - _metricToggleState[name] = true; // default checked + // Default: only show losses and key metric; others off + var lower = name.toLowerCase(); + var isDefault = false; + for (var di = 0; di < _defaultOnPatterns.length; di++) { + if (lower === _defaultOnPatterns[di] || lower.indexOf(_defaultOnPatterns[di]) !== -1) { isDefault = true; break; } + } + // Also check AI key metric + var aiKeyMetricEl = document.getElementById('aiKeyMetric'); + if (aiKeyMetricEl && aiKeyMetricEl.value === name) isDefault = true; + _metricToggleState[name] = isDefault; } }); @@ -583,49 +886,70 @@ function _renderMetricDropdown(container) { var visible = _getVisibleMetrics(); visible.forEach(function(name) { var color = getColor(name); - var row = document.createElement('label'); + var row = document.createElement('div'); row.className = 'metric-dropdown-item'; - var cb = document.createElement('input'); - cb.type = 'checkbox'; - cb.checked = !!_metricToggleState[name]; - cb.dataset.metric = name; - cb.addEventListener('change', function(e) { + var isActive = !!_metricToggleState[name]; + var isPinned = S.pinnedMetrics && S.pinnedMetrics.has(name); + + // Filled/hollow dot toggle (replaces checkbox + swatch) + var dot = document.createElement('span'); + dot.className = 'metric-dot' + (isActive ? ' active' : '') + (isPinned ? ' pinned' : ''); + dot.style.color = color; + dot.style.borderColor = color; + if (isActive) dot.style.background = color; + dot.title = isActive ? 'Hide metric' : 'Show metric'; + dot.addEventListener('click', function(e) { e.stopPropagation(); - _metricToggleState[name] = cb.checked; - // Update the badge count + _metricToggleState[name] = !_metricToggleState[name]; var cnt = 0; S.metricNames.forEach(function(n) { if (_metricToggleState[n]) cnt++; }); var badge = wrapper.querySelector('.metric-count-badge'); if (badge) badge.textContent = cnt + '/' + totalCount; + _renderMetricDropdown(container); updateChart(); }); - - var swatch = document.createElement('span'); - swatch.className = 'swatch'; - swatch.style.background = color; + // Double-click or right-click to toggle pin + dot.addEventListener('dblclick', function(e) { + e.stopPropagation(); e.preventDefault(); + toggleMetricCard(name); + _renderMetricDropdown(container); + }); + dot.addEventListener('contextmenu', function(e) { + e.preventDefault(); e.stopPropagation(); + toggleMetricCard(name); + _renderMetricDropdown(container); + }); var label = document.createElement('span'); label.className = 'metric-dropdown-name'; label.textContent = name; + label.addEventListener('click', function(e) { + e.stopPropagation(); + _metricToggleState[name] = !_metricToggleState[name]; + var cnt = 0; + S.metricNames.forEach(function(n) { if (_metricToggleState[n]) cnt++; }); + var badge = wrapper.querySelector('.metric-count-badge'); + if (badge) badge.textContent = cnt + '/' + totalCount; + _renderMetricDropdown(container); + updateChart(); + }); - var pinBtn = document.createElement('button'); - var isPinned = S.pinnedMetrics && S.pinnedMetrics.has(name); - pinBtn.className = 'pin-btn' + (isPinned ? ' pin-btn-active' : ''); - pinBtn.dataset.metric = name; - pinBtn.title = isPinned ? 'Unpin metric card' : 'Pin metric card'; - pinBtn.innerHTML = '📌'; - pinBtn.addEventListener('click', function(e) { - e.preventDefault(); + row.appendChild(dot); + row.appendChild(label); + + // Pin button — always visible, toggles pinned state + var pinIcon = document.createElement('span'); + pinIcon.className = 'metric-pin-icon' + (isPinned ? ' pinned' : ''); + pinIcon.textContent = '\u{1F4CC}'; + pinIcon.title = isPinned ? 'Unpin metric' : 'Pin metric (opens mini card)'; + pinIcon.addEventListener('click', function(e) { e.stopPropagation(); toggleMetricCard(name); _renderMetricDropdown(container); }); + row.appendChild(pinIcon); - row.appendChild(cb); - row.appendChild(swatch); - row.appendChild(label); - row.appendChild(pinBtn); list.appendChild(row); }); @@ -642,13 +966,19 @@ function _renderMetricDropdown(container) { panel.appendChild(list); wrapper.appendChild(panel); - // Close dropdown when clicking outside - document.addEventListener('click', function closeDropdown(e) { - if (!wrapper.contains(e.target)) { - var p = wrapper.querySelector('.metric-dropdown-panel'); - if (p) p.classList.remove('open'); - } - }); + // Close dropdown when clicking outside (single delegated listener) + if (!_dropdownCloseHandler) { + _dropdownCloseHandler = function(e) { + var wraps = document.querySelectorAll('.metric-dropdown-wrap'); + wraps.forEach(function(w) { + if (!w.contains(e.target)) { + var p = w.querySelector('.metric-dropdown-panel'); + if (p) p.classList.remove('open'); + } + }); + }; + document.addEventListener('click', _dropdownCloseHandler); + } } // ---- Per-metric pinnable cards ---- @@ -746,12 +1076,18 @@ function updateSingleMetricCard(name) { var color = getColor(name); var pts = S.metricsData[name] || []; + var _cfgTension = (S.config && S.config.chart) ? S.config.chart.line_tension : 0.15; + var _cfgForecastDash = (S.config && S.config.chart) ? S.config.chart.forecast_dash : [6, 3]; + var _cfgMutationDash = (S.config && S.config.chart) ? S.config.chart.mutation_dash : [3, 4]; + + var chartPts = pts.map(function(p) { return {x: p.step, y: p.value}; }); + chartPts = _lttbDownsample(chartPts, 500); // mini cards need fewer points datasets.push({ label: name, - data: pts.map(function(p) { return {x: p.step, y: p.value}; }), + data: chartPts, borderColor: color, backgroundColor: 'transparent', - tension: 0.3, + tension: _cfgTension, }); var lastStep = pts.length ? pts[pts.length - 1].step : 0; @@ -766,8 +1102,8 @@ function updateSingleMetricCard(name) { datasets.push({ label: 'forecast', data: fcPts, - borderColor: color, borderDash: [6, 3], - backgroundColor: 'transparent', tension: 0.3, borderWidth: 1.2, pointRadius: 0, + borderColor: color, borderDash: _cfgForecastDash, + backgroundColor: 'transparent', tension: _cfgTension, borderWidth: 1.2, pointRadius: 0, }); } @@ -779,8 +1115,8 @@ function updateSingleMetricCard(name) { datasets.push({ label: 'post-change', data: muPts, - borderColor: 'rgba(51,204,221,0.7)', borderDash: [3, 4], - backgroundColor: 'transparent', tension: 0.3, borderWidth: 1.2, pointRadius: 0, + borderColor: 'rgba(51,204,221,0.7)', borderDash: _cfgMutationDash, + backgroundColor: 'transparent', tension: _cfgTension, borderWidth: 1.2, pointRadius: 0, }); } @@ -819,17 +1155,31 @@ async function fetchAllForecasts() { return; } _forecastInFlight = true; - var promises = []; - S.metricNames.forEach(function(name) { - promises.push(fetchForecastForMetric(name)); - }); - await Promise.all(promises); + // Only fetch forecasts for pinned or focused metrics — not all metrics + var names = []; + if (S.focusMetric) { + names.push(S.focusMetric); + } + if (S.pinnedMetrics) { + S.pinnedMetrics.forEach(function(name) { + if (names.indexOf(name) === -1) names.push(name); + }); + } + if (names.length === 0) { + _forecastInFlight = false; + return; + } + var batchSize = (S.config && S.config.ui) ? S.config.ui.forecast_batch_size : 8; + for (var i = 0; i < names.length; i += batchSize) { + var batch = names.slice(i, i + batchSize); + await Promise.all(batch.map(fetchForecastForMetric)); + } _forecastInFlight = false; updateChart(); // If new data arrived while we were fetching, refresh again if (_forecastPendingRefresh) { _forecastPendingRefresh = false; - setTimeout(fetchAllForecasts, 500); + setTimeout(fetchAllForecasts, 1000); } } @@ -842,17 +1192,31 @@ async function fetchForecastForMetric(name) { function startForecastPolling() { if (_forecastTimer) clearInterval(_forecastTimer); - // Poll every 5s as a baseline, but also triggered on new metrics - _forecastTimer = setInterval(fetchAllForecasts, 5000); - fetchAllForecasts(); + var _forecastPollMs = (S.config && S.config.ui) ? S.config.ui.forecast_poll_interval : 5000; + // Poll at configured interval — new metrics also trigger refresh via onNewMetricsForForecast + _forecastTimer = setInterval(function() { + // Only fetch if we have pinned/focused metrics to avoid wasted work + if ((S.pinnedMetrics && S.pinnedMetrics.size > 0) || S.focusMetric) { + fetchAllForecasts(); + } + }, _forecastPollMs); + // Delay initial fetch to let the dashboard settle + setTimeout(function() { + if ((S.pinnedMetrics && S.pinnedMetrics.size > 0) || S.focusMetric) { + fetchAllForecasts(); + } + }, 3000); } // Called from websocket when new metrics arrive — trigger forecast refresh function onNewMetricsForForecast(maxStep) { - // Only trigger every 10 steps to avoid flooding - if (maxStep - _lastForecastStep >= 10) { + // Only trigger every N steps (from config), and only if there are pinned/focused metrics + var _forecastCadence = (S.config && S.config.ui) ? S.config.ui.forecast_step_cadence : 10; + if (maxStep - _lastForecastStep >= _forecastCadence) { _lastForecastStep = maxStep; - fetchAllForecasts(); + if ((S.pinnedMetrics && S.pinnedMetrics.size > 0) || S.focusMetric) { + fetchAllForecasts(); + } } } diff --git a/src/hotcb/server/static/js/controls.js b/src/hotcb/server/static/js/controls.js index 825ad68..324b27b 100644 --- a/src/hotcb/server/static/js/controls.js +++ b/src/hotcb/server/static/js/controls.js @@ -1,5 +1,8 @@ /** * hotcb dashboard — knobs, freeze, autopilot + * + * Controls are generated dynamically from /api/config controls field + * (populated by MutableState.describe_all() on the server). */ function lrFromSlider(v) { return Math.pow(10, parseFloat(v)); } @@ -10,6 +13,237 @@ var _applyQueue = []; var _lastApplyStep = 0; var _healthEma = null; // EMA-smoothed health score +// Track last-applied slider values for staged-change highlighting +var _appliedKnobs = {}; + +/* ================================================================ */ +/* Dynamic control generation from actuator metadata */ +/* ================================================================ */ + +function buildControls(controlSpecs) { + var panel = document.getElementById('knobPanel'); + if (!panel) return; + if (!controlSpecs || !controlSpecs.length) { + panel.innerHTML = '
No controls available
'; + return; + } + panel.innerHTML = ''; + + // Group by group field + var groups = {}; + var groupOrder = []; + controlSpecs.forEach(function(spec) { + var g = spec.group || 'other'; + if (!groups[g]) { + groups[g] = []; + groupOrder.push(g); + } + groups[g].push(spec); + }); + + // Render each group + groupOrder.forEach(function(groupName) { + var header = document.createElement('div'); + header.className = 'knob-group-header'; + header.textContent = groupName.charAt(0).toUpperCase() + groupName.slice(1); + panel.appendChild(header); + + groups[groupName].forEach(function(spec) { + panel.appendChild(buildKnobRow(spec)); + }); + }); + + // Wire up input event listeners for staged-change highlighting + _wireKnobInputListeners(); +} + +function buildKnobRow(spec) { + // spec = {param_key, type, label, group, min, max, step, log_base, choices, current, state} + var row = document.createElement('div'); + row.className = 'knob-row'; + row.dataset.param = spec.param_key; + + var label = document.createElement('span'); + label.className = 'knob-label'; + label.textContent = spec.label || spec.param_key; + row.appendChild(label); + + if (spec.type === 'bool') { + // Toggle switch + var toggle = document.createElement('input'); + toggle.type = 'checkbox'; + toggle.className = 'knob-toggle'; + toggle.checked = !!spec.current; + toggle.dataset.param = spec.param_key; + row.appendChild(toggle); + } else if (spec.type === 'choice') { + // Dropdown + var select = document.createElement('select'); + select.className = 'knob-select'; + select.dataset.param = spec.param_key; + (spec.choices || []).forEach(function(c) { + var opt = document.createElement('option'); + opt.value = c; + opt.textContent = c; + if (c === spec.current) opt.selected = true; + select.appendChild(opt); + }); + row.appendChild(select); + } else if (spec.type === 'log_float') { + // Log-scale slider + var logMin = Math.log10(spec.min || 1e-7); + var logMax = Math.log10(spec.max || 1.0); + var logCurrent = (spec.current && spec.current > 0) ? Math.log10(spec.current) : logMin; + + var slider = document.createElement('input'); + slider.type = 'range'; + slider.className = 'knob-slider'; + slider.min = logMin; + slider.max = logMax; + slider.step = spec.step || 0.01; + slider.value = logCurrent; + slider.dataset.param = spec.param_key; + slider.dataset.logScale = 'true'; + row.appendChild(slider); + + var valInput = document.createElement('input'); + valInput.type = 'text'; + valInput.className = 'knob-val'; + valInput.value = spec.current != null ? Number(spec.current).toExponential(2) : ''; + valInput.dataset.param = spec.param_key; + row.appendChild(valInput); + } else { + // Linear slider (float, int) + var slider = document.createElement('input'); + slider.type = 'range'; + slider.className = 'knob-slider'; + slider.min = spec.min != null ? spec.min : 0; + slider.max = spec.max != null ? spec.max : 1; + slider.step = spec.step || 0.01; + slider.value = spec.current != null ? spec.current : 0; + slider.dataset.param = spec.param_key; + row.appendChild(slider); + + var valInput = document.createElement('input'); + valInput.type = 'text'; + valInput.className = 'knob-val'; + valInput.value = spec.current != null ? spec.current : ''; + valInput.dataset.param = spec.param_key; + row.appendChild(valInput); + } + + // State indicator + if (spec.state && spec.state !== 'untouched' && spec.state !== 'verified') { + var stateEl = document.createElement('span'); + stateEl.className = 'knob-state knob-state-' + spec.state; + stateEl.textContent = spec.state; + row.appendChild(stateEl); + } + + return row; +} + +function _wireKnobInputListeners() { + var panel = document.getElementById('knobPanel'); + if (!panel) return; + + // Sliders: sync display value on input + panel.querySelectorAll('.knob-slider').forEach(function(slider) { + slider.addEventListener('input', function(e) { + var param = e.target.dataset.param; + var valEl = panel.querySelector('.knob-val[data-param="' + param + '"]'); + if (!valEl) return; + if (e.target.dataset.logScale === 'true') { + valEl.value = Math.pow(10, parseFloat(e.target.value)).toExponential(2); + } else { + valEl.value = parseFloat(e.target.value).toFixed(2); + } + _markStagedChanges(); + }); + }); + + // Toggles and selects: mark staged changes on change + panel.querySelectorAll('.knob-toggle, .knob-select').forEach(function(el) { + el.addEventListener('change', function() { _markStagedChanges(); }); + }); +} + +/** + * Read current value from a dynamic control row. + */ +function _readKnobValue(row) { + var slider = row.querySelector('.knob-slider'); + var toggle = row.querySelector('.knob-toggle'); + var select = row.querySelector('.knob-select'); + + if (toggle) return toggle.checked; + if (select) return select.value; + if (slider) { + if (slider.dataset.logScale === 'true') { + return Math.pow(10, parseFloat(slider.value)); + } + return parseFloat(slider.value); + } + return undefined; +} + +/** + * Build the module -> param_key mapping for apply commands. + * Groups determine which API endpoint to use: + * "optimizer" -> POST /api/opt/set + * "loss" -> POST /api/loss/set + * Other groups -> POST /api/opt/set (generic param set) + */ +function _getControlSpec(paramKey) { + var specs = (S.config && S.config.controls) || []; + for (var i = 0; i < specs.length; i++) { + if (specs[i].param_key === paramKey) return specs[i]; + } + return null; +} + +function _markStagedChanges() { + var panel = document.getElementById('knobPanel'); + if (!panel) return; + var threshold = (S.config && S.config.ui && S.config.ui.staged_change_threshold) || 0.005; + panel.querySelectorAll('.knob-row[data-param]').forEach(function(row) { + var param = row.dataset.param; + var current = _readKnobValue(row); + var applied = _appliedKnobs[param]; + + if (applied === undefined || current === undefined) { + row.classList.remove('staged'); + return; + } + if (typeof current === 'boolean') { + if (current !== applied) row.classList.add('staged'); + else row.classList.remove('staged'); + } else if (typeof current === 'string') { + if (current !== applied) row.classList.add('staged'); + else row.classList.remove('staged'); + } else { + var denom = Math.max(Math.abs(applied), 1e-12); + if (Math.abs(current - applied) / denom > threshold) { + row.classList.add('staged'); + } else { + row.classList.remove('staged'); + } + } + }); +} + +function _snapshotAppliedKnobs() { + var panel = document.getElementById('knobPanel'); + if (!panel) return; + panel.querySelectorAll('.knob-row[data-param]').forEach(function(row) { + var param = row.dataset.param; + var val = _readKnobValue(row); + if (val !== undefined) _appliedKnobs[param] = val; + }); + // Clear all staged highlights + document.querySelectorAll('.knob-row.staged').forEach(function(el) { el.classList.remove('staged'); }); +} + function debounceApply(fn, delay) { if (_applyDebounceTimer) clearTimeout(_applyDebounceTimer); _applyDebounceTimer = setTimeout(fn, delay || 300); @@ -26,6 +260,9 @@ function _clearTrainingState() { S.chatHistory = []; S.alerts = []; _healthEma = null; + _appliedKnobs = {}; + // Reset WS slider sync flag so next training session re-syncs + if (typeof _slidersInitialized !== 'undefined') _slidersInitialized = false; // Clear focus/zoom state if (S.focusMetric) { S.focusMetric = null; @@ -44,6 +281,15 @@ function _clearTrainingState() { // Clear forecast cache _forecastCache = {}; _highlightedMutationStep = null; + // Clear metric toggle state so old metric names don't persist + if (typeof _metricToggleState !== 'undefined') _metricToggleState = {}; + if (typeof _metricDropdownShowAll !== 'undefined') _metricDropdownShowAll = false; + if (typeof _lastMetricCount !== 'undefined') _lastMetricCount = 0; + // Reset run-reset detection + if (typeof _lastSeenMaxStep !== 'undefined') _lastSeenMaxStep = 0; + // Reset step range to "All" for fresh run + if (typeof _chartStepRange !== 'undefined') _chartStepRange = 'all'; + if (typeof _updateRangeButtons === 'function') _updateRangeButtons(); clearTimelineDedup(); updateMetricToggles(); updateChart(); @@ -57,69 +303,79 @@ function _clearTrainingState() { } function _updateConfigControls(configId) { - var multitaskEls = document.querySelectorAll('.multitask-only'); - var finetuneEls = document.querySelectorAll('.finetune-only'); - var singleLossEls = document.querySelectorAll('.single-loss-only'); - - var isMultitask = configId === 'multitask'; - var isFinetune = configId === 'finetune'; - - multitaskEls.forEach(function(el) { el.style.display = isMultitask ? '' : 'none'; }); - finetuneEls.forEach(function(el) { el.style.display = isFinetune ? '' : 'none'; }); - singleLossEls.forEach(function(el) { el.style.display = isMultitask ? 'none' : ''; }); + // No-op: controls are now generated dynamically from actuator metadata. + // Kept as a stub because initControls and launcher still reference it. } function initControls() { - // Knob sliders - $('#knobLr').addEventListener('input', function(e) { - var v = lrFromSlider(e.target.value); - $('#knobLrVal').value = v.toExponential(2); - }); - $('#knobWd').addEventListener('input', function(e) { - var v = Math.pow(10, parseFloat(e.target.value)); - $('#knobWdVal').value = v.toExponential(2); - }); - $('#knobLossW').addEventListener('input', function(e) { - $('#knobLossWVal').value = parseFloat(e.target.value).toFixed(2); - }); - $('#knobWeightA').addEventListener('input', function(e) { - $('#knobWeightAVal').value = parseFloat(e.target.value).toFixed(2); - }); - $('#knobWeightB').addEventListener('input', function(e) { - $('#knobWeightBVal').value = parseFloat(e.target.value).toFixed(2); - }); + // Dynamic controls — build from config if available + if (S.config && S.config.controls && S.config.controls.length) { + buildControls(S.config.controls); + } - // Apply + // Apply — dynamic: collect all changed params from knobPanel $('#btnApply').addEventListener('click', function() { var btn = $('#btnApply'); btn.textContent = 'Queued...'; btn.disabled = true; debounceApply(async function() { - var lr = lrFromSlider($('#knobLr').value); - var wd = Math.pow(10, parseFloat($('#knobWd').value)); - var optParams = {lr: lr, weight_decay: wd}; - // Include opt_idx for multi-optimizer setups - var optIdxSel = document.getElementById('knobOptIdx'); - if (optIdxSel && optIdxSel.parentElement.style.display !== 'none') { - var idx = parseInt(optIdxSel.value); - if (idx > 0) optParams.opt_idx = idx; + var anyChanged = false; + var threshold = (S.config && S.config.ui && S.config.ui.staged_change_threshold) || 0.005; + + function _isDiff(current, baseline) { + if (baseline === undefined) return true; + if (typeof current === 'boolean') return current !== baseline; + if (typeof current === 'string') return current !== baseline; + var denom = Math.max(Math.abs(baseline), 1e-12); + return Math.abs(current - baseline) / denom > threshold; } - await api('POST', '/api/opt/set', {params: optParams}); - var configId = $('#trainConfig').value; - if (configId === 'multitask') { - var wa = parseFloat($('#knobWeightA').value); - var wb = parseFloat($('#knobWeightB').value); - await api('POST', '/api/loss/set', {params: {weight_a: wa, weight_b: wb}}); - } else { - var lw = parseFloat($('#knobLossW').value); - if (lw !== 1.0) await api('POST', '/api/loss/set', {params: {weight: lw}}); + // Collect changed params grouped by module + var changedByGroup = {}; // group -> {param_key: value} + var panel = document.getElementById('knobPanel'); + if (panel) { + panel.querySelectorAll('.knob-row[data-param]').forEach(function(row) { + var param = row.dataset.param; + var current = _readKnobValue(row); + if (current === undefined) return; + if (!_isDiff(current, _appliedKnobs[param])) return; + + var spec = _getControlSpec(param); + var group = (spec && spec.group) || 'optimizer'; + if (!changedByGroup[group]) changedByGroup[group] = {}; + changedByGroup[group][param] = current; + }); } - if (configId === 'finetune') { - var frozen = $('#knobBackbone').value === '1'; - await api('POST', '/api/cb/set_params', {params: {backbone_frozen: frozen}}); + // Send commands per group + var groups = Object.keys(changedByGroup); + for (var gi = 0; gi < groups.length; gi++) { + var group = groups[gi]; + var params = changedByGroup[group]; + if (group === 'optimizer') { + await api('POST', '/api/opt/set', {params: params}); + } else if (group === 'loss') { + await api('POST', '/api/loss/set', {params: params}); + } else { + // Generic: send each param individually as opt set + var pkeys = Object.keys(params); + for (var pi = 0; pi < pkeys.length; pi++) { + var p = {}; + p[pkeys[pi]] = params[pkeys[pi]]; + await api('POST', '/api/opt/set', {params: p}); + } + } + anyChanged = true; } + + if (!anyChanged) { + btn.textContent = 'No changes'; + btn.disabled = false; + setTimeout(function() { btn.textContent = 'Apply'; }, 1500); + return; + } + + _snapshotAppliedKnobs(); updateChart(); // Snapshot forecast at mutation point var allSteps = []; @@ -139,9 +395,25 @@ function initControls() { $('#btnScheduleSubmit').addEventListener('click', async function() { var step = parseInt($('#schedStep').value); if (!step || step <= 0) return; - var lr = lrFromSlider($('#knobLr').value); - var wd = Math.pow(10, parseFloat($('#knobWd').value)); - await api('POST', '/api/schedule', {at_step: step, module: 'opt', op: 'set_params', params: {lr: lr, weight_decay: wd}}); + // Collect current values from all dynamic controls + var panel = document.getElementById('knobPanel'); + var params = {}; + if (panel) { + panel.querySelectorAll('.knob-row[data-param]').forEach(function(row) { + var param = row.dataset.param; + var val = _readKnobValue(row); + if (val !== undefined) params[param] = val; + }); + } + if (Object.keys(params).length === 0) { + alert('No control values to schedule. Adjust controls first.'); + return; + } + // Determine module from first param's group + var firstKey = Object.keys(params)[0]; + var spec = _getControlSpec(firstKey); + var module = (spec && spec.group === 'loss') ? 'loss' : 'opt'; + await api('POST', '/api/schedule', {at_step: step, module: module, op: 'set_params', params: params}); closeModal('modalSchedule'); }); @@ -341,6 +613,33 @@ function initControls() { $('#themeSelect').addEventListener('change', function(e) { setTheme(e.target.value); }); + + // Hydrate controls from server state (works for both launcher and external training) + hydrateControlsFromServer().then(function() { _snapshotAppliedKnobs(); }); + + // Poll for controls periodically — actuator file may appear after training starts. + // Keep polling until we get MORE controls than the defaults (lr + weight_decay = 2). + var _controlsPollCount = 0; + var _controlsPollMax = 30; // stop after ~90 seconds + setInterval(function() { + _controlsPollCount++; + if (_controlsPollCount > _controlsPollMax) return; + var currentCount = (S.config && S.config.controls) ? S.config.controls.length : 0; + api('GET', '/api/state/controls').then(function(state) { + if (!state || !state.controls || !state.controls.length) return; + // Only rebuild if we got MORE controls or controls changed + if (state.controls.length <= currentCount && currentCount > 2) return; + if (state.controls.length > currentCount || currentCount <= 2) { + buildControls(state.controls); + if (S.config) S.config.controls = state.controls; + if (state.last_opt_params) syncSlidersFromApplied(state.last_opt_params); + if (state.last_loss_params) syncSlidersFromApplied(state.last_loss_params); + _snapshotAppliedKnobs(); + // Stop polling once we have real controls (more than defaults) + if (state.controls.length > 2) _controlsPollCount = _controlsPollMax + 1; + } + }); + }, 3000); } /* ================================================================ */ @@ -352,78 +651,123 @@ async function loadCapabilities() { var caps = await api('GET', '/api/capabilities'); if (!caps || !caps.detected) return; S.capabilities = caps; + // Capabilities are now informational only — controls are generated from config. + } catch(e) { /* ignore */ } +} - // Multi-optimizer: show optimizer selector - var numOpts = caps.num_optimizers || 1; - if (numOpts > 1) { - var multiOptEls = document.querySelectorAll('.multi-opt-only'); - multiOptEls.forEach(function(el) { el.style.display = ''; }); - var sel = document.getElementById('knobOptIdx'); - if (sel) { - sel.innerHTML = ''; - var names = caps.optimizer_names || []; - for (var i = 0; i < numOpts; i++) { - var opt = document.createElement('option'); - opt.value = i; - var label = names[i] ? names[i] : ('optimizer ' + i); - var pg = (caps.num_param_groups || [])[i]; - if (pg) label += ' (' + pg + ' groups)'; - opt.textContent = label; - sel.appendChild(opt); - } +function syncSlidersFromApplied(params) { + if (!params || typeof params !== 'object') return; + var panel = document.getElementById('knobPanel'); + if (!panel) return; + + Object.keys(params).forEach(function(k) { + var value = params[k]; + if (typeof value !== 'number' && typeof value !== 'boolean' && typeof value !== 'string') return; + + // Update baseline + _appliedKnobs[k] = value; + + // Also check aliases: weight_decay -> weight_decay param key + var paramKey = k; + + // Find the matching dynamic control row + var row = panel.querySelector('.knob-row[data-param="' + paramKey + '"]'); + if (!row) return; + + var slider = row.querySelector('.knob-slider'); + var valEl = row.querySelector('.knob-val'); + var toggle = row.querySelector('.knob-toggle'); + var select = row.querySelector('.knob-select'); + + if (toggle && typeof value === 'boolean') { + toggle.checked = value; + } else if (select && typeof value === 'string') { + select.value = value; + } else if (slider && typeof value === 'number') { + if (slider.dataset.logScale === 'true' && value > 0) { + slider.value = Math.log10(value); + if (valEl) valEl.value = value.toExponential(2); + } else { + slider.value = value; + if (valEl) valEl.value = parseFloat(value).toFixed(2); } } + }); + + _markStagedChanges(); +} - // Loss state: show/hide loss controls based on detected keys - if (caps.loss_state_detected && caps.loss_state_keys && caps.loss_state_keys.length > 0) { - // Show loss controls — they might be hidden if no demo config selected - var lossRows = document.querySelectorAll('[data-param="loss_w"]'); - lossRows.forEach(function(el) { el.style.display = ''; }); +async function hydrateControlsFromServer() { + try { + var state = await api('GET', '/api/state/controls'); + if (!state) return; + + // Build dynamic controls from server-provided controls list + if (state.controls && state.controls.length) { + buildControls(state.controls); + // Update S.config.controls so Apply handler can look up specs + if (S.config) S.config.controls = state.controls; } - // Grad clip info - if (caps.grad_clip_value !== null && caps.grad_clip_value !== undefined) { - var clipInfo = caps.grad_clip_wired ? 'wired' : 'advisory'; + // Demo mode gate: hide entire Training card when not in demo mode + if (state.demo_mode === false) { + var trainPanel = document.getElementById('trainPanel'); + if (trainPanel) { + var trainCard = trainPanel.closest('.card'); + if (trainCard) trainCard.style.display = 'none'; + } } - } catch(e) { /* ignore */ } -} -function syncSlidersFromApplied(params) { - if (!params || typeof params !== 'object') return; + // Sync sliders from last applied params + if (state.last_opt_params) syncSlidersFromApplied(state.last_opt_params); + if (state.last_loss_params) syncSlidersFromApplied(state.last_loss_params); - // lr - if ('lr' in params && typeof params.lr === 'number' && params.lr > 0) { - var lrSlider = $('#knobLr'); - var lrDisplay = $('#knobLrVal'); - if (lrSlider && lrDisplay) { - lrSlider.value = Math.log10(params.lr); - lrDisplay.value = params.lr.toExponential(2); + // Module activity detection — controls are now dynamic, no CSS class hiding needed + + // External training: hide demo config dropdown, show attached label + if (state.is_external) { + var trainConfig = document.getElementById('trainConfig'); + var trainConfigDesc = document.getElementById('trainConfigDesc'); + if (trainConfig) trainConfig.style.display = 'none'; + if (trainConfigDesc) trainConfigDesc.textContent = 'External Training (attached)'; } - } - // weight_decay / wd - var wd = ('weight_decay' in params) ? params.weight_decay : ('wd' in params ? params.wd : null); - if (wd !== null && typeof wd === 'number' && wd > 0) { - var wdSlider = $('#knobWd'); - var wdDisplay = $('#knobWdVal'); - if (wdSlider && wdDisplay) { - wdSlider.value = Math.log10(wd); - wdDisplay.value = wd.toExponential(2); + // Sync config from run.json (for non-launcher runs) + var runCfg = state.run_config || {}; + if (runCfg.config_id) { + var sel = document.getElementById('trainConfig'); + // Only sync if it's a known config + if (sel) { + var found = false; + for (var i = 0; i < sel.options.length; i++) { + if (sel.options[i].value === runCfg.config_id) { found = true; break; } + } + if (found) { + sel.value = runCfg.config_id; + _updateConfigControls(runCfg.config_id); + } + } + if (runCfg.max_steps) { + var msEl = document.getElementById('trainMaxSteps'); + if (msEl) msEl.value = runCfg.max_steps; + } } - } - // weight / loss_w / weight_a (linear 0-1) - var lw = ('weight' in params) ? params.weight : - ('loss_w' in params) ? params.loss_w : - ('weight_a' in params) ? params.weight_a : null; - if (lw !== null && typeof lw === 'number') { - var lwSlider = $('#knobLossW'); - var lwDisplay = $('#knobLossWVal'); - if (lwSlider && lwDisplay) { - lwSlider.value = lw; - lwDisplay.value = parseFloat(lw).toFixed(2); + // Sync step counter + if (state.latest_step) { + var stepEl = document.getElementById('stepValue'); + if (stepEl) stepEl.textContent = state.latest_step; } - } + + // Sync autopilot mode + if (state.autopilot_mode && state.autopilot_mode !== 'off') { + var modeSelect = document.getElementById('autopilotMode'); + if (modeSelect) { + modeSelect.value = state.autopilot_mode; + _updateAIConfigVisibility(state.autopilot_mode); + } + } + } catch (e) { /* ignore hydration errors */ } } function setTheme(theme) { @@ -438,6 +782,17 @@ function setTheme(theme) { if (S.featureCtx && S.featureCtx.scene) { S.featureCtx.scene.background = new THREE.Color(bgColor); } + // Update chart tooltip colors for new theme + if (S.chartInstance) { + var cs = getComputedStyle(document.documentElement); + S.chartInstance.options.plugins.tooltip.backgroundColor = cs.getPropertyValue('--bg-card').trim(); + S.chartInstance.options.plugins.tooltip.borderColor = cs.getPropertyValue('--border-bright').trim(); + S.chartInstance.options.scales.x.grid.color = cs.getPropertyValue('--border').trim(); + S.chartInstance.options.scales.y.grid.color = cs.getPropertyValue('--border').trim(); + S.chartInstance.options.scales.x.ticks.color = cs.getPropertyValue('--text-muted').trim(); + S.chartInstance.options.scales.y.ticks.color = cs.getPropertyValue('--text-muted').trim(); + S.chartInstance.update('none'); + } } /** @@ -574,6 +929,10 @@ function setHealth(score, desc) { $('#healthDesc').textContent = desc; } +var _lastSyncedConfigId = null; +var _wasTrainingRunning = false; +var _runConfigName = ''; + async function pollTrainStatus() { try { var res = await api('GET', '/api/train/status'); @@ -582,25 +941,51 @@ async function pollTrainStatus() { var btnStart = $('#btnTrainStart'); var btnStop = $('#btnTrainStop'); if (res.running) { - var info = 'Running since ' + res.started_at; - if (res.config && res.config.max_steps) { - info += ' (' + res.config.max_steps + ' steps)'; - } - if (res.config && res.config.seed !== undefined) { - info += ' seed=' + res.config.seed; - // Backfill seed input so user can see/copy it - var seedInput = document.getElementById('trainSeed'); - if (seedInput && !seedInput.value) seedInput.value = res.config.seed; - } + _wasTrainingRunning = true; + var cfg = res.config || {}; + _runConfigName = cfg.config_name || cfg.config_id || ''; + var info = 'Running: ' + (cfg.config_name || cfg.config_id || '?'); + if (cfg.max_steps) info += ' (' + cfg.max_steps + ' steps)'; + if (cfg.seed !== undefined) info += ' seed=' + cfg.seed; el.textContent = info; el.style.color = 'var(--green, #4ade80)'; btnStart.disabled = true; btnStop.disabled = false; + + // Sync controls to match the running config (once per config change) + if (cfg.config_id && cfg.config_id !== _lastSyncedConfigId) { + _lastSyncedConfigId = cfg.config_id; + var sel = document.getElementById('trainConfig'); + if (sel && sel.value !== cfg.config_id) { + sel.value = cfg.config_id; + _updateConfigControls(cfg.config_id); + var desc = document.getElementById('trainConfigDesc'); + if (desc && cfg.config_name) desc.textContent = cfg.config_name; + } + if (cfg.max_steps) { + var msEl = document.getElementById('trainMaxSteps'); + if (msEl) msEl.value = cfg.max_steps; + } + if (cfg.step_delay !== undefined) { + var sdEl = document.getElementById('trainStepDelay'); + if (sdEl) sdEl.value = cfg.step_delay; + } + if (cfg.seed !== undefined) { + var seedInput = document.getElementById('trainSeed'); + if (seedInput) seedInput.value = cfg.seed; + } + } } else { + // Detect running → stopped transition + if (_wasTrainingRunning) { + _wasTrainingRunning = false; + showRunSummary(_runConfigName); + } el.textContent = 'Stopped'; el.style.color = 'var(--text-muted)'; btnStart.disabled = false; btnStop.disabled = true; + _lastSyncedConfigId = null; } } catch (e) { /* ignore poll errors */ } } @@ -786,9 +1171,11 @@ async function pollAutopilotStatus() { var ruleId = action.rule_id || '?'; var condition = action.condition_met || ''; if (condition.length > 80) condition = condition.substring(0, 77) + '...'; + var fullCondition = action.condition_met || ''; div.innerHTML = badge + ' step ' + step + ' ' + '' + ruleId + '
' + - '' + condition + ''; + '' + condition + ''; // Add "Apply" button for proposed (suggest-mode) actions if (action.status === 'proposed' && action.action_id) { diff --git a/src/hotcb/server/static/js/init.js b/src/hotcb/server/static/js/init.js index fee29dd..659a349 100644 --- a/src/hotcb/server/static/js/init.js +++ b/src/hotcb/server/static/js/init.js @@ -2,7 +2,25 @@ * hotcb dashboard — initialization and data loading */ +function dismissChartWaiting() { + var el = document.getElementById('chartWaiting'); + if (el && !el.classList.contains('hidden')) { + el.classList.add('hidden'); + setTimeout(function() { el.style.display = 'none'; }, 500); + } +} + async function initialLoad() { + // Fetch centralized config before other init + var cfg = await api('GET', '/api/config'); + if (cfg) { + S.config = cfg; + // Build dynamic controls from config + if (cfg.controls && cfg.controls.length) { + buildControls(cfg.controls); + } + } + // Status var status = await api('GET', '/api/status'); if (status) { @@ -10,9 +28,10 @@ async function initialLoad() { if (status.run_dir) S.runs = [status.run_dir]; } - // Metric history - var hist = await api('GET', '/api/metrics/history?last_n=2000'); + // Metric history — load full run for external projects (LTTB handles rendering) + var hist = await api('GET', '/api/metrics/history?last_n=50000'); if (hist && hist.records && hist.records.length > 0) { + dismissChartWaiting(); hist.records.forEach(function(rec) { var step = rec.step || 0; var metrics = rec.metrics || {}; @@ -63,34 +82,26 @@ async function initialLoad() { // Restore controls from server state (overrides stale localStorage) var ctrlState = await api('GET', '/api/state/controls'); if (ctrlState) { - // Sync sliders from latest metrics - var m = ctrlState.latest_metrics || {}; - if (m.lr && m.lr > 0) { - var lrSlider = document.getElementById('knobLr'); - var lrDisplay = document.getElementById('knobLrVal'); - if (lrSlider && lrDisplay) { - lrSlider.value = Math.log10(m.lr); - lrDisplay.value = m.lr.toExponential(2); - } + // Build/rebuild controls from live MutableState data + if (ctrlState.controls && ctrlState.controls.length) { + buildControls(ctrlState.controls); + if (S.config) S.config.controls = ctrlState.controls; } - if (m.weight_decay && m.weight_decay > 0) { - var wdSlider = document.getElementById('knobWd'); - var wdDisplay = document.getElementById('knobWdVal'); - if (wdSlider && wdDisplay) { - wdSlider.value = Math.log10(m.weight_decay); - wdDisplay.value = m.weight_decay.toExponential(2); - } + + // Sync sliders from latest metrics using dynamic sync + var m = ctrlState.latest_metrics || {}; + if (Object.keys(m).length > 0) { + syncSlidersFromApplied(m); } // Sync from last applied opt params as fallback var op = ctrlState.last_opt_params || {}; - if (!m.lr && op.lr && op.lr > 0) { - var lrSlider = document.getElementById('knobLr'); - var lrDisplay = document.getElementById('knobLrVal'); - if (lrSlider && lrDisplay) { - lrSlider.value = Math.log10(op.lr); - lrDisplay.value = op.lr.toExponential(2); - } + if (Object.keys(op).length > 0) { + syncSlidersFromApplied(op); + } + if (ctrlState.last_loss_params && Object.keys(ctrlState.last_loss_params).length > 0) { + syncSlidersFromApplied(ctrlState.last_loss_params); } + // Sync training config var rc = ctrlState.run_config || {}; if (rc.config_id) { @@ -119,12 +130,13 @@ async function initialLoad() { } } - // Load training capabilities and adapt controls + // Load training capabilities (informational) loadCapabilities(); // Periodic updates startForecastPolling(); - setInterval(fetchAlerts, 15000); + var _alertPollMs = (S.config && S.config.ui) ? S.config.ui.alert_poll_interval : 15000; + S._alertInterval = setInterval(fetchAlerts, _alertPollMs); // Show tour for first-time users (with delay to let UI settle) if (shouldShowTour()) { @@ -144,6 +156,18 @@ async function initialLoad() { initTabs(); initControls(); + // Health card collapse toggle + var healthToggle = document.getElementById('healthToggle'); + var healthDetails = document.getElementById('healthDetails'); + if (healthToggle && healthDetails) { + // Start collapsed + healthDetails.classList.add('collapsed'); + healthToggle.classList.add('collapsed'); + healthToggle.addEventListener('click', function() { + healthDetails.classList.toggle('collapsed'); + healthToggle.classList.toggle('collapsed'); + }); + } $('#btnTour').addEventListener('click', startTour); initRecipeEditor(); initAutopilotRulesEditor(); @@ -155,6 +179,17 @@ async function initialLoad() { initConfigWizard(); initCompare(); createMetricsChart(); + initStepRangeControls(); + // Normalize toggle — manual click disables auto-detection + var normBtn = document.getElementById('btnNormalize'); + if (normBtn) { + normBtn.addEventListener('click', function() { + _chartNormalizeAuto = false; // user took control + _chartNormalize = !_chartNormalize; + normBtn.classList.toggle('btn-accent', _chartNormalize); + updateChart(); + }); + } initialLoad(); connectWS(); @@ -176,19 +211,11 @@ async function initialLoad() { if (savedState.pinnedMetrics && savedState.pinnedMetrics.length) { S._pendingPinnedMetrics = savedState.pinnedMetrics; } - if (savedState.knobs) { - if (savedState.knobs.lr) { - var lr = document.getElementById('knobLr'); - if (lr) { lr.value = savedState.knobs.lr; lr.dispatchEvent(new Event('input')); } - } - if (savedState.knobs.wd) { - var wd = document.getElementById('knobWd'); - if (wd) { wd.value = savedState.knobs.wd; wd.dispatchEvent(new Event('input')); } - } - } + // Knob state is now handled by dynamic controls from server } // Persist UI state periodically and before page unload - setInterval(saveUIState, 5000); + var _stateSaveMs = (S.config && S.config.ui) ? S.config.ui.state_save_interval : 5000; + S._saveStateInterval = setInterval(saveUIState, _stateSaveMs); window.addEventListener('beforeunload', saveUIState); })(); diff --git a/src/hotcb/server/static/js/manifold3d.js b/src/hotcb/server/static/js/manifold3d.js index 65b5716..0f66441 100644 --- a/src/hotcb/server/static/js/manifold3d.js +++ b/src/hotcb/server/static/js/manifold3d.js @@ -71,9 +71,16 @@ function render3DPoints(ctx, points, interventionSteps) { if (!ctx) return; var scene = ctx.scene; - // Remove old data points + // Remove old data points and dispose GPU resources var old = scene.children.filter(function(c) { return c.userData && c.userData.isDataPoint; }); - old.forEach(function(o) { scene.remove(o); }); + old.forEach(function(o) { + scene.remove(o); + if (o.geometry) o.geometry.dispose(); + if (o.material) { + if (o.material.map) o.material.map.dispose(); + o.material.dispose(); + } + }); if (!points || points.length === 0) return; diff --git a/src/hotcb/server/static/js/panels.js b/src/hotcb/server/static/js/panels.js index 6698b26..8b12f64 100644 --- a/src/hotcb/server/static/js/panels.js +++ b/src/hotcb/server/static/js/panels.js @@ -9,6 +9,7 @@ var _manifoldRefreshInterval = null; function _startManifoldAutoRefresh() { if (_manifoldRefreshInterval) return; + var _manifoldMs = (S.config && S.config.ui) ? S.config.ui.manifold_refresh_interval : 10000; _manifoldRefreshInterval = setInterval(function() { var activeSubtab = document.querySelector('[data-subtab].active'); if (activeSubtab && activeSubtab.dataset.subtab === 'feature-space') { @@ -16,7 +17,7 @@ function _startManifoldAutoRefresh() { } else { fetchManifold(); } - }, 10000); + }, _manifoldMs); } function _stopManifoldAutoRefresh() { @@ -36,6 +37,9 @@ function initTabs() { area.querySelectorAll('.tab-content[data-tab]').forEach(function(x) { x.classList.remove('active'); }); var target = area.querySelector('.tab-content[data-tab="' + t.dataset.tab + '"]'); if (target) target.classList.add('active'); + // Compare mode: hide right-col panes, show compact status bar + document.body.classList.toggle('compare-active-mode', t.dataset.tab === 'compare'); + if (t.dataset.tab === 'compare') fetchCompareRuns(); if (t.dataset.tab === 'manifold') { fetchManifold(); _startManifoldAutoRefresh(); @@ -92,16 +96,32 @@ function addTimelineItem(rec) { var step = rec.step || '?'; var mod = rec.module || '?'; var desc = rec.op || ''; - var params = rec.params ? JSON.stringify(rec.params) : ''; var decision = rec.decision || rec.status || 'applied'; var source = rec.source || 'interactive'; var sourceColor = source === 'recipe' ? 'var(--yellow, #facc15)' : source === 'autopilot' ? 'var(--cyan, #22d3ee)' : 'var(--text-muted)'; + // Build param capsules instead of raw JSON + var paramCapsules = ''; + var paramSrc = (rec.params && typeof rec.params === 'object') ? rec.params : + (rec.payload && typeof rec.payload === 'object') ? rec.payload : null; + if (paramSrc) { + var keys = Object.keys(paramSrc); + keys.forEach(function(k) { + var v = paramSrc[k]; + if (typeof v === 'number') { + v = v < 0.01 || v > 1e4 ? v.toExponential(1) : parseFloat(v.toPrecision(3)); + } else if (typeof v === 'object' && v !== null) { + v = JSON.stringify(v); + } + paramCapsules += '' + k + '' + v + ''; + }); + } div.innerHTML = 'step ' + step + '' + '' + mod + '' + - '' + desc + ' ' + params + '' + + '' + desc + '' + + '' + paramCapsules + '' + '' + decision + '' + '' + source + ''; @@ -118,6 +138,8 @@ function addTimelineItem(rec) { div.classList.remove('tl-active'); var existing = document.getElementById('impactSummary'); if (existing) existing.remove(); + // Restore user's chosen range instead of staying locked + _applyChartStepRange(); if (S.chartInstance) S.chartInstance.update('none'); return; } @@ -157,7 +179,8 @@ async function fetchRecipe() { function _startRecipeAutoRefresh() { if (_recipeAutoRefresh) return; - _recipeAutoRefresh = setInterval(fetchRecipe, 5000); + var _recipeMs = (S.config && S.config.ui) ? S.config.ui.recipe_refresh_interval : 5000; + _recipeAutoRefresh = setInterval(fetchRecipe, _recipeMs); } function _stopRecipeAutoRefresh() { if (_recipeAutoRefresh) { clearInterval(_recipeAutoRefresh); _recipeAutoRefresh = null; } @@ -166,10 +189,42 @@ function _stopRecipeAutoRefresh() { function renderRecipe() { var list = $('#recipeList'); list.innerHTML = ''; - if (S.recipeEntries.length === 0) { + + // Show applied mutations from current run + if (S.appliedData && S.appliedData.length > 0) { + var appliedHeader = document.createElement('div'); + appliedHeader.style.cssText = 'font-size:9px;font-weight:700;text-transform:uppercase;color:var(--text-muted);padding:6px 4px 4px;letter-spacing:0.5px;border-bottom:1px solid var(--border);margin-bottom:4px;'; + appliedHeader.textContent = 'Applied This Run (' + S.appliedData.length + ')'; + list.appendChild(appliedHeader); + S.appliedData.forEach(function(rec) { + var div = document.createElement('div'); + div.style.cssText = 'display:grid;grid-template-columns:60px 44px 1fr;gap:4px;align-items:center;padding:3px 4px;font-size:9px;font-family:var(--font-mono);color:var(--text-muted);opacity:0.7;'; + var capsules = ''; + if (rec.params && typeof rec.params === 'object') { + Object.keys(rec.params).forEach(function(k) { + var v = rec.params[k]; + if (typeof v === 'number') v = v < 0.01 || v > 1e4 ? v.toExponential(1) : parseFloat(v.toPrecision(3)); + capsules += '' + k + '' + v + ''; + }); + } + div.innerHTML = 'step ' + (rec.step || '?') + '' + + '' + (rec.module || '?') + '' + + '' + (rec.op || '') + ' ' + capsules + ''; + list.appendChild(div); + }); + } + + if (S.recipeEntries.length === 0 && (!S.appliedData || S.appliedData.length === 0)) { list.innerHTML = '
No recipe entries. Click + Add to create one, or use Schedule from Controls.
'; return; } + if (S.recipeEntries.length === 0) return; + + // Recipe section header + var recipeHeader = document.createElement('div'); + recipeHeader.style.cssText = 'font-size:9px;font-weight:700;text-transform:uppercase;color:var(--accent);padding:8px 4px 4px;letter-spacing:0.5px;border-bottom:1px solid var(--border);margin-bottom:4px;'; + recipeHeader.textContent = 'Scheduled Recipe (' + S.recipeEntries.length + ')'; + list.appendChild(recipeHeader); S.recipeEntries.forEach(function(entry, idx) { var step = entry.at_step !== undefined ? entry.at_step : (entry.at ? entry.at.step : (entry.step || '?')); var mod = entry.module || '?'; @@ -635,7 +690,8 @@ async function fetchAutopilotRules() { function _startRulesAutoRefresh() { if (_rulesAutoRefresh) return; - _rulesAutoRefresh = setInterval(fetchAutopilotRules, 5000); + var _rulesMs = (S.config && S.config.ui) ? S.config.ui.recipe_refresh_interval : 5000; + _rulesAutoRefresh = setInterval(fetchAutopilotRules, _rulesMs); } function _stopRulesAutoRefresh() { if (_rulesAutoRefresh) { clearInterval(_rulesAutoRefresh); _rulesAutoRefresh = null; } @@ -850,7 +906,12 @@ function initAutopilotRulesEditor() { /* Compare Runs */ /* ================================================================ */ var _compareChart = null; -var _compareRunColorPalette = ['#00d4aa', '#3d9eff', '#ff9833', '#ff4d5e', '#9966ff', '#33dd77', '#ff66aa', '#66ddff', '#aadd33', '#dd66ff']; +var _compareRunColorPalette = [ + '#00d4aa', '#3d9eff', '#ff9833', '#ff4d5e', '#9966ff', + '#33dd77', '#ff66aa', '#66ddff', '#aadd33', '#dd66ff', + '#ff8800', '#00aaff', '#cc44cc', '#44cc88', '#ffcc00', + '#ee5577', '#77ccee', '#bbaa33', '#aa55ee', '#55bbaa', +]; var _compareRunColorMap = {}; // runId -> color (stable mapping) var _selectedCompareRuns = new Set(); var _compareAllData = {}; // runId -> records[] @@ -858,6 +919,9 @@ var _compareMetricNames = new Set(); var _compareEnabledMetrics = {}; // name -> bool var _compareZoomed = false; var _compareRunMeta = {}; // runId -> run metadata +var _compareNormalize = false; +// External directories loaded for comparison +var _compareExternalRuns = []; // [{run_id, dir, label, ...}] function _getCompareRunColor(runId) { if (!_compareRunColorMap[runId]) { @@ -868,8 +932,10 @@ function _getCompareRunColor(runId) { return _compareRunColorPalette[i]; } } - // All colors used, cycle - _compareRunColorMap[runId] = _compareRunColorPalette[Object.keys(_compareRunColorMap).length % _compareRunColorPalette.length]; + // Generate a unique color via HSL when palette exhausted + var idx = Object.keys(_compareRunColorMap).length; + var hue = (idx * 137.508) % 360; // golden angle for distinct hues + _compareRunColorMap[runId] = 'hsl(' + Math.round(hue) + ',70%,55%)'; } return _compareRunColorMap[runId]; } @@ -886,26 +952,89 @@ function initCompare() { if (_compareChart) _compareChart.resize(); _updateCompareOverlayInfo(); }); + + // Normalize toggle for compare chart + var normBtn = document.getElementById('btnNormalizeCompare'); + if (normBtn) { + normBtn.addEventListener('click', function() { + _compareNormalize = !_compareNormalize; + normBtn.classList.toggle('btn-accent', _compareNormalize); + _rebuildCompareChart(); + }); + } + + // External dir loader + var loadDirBtn = document.getElementById('btnCompareLoadDir'); + var loadDirForm = document.getElementById('compareLoadDirForm'); + if (loadDirBtn && loadDirForm) { + loadDirBtn.addEventListener('click', function() { + loadDirForm.style.display = loadDirForm.style.display === 'none' ? 'block' : 'none'; + }); + var cancelBtn = document.getElementById('btnCompareLoadDirCancel'); + if (cancelBtn) cancelBtn.addEventListener('click', function() { + loadDirForm.style.display = 'none'; + }); + var submitBtn = document.getElementById('btnCompareLoadDirSubmit'); + if (submitBtn) submitBtn.addEventListener('click', _loadExternalDir); + } } -async function fetchCompareRuns() { - var data = await api('GET', '/api/train/runs/history'); - if (!data || !data.runs) return; +async function _loadExternalDir() { + var dirInput = document.getElementById('compareExternalDir'); + if (!dirInput || !dirInput.value.trim()) return; + var dirPath = dirInput.value.trim(); - // Store run metadata + var data = await api('POST', '/api/runs/load-external', {dir: dirPath}); + if (!data || !data.runs || data.runs.length === 0) { + alert('No runs found in: ' + dirPath); + return; + } + + // Store as external runs and add to run list data.runs.forEach(function(run) { - if (run.run_id) _compareRunMeta[run.run_id] = run; + run._external = true; + run._source_dir = dirPath; + _compareExternalRuns.push(run); + _compareRunMeta[run.run_id] = run; + }); + + // Hide form, refresh the run list display + document.getElementById('compareLoadDirForm').style.display = 'none'; + dirInput.value = ''; + _refreshCompareRunList(); +} + +function _refreshCompareRunList() { + var list = $('#compareRunList'); + list.innerHTML = ''; + + // Combine discovered + external runs + var allRuns = []; + // Add external runs + _compareExternalRuns.forEach(function(run) { allRuns.push(run); }); + + // Also trigger normal discover to add local runs + api('GET', '/api/runs/discover').then(function(data) { + if (data && data.runs) { + data.runs.forEach(function(run) { + _compareRunMeta[run.run_id] = run; + allRuns.push(run); + }); + } + _renderCompareRunList(allRuns); }); +} +function _renderCompareRunList(runs) { var list = $('#compareRunList'); list.innerHTML = ''; - if (data.runs.length === 0) { - list.innerHTML = '
No completed runs yet. Start and complete a training run first.
'; + if (runs.length === 0) { + list.innerHTML = '
No runs found.
'; return; } - data.runs.forEach(function(run, idx) { + runs.forEach(function(run) { var div = document.createElement('div'); var color = _getCompareRunColor(run.run_id); var isSelected = _selectedCompareRuns.has(run.run_id); @@ -914,15 +1043,14 @@ async function fetchCompareRuns() { 'background:' + (isSelected ? color + '11' : 'transparent') + ';transition:all 0.15s;'; var dot = ''; - var configLabel = run.config_name || run.config_id || '?'; - var runId = run.run_id || '?'; - var finalLoss = run.final_metrics && run.final_metrics.train_loss - ? run.final_metrics.train_loss.toFixed(4) : '--'; + var configLabel = run.label || run.run_id || '?'; + var stepInfo = run.step_count ? run.step_count + ' steps' : '--'; + var sourceTag = run._external ? 'EXT' : ''; div.innerHTML = dot + '
' + - '
' + configLabel + '
' + - '
' + runId + ' · loss: ' + finalLoss + '
' + + '
' + configLabel + sourceTag + '
' + + '
' + run.run_id + ' · ' + stepInfo + '
' + '
'; div.addEventListener('click', function() { @@ -942,6 +1070,22 @@ async function fetchCompareRuns() { }); } +async function fetchCompareRuns() { + var data = await api('GET', '/api/runs/discover'); + var allRuns = []; + if (data && data.runs) { + data.runs.forEach(function(run) { + if (run.run_id) _compareRunMeta[run.run_id] = run; + allRuns.push(run); + }); + } + // Merge external runs + _compareExternalRuns.forEach(function(run) { + allRuns.push(run); + }); + _renderCompareRunList(allRuns); +} + function _updateCompareMetricToggles() { var container = $('#compareMetricToggles'); if (!container) return; @@ -1069,32 +1213,57 @@ function _rebuildCompareChart() { // Mutation annotation plugin for compare chart var compareAnnotations = []; - // Build metric-level dash patterns: first metric solid, rest dashed variants - var metricDashPatterns = [[], [6, 3], [2, 2], [8, 4, 2, 4], [4, 2], [10, 3]]; + // Color by metric name (consistent across runs), dash pattern by run/experiment + var runDashPatterns = [[], [6, 3], [3, 3], [8, 3, 2, 3], [4, 2], [10, 3]]; + + // Pre-compute per-metric min/max for normalization (across all runs) + var _cmpMetricRange = {}; + if (_compareNormalize) { + enabledMetrics.forEach(function(metricName) { + var mn = Infinity, mx = -Infinity; + runIds.forEach(function(runId) { + var records = _compareAllData[runId] || []; + records.forEach(function(rec) { + var v = (rec.metrics || {})[metricName]; + if (typeof v === 'number') { + if (v < mn) mn = v; + if (v > mx) mx = v; + } + }); + }); + if (mx === mn) { mn -= 0.5; mx += 0.5; } + _cmpMetricRange[metricName] = {min: mn, max: mx}; + }); + } runIds.forEach(function(runId, runIdx) { var records = _compareAllData[runId] || []; - var color = _getCompareRunColor(runId); + var dashPattern = runDashPatterns[runIdx % runDashPatterns.length]; enabledMetrics.forEach(function(metricName, metricIdx) { var points = []; + var range = _cmpMetricRange[metricName]; records.forEach(function(rec) { var metrics = rec.metrics || {}; if (metricName in metrics) { - points.push({x: rec.step, y: metrics[metricName]}); + var v = metrics[metricName]; + var y = (_compareNormalize && range) ? (v - range.min) / (range.max - range.min) : v; + points.push({x: rec.step, y: y, _rawY: v}); } }); if (points.length === 0) return; - // All metrics for the same run share the same color, differentiated by dash pattern - var dashPattern = metricDashPatterns[metricIdx % metricDashPatterns.length]; + // Color by metric name for consistency across runs + var color = typeof getColor === 'function' ? getColor(metricName) : _getCompareRunColor(runId); + var meta = _compareRunMeta[runId] || {}; + var runLabel = (meta.label || meta.config_name || runId).substring(0, 12); datasets.push({ - label: runId.substring(0, 8) + ' · ' + metricName, + label: runLabel + ' · ' + metricName, data: points, borderColor: color, backgroundColor: 'transparent', - tension: 0.3, + tension: 0.15, pointRadius: 0, borderWidth: 2, borderDash: dashPattern, @@ -1116,7 +1285,7 @@ function _rebuildCompareChart() { animation: false, scales: { x: {type: 'linear', title: {display: true, text: 'Step', color: '#7a8fa3', font:{size:11}}, ticks: {color:'#7a8fa3'}, grid: {color: 'rgba(30,46,68,0.5)'}}, - y: {title: {display: false}, ticks: {color:'#7a8fa3'}, grid: {color: 'rgba(30,46,68,0.3)'}}, + y: {title: {display: _compareNormalize, text: 'Normalized [0,1]', color: '#7a8fa3', font:{size:10}}, ticks: {color:'#7a8fa3'}, grid: {color: 'rgba(30,46,68,0.3)'}}, }, plugins: { legend: { @@ -1136,6 +1305,19 @@ function _rebuildCompareChart() { tooltip: { backgroundColor:'#121c2b', borderColor:'#2a4060', borderWidth:1, titleFont:{family:'JetBrains Mono',size:11}, bodyFont:{family:'JetBrains Mono',size:10}, + usePointStyle: false, boxWidth: 12, boxHeight: 2, + intersect: false, mode: 'index', axis: 'x', + filter: function(item) { + var label = item.dataset.label || ''; + return label.indexOf('mutations') === -1; + }, + itemSort: function(a, b) { + var chart = a.chart; + var cursorY = (chart && chart._lastEvent) ? chart._lastEvent.y : 0; + var ay = a.element ? a.element.y : 0; + var by = b.element ? b.element.y : 0; + return Math.abs(ay - cursorY) - Math.abs(by - cursorY); + }, callbacks: { label: function(ctx) { var raw = ctx.raw; @@ -1157,15 +1339,18 @@ function _rebuildCompareChart() { if (raw._metric) parts.push('on: ' + raw._metric); return parts; } - // Regular line tooltip — show run color swatch info + // Regular line tooltip — show raw value when normalized var ds = ctx.dataset; var val = typeof ctx.parsed.y === 'number' ? ctx.parsed.y.toPrecision(5) : ctx.parsed.y; - return ds.label + ': ' + val; + if (_compareNormalize && raw && raw._rawY !== undefined) { + return ' ' + ds.label + ': ' + fmtNum(raw._rawY); + } + return ' ' + ds.label + ': ' + val; } } }, }, - elements: { point: {radius:0, hoverRadius:3}, line: {tension:0.3} }, + elements: { point: {radius:0, hoverRadius:3}, line: {tension:0.15} }, }, }); @@ -1309,19 +1494,27 @@ async function updateCompareChart() { for (var i = 0; i < runIds.length; i++) { (function(runId) { - // Fetch metrics - var p1 = api('GET', '/api/train/runs/' + runId + '/metrics').then(function(data) { + var meta = _compareRunMeta[runId] || {}; + var isExternal = !!meta._external; + var runDir = meta.dir || ''; + + // Fetch metrics — use external API for external runs + var metricsUrl = isExternal + ? '/api/runs/external/metrics?dir=' + encodeURIComponent(runDir) + '&last_n=50000' + : '/api/train/runs/' + runId + '/metrics'; + var p1 = api('GET', metricsUrl).then(function(data) { if (data && data.records) { _compareAllData[runId] = data.records; - // Discover metric names + // Discover metric names — default only losses/key metric ON data.records.forEach(function(rec) { var metrics = rec.metrics || {}; Object.keys(metrics).forEach(function(name) { if (typeof metrics[name] === 'number') { _compareMetricNames.add(name); if (!(name in _compareEnabledMetrics)) { - // Default: enable all discovered metrics - _compareEnabledMetrics[name] = true; + var lower = name.toLowerCase(); + var isLoss = lower.indexOf('loss') !== -1; + _compareEnabledMetrics[name] = isLoss; } } }); @@ -1329,7 +1522,10 @@ async function updateCompareChart() { } }); // Fetch applied data for mutation markers - var p2 = api('GET', '/api/train/runs/' + runId + '/applied').then(function(data) { + var appliedUrl = isExternal + ? '/api/runs/external/applied?dir=' + encodeURIComponent(runDir) + '&last_n=200' + : '/api/train/runs/' + runId + '/applied'; + var p2 = api('GET', appliedUrl).then(function(data) { if (data && data.records) { _compareAllData['_applied_' + runId] = data.records; } @@ -1362,6 +1558,165 @@ async function updateCompareChart() { _rebuildCompareChart(); } +/* ================================================================ */ +/* End-of-run summary */ +/* ================================================================ */ + +function showRunSummary(configName) { + // Gather metric data + var metricNames = Object.keys(S.metricsData); + if (metricNames.length === 0) return; // no data, nothing to show + + // Build per-metric start/end/delta + var rows = []; + var maxStep = 0; + metricNames.forEach(function(name) { + var series = S.metricsData[name]; + if (!series || series.length === 0) return; + var first = series[0]; + var last = series[series.length - 1]; + if (last.step > maxStep) maxStep = last.step; + var delta = last.value - first.value; + var pctChange = first.value !== 0 ? (delta / Math.abs(first.value)) * 100 : 0; + rows.push({ + name: name, + start: first.value, + end: last.value, + delta: delta, + pctChange: pctChange, + direction: delta < -0.001 ? 'down' : delta > 0.001 ? 'up' : 'flat' + }); + }); + + // Sort by absolute delta (largest change first) + rows.sort(function(a, b) { return Math.abs(b.delta) - Math.abs(a.delta); }); + + // Count mutations + var mutationCount = S.appliedData ? S.appliedData.length : 0; + + // Find most impactful mutation + var bestMove = _findBestMutation(); + + // Build modal HTML + var overlay = document.createElement('div'); + overlay.className = 'run-summary-overlay'; + overlay.addEventListener('click', function(e) { + if (e.target === overlay) overlay.remove(); + }); + + var panel = document.createElement('div'); + panel.className = 'run-summary-panel'; + + var title = configName || 'Training'; + panel.innerHTML = '

Run Complete — ' + _esc(title) + '

' + + '
' + maxStep + ' steps · ' + + metricNames.length + ' metrics · ' + mutationCount + ' mutations applied
'; + + // Metrics table + var table = ''; + var showCount = Math.min(rows.length, 12); + for (var i = 0; i < showCount; i++) { + var r = rows[i]; + var cls = r.direction === 'down' ? 'negative' : r.direction === 'up' ? 'positive' : 'neutral'; + var arrow = r.direction === 'down' ? '▼' : r.direction === 'up' ? '▲' : '▬'; + var sign = r.delta > 0 ? '+' : ''; + table += '' + + '' + + '' + + '' + + ''; + } + if (rows.length > showCount) { + table += ''; + } + table += '
MetricStartEndChange
' + _esc(r.name) + '' + _fmtNum(r.start) + '' + _fmtNum(r.end) + '' + sign + _fmtNum(r.delta) + ' (' + sign + r.pctChange.toFixed(1) + '%)' + arrow + '
' + + (rows.length - showCount) + ' more metrics...
'; + panel.innerHTML += table; + + // Best move highlight + if (bestMove) { + panel.innerHTML += '
Best move: ' + + _esc(bestMove.desc) + '
'; + } + + // Action buttons + panel.innerHTML += '
' + + '' + + '' + + '
'; + + overlay.appendChild(panel); + document.body.appendChild(overlay); + + // Wire button events + var btnClose = document.getElementById('summaryBtnClose'); + if (btnClose) btnClose.addEventListener('click', function() { overlay.remove(); }); + var btnCompare = document.getElementById('summaryBtnCompare'); + if (btnCompare) btnCompare.addEventListener('click', function() { + overlay.remove(); + // Switch to Compare tab + var tab = document.querySelector('.tab[data-tab="compare"]'); + if (tab) tab.click(); + }); +} + +function _findBestMutation() { + if (!S.appliedData || S.appliedData.length === 0) return null; + + var best = null; + var bestImpact = 0; + + S.appliedData.forEach(function(mut) { + var step = mut.step; + if (!step) return; + + // Look at metric deltas in the 20 steps after this mutation + var metricNames = Object.keys(S.metricsData); + metricNames.forEach(function(name) { + var series = S.metricsData[name]; + if (!series || series.length < 3) return; + + // Find metric value at mutation step and 20 steps after + var atMut = null, afterMut = null; + for (var i = 0; i < series.length; i++) { + if (series[i].step >= step && atMut === null) atMut = series[i].value; + if (series[i].step >= step + 20 && afterMut === null) { + afterMut = series[i].value; + break; + } + } + if (atMut !== null && afterMut !== null) { + var impact = Math.abs(afterMut - atMut); + if (impact > bestImpact) { + bestImpact = impact; + var pct = atMut !== 0 ? ((afterMut - atMut) / Math.abs(atMut) * 100).toFixed(1) : '?'; + var dir = afterMut < atMut ? 'dropped' : 'increased'; + var desc = (mut.module || '?') + '.' + (mut.op || '?'); + if (mut.params) { + var keys = Object.keys(mut.params); + if (keys.length > 0) desc += ' (' + keys.map(function(k) { return k + '=' + mut.params[k]; }).join(', ') + ')'; + } + best = { desc: desc + ' at step ' + step + ' → ' + name + ' ' + dir + ' ' + pct + '%' }; + } + } + }); + }); + + return best; +} + +function _fmtNum(v) { + if (Math.abs(v) < 0.001 && v !== 0) return v.toExponential(2); + if (Math.abs(v) >= 1000) return v.toFixed(1); + return v.toPrecision(4); +} + +function _esc(s) { + var d = document.createElement('div'); + d.textContent = s; + return d.innerHTML; +} + /* ================================================================ */ /* Multi-run selector */ /* ================================================================ */ diff --git a/src/hotcb/server/static/js/state.js b/src/hotcb/server/static/js/state.js index da6645d..89f93c2 100644 --- a/src/hotcb/server/static/js/state.js +++ b/src/hotcb/server/static/js/state.js @@ -5,6 +5,7 @@ const COLORS = ['#00d4aa','#3d9eff','#ff9833','#ff4d5e','#9966ff','#33dd77','#ffd233','#33ccdd','#ff66aa','#aabb44']; const S = { + config: null, // populated from /api/config at startup ws: null, metricsData: {}, // {metricName: [{step, value}]} appliedData: [], @@ -54,11 +55,19 @@ function saveUIState() { state.pinnedMetrics = Array.from(S.pinnedMetrics || []); } // Don't persist focusMetric — it should reset with the session - // Knobs - var lr = document.getElementById('knobLr'); - var wd = document.getElementById('knobWd'); - if (lr) state.knobs.lr = lr.value; - if (wd) state.knobs.wd = wd.value; + // Knobs — save all dynamic control values + var panel = document.getElementById('knobPanel'); + if (panel) { + panel.querySelectorAll('.knob-row[data-param]').forEach(function(row) { + var param = row.dataset.param; + var slider = row.querySelector('.knob-slider'); + var toggle = row.querySelector('.knob-toggle'); + var select = row.querySelector('.knob-select'); + if (toggle) state.knobs[param] = toggle.checked; + else if (select) state.knobs[param] = select.value; + else if (slider) state.knobs[param] = slider.value; + }); + } // Metric visibility if (S.metricNames) { S.metricNames.forEach(function(name) { diff --git a/src/hotcb/server/static/js/utils.js b/src/hotcb/server/static/js/utils.js index 6d4d8c9..b45cb97 100644 --- a/src/hotcb/server/static/js/utils.js +++ b/src/hotcb/server/static/js/utils.js @@ -10,6 +10,10 @@ async function api(method, path, body) { if (body) opts.body = JSON.stringify(body); try { const r = await fetch(path, opts); + if (!r.ok) { + console.warn('API', r.status, path); + try { return await r.json(); } catch(_) { return null; } + } return await r.json(); } catch (e) { console.error('API error:', path, e); return null; } } diff --git a/src/hotcb/server/static/js/websocket.js b/src/hotcb/server/static/js/websocket.js index 33ff630..0858456 100644 --- a/src/hotcb/server/static/js/websocket.js +++ b/src/hotcb/server/static/js/websocket.js @@ -2,12 +2,18 @@ * hotcb dashboard — WebSocket connection and message handling */ +var _wsRetryCount = 0; +var _wsMaxRetries = (S.config && S.config.server) ? S.config.server.ws_max_retries : 20; +var _slidersInitialized = false; +var _lastSeenMaxStep = 0; // track for run-reset detection + function connectWS() { var proto = location.protocol === 'https:' ? 'wss:' : 'ws:'; var ws = new WebSocket(proto + '//' + location.host + '/ws'); S.ws = ws; ws.onopen = function() { + _wsRetryCount = 0; $('#wsStatus').className = 'status-dot ok'; $('#wsLabel').textContent = 'connected'; ws.send(JSON.stringify({channels: ['metrics', 'applied', 'mutations', 'segments']})); @@ -16,7 +22,13 @@ function connectWS() { ws.onclose = function() { $('#wsStatus').className = 'status-dot error'; $('#wsLabel').textContent = 'disconnected'; - setTimeout(connectWS, 3000); + if (_wsRetryCount < _wsMaxRetries) { + var _wsRetryBaseMs = ((S.config && S.config.server) ? S.config.server.ws_retry_base : 3) * 1000; + var _wsRetryCapMs = ((S.config && S.config.server) ? S.config.server.ws_retry_cap : 30) * 1000; + var delay = Math.min(_wsRetryBaseMs * Math.pow(1.5, _wsRetryCount), _wsRetryCapMs); + _wsRetryCount++; + setTimeout(connectWS, delay); + } }; ws.onerror = function() { @@ -32,6 +44,53 @@ function connectWS() { if (!Array.isArray(data)) return; if (ch === 'metrics') { + if (typeof dismissChartWaiting === 'function') dismissChartWaiting(); + // Skip WS initial burst for metrics — REST initialLoad() already fetched + // the full history (50k records). WS burst is only ~500 records and would + // cause duplicates / data loss. Only accept live incremental updates. + if (msg.initial) { + // Just update _lastSeenMaxStep from the burst so run-reset detection works + if (data.length > 0) { + var lastRec = data[data.length - 1]; + var bStep = lastRec.step || 0; + if (bStep > _lastSeenMaxStep) _lastSeenMaxStep = bStep; + } + // Sync sliders from burst data if not yet done + if (!_slidersInitialized && data.length > 0) { + _slidersInitialized = true; + var last = data[data.length - 1]; + var lm = last.metrics || {}; + var syncObj = {}; + var specs = (S.config && S.config.controls) || []; + specs.forEach(function(spec) { + if (lm[spec.param_key] !== undefined) syncObj[spec.param_key] = lm[spec.param_key]; + }); + if (lm.lr && lm.lr > 0) syncObj.lr = lm.lr; + if (lm.weight_decay && lm.weight_decay > 0) syncObj.weight_decay = lm.weight_decay; + if (Object.keys(syncObj).length > 0) syncSlidersFromApplied(syncObj); + } + return; // skip adding initial burst data — already loaded via REST + } + // Detect run reset: if incoming steps jump backwards, clear stale state + if (data.length > 0) { + var firstIncoming = data[0].step || 0; + if (_lastSeenMaxStep > 0 && firstIncoming < _lastSeenMaxStep - 10) { + // Steps went backwards — new run started. Flush stale caches. + S.metricsData = {}; + S.appliedData = []; + S.metricNames = new Set(); + S.latestMetrics = {}; + if (typeof _forecastCache !== 'undefined') _forecastCache = {}; + if (typeof _highlightedMutationStep !== 'undefined') _highlightedMutationStep = null; + _lastSeenMaxStep = 0; + _slidersInitialized = false; + var tl = document.getElementById('timelineList'); + if (tl) tl.innerHTML = ''; + if (typeof clearTimelineDedup === 'function') clearTimelineDedup(); + var mc = document.getElementById('mutationCount'); + if (mc) mc.textContent = '0'; + } + } var prevSize = S.metricNames.size; data.forEach(function(rec) { var step = rec.step || 0; @@ -42,10 +101,10 @@ function connectWS() { if (typeof value !== 'number') return; S.metricNames.add(name); if (!S.metricsData[name]) S.metricsData[name] = []; - S.metricsData[name].push({step: step, value: value}); - if (S.metricsData[name].length > 5000) { - S.metricsData[name] = S.metricsData[name].slice(-4000); - } + // Skip duplicate steps (can happen with REST + WS initial burst overlap) + var arr = S.metricsData[name]; + if (arr.length > 0 && arr[arr.length - 1].step >= step) return; + arr.push({step: step, value: value}); }); }); if (S.metricNames.size !== prevSize) { @@ -61,8 +120,12 @@ function connectWS() { var pts = S.metricsData[name] || []; if (pts.length) maxStep = Math.max(maxStep, pts[pts.length-1].step); }); + _lastSeenMaxStep = maxStep; var stepEl = document.getElementById('stepValue'); if (stepEl) stepEl.textContent = maxStep; + // Update compare status bar step counter + var cmpStepEl = document.getElementById('cmpStepStatus'); + if (cmpStepEl) cmpStepEl.textContent = 'Step: ' + maxStep; // Trigger forecast refresh on new data if (typeof onNewMetricsForForecast === 'function') onNewMetricsForForecast(maxStep); } @@ -80,9 +143,10 @@ function connectWS() { data.forEach(function(rec) { S.appliedData.push(rec); addTimelineItem(rec); - // Sync slider knobs from applied params - if (rec.params) { - syncSlidersFromApplied(rec.params); + // Sync slider knobs from applied params (check both fields) + var syncParams = rec.params || rec.payload; + if (syncParams) { + syncSlidersFromApplied(syncParams); } }); } diff --git a/src/hotcb/server/tailer.py b/src/hotcb/server/tailer.py index 43d4493..69d7882 100644 --- a/src/hotcb/server/tailer.py +++ b/src/hotcb/server/tailer.py @@ -119,7 +119,11 @@ async def poll_once(self) -> Dict[str, List[dict]]: async def _poll_target(self, target: TailTarget) -> List[dict]: """Read new records from one target and dispatch to subscribers.""" try: - records, new_cursor = read_new_jsonl(target.cursor) + # Run blocking file I/O in a thread to avoid blocking the event loop + loop = asyncio.get_running_loop() + records, new_cursor = await loop.run_in_executor( + None, read_new_jsonl, target.cursor + ) target.cursor = new_cursor except Exception as e: log.warning("Tailer error on %s: %s", target.name, e) diff --git a/src/hotcb/tests/test_actuator_unified.py b/src/hotcb/tests/test_actuator_unified.py new file mode 100644 index 0000000..1ab8882 --- /dev/null +++ b/src/hotcb/tests/test_actuator_unified.py @@ -0,0 +1,786 @@ +"""Comprehensive tests for the unified actuator types (Phase 2). + +Tests cover: ActuatorType validation, state machine transitions, mutation +tracking, apply_fn behaviour, snapshot/restore, describe_space, and the +convenience constructors (optimizer_actuators, loss_actuators, mutable_state). +""" + +from __future__ import annotations + +import pytest + +from hotcb.actuators import ( + ApplyResult, + ValidationResult, + mutable_state, + optimizer_actuators, + loss_actuators, +) +from hotcb.actuators.actuator import ( + ActuatorState, + ActuatorType, + HotcbActuator, + Mutation, + _INIT_SENTINEL, +) +from hotcb.actuators.state import MutableState + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class MockOptimizer: + def __init__(self, lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)): + self.param_groups = [{"lr": lr, "weight_decay": weight_decay, "betas": betas}] + + +class MockScheduler: + def __init__(self, base_lrs): + self.base_lrs = list(base_lrs) + + +def _noop_apply(value, env): + return ApplyResult(success=True, detail={"applied": value}) + + +def _failing_apply(value, env): + return ApplyResult(success=False, error="intentional_failure") + + +def _raising_apply(value, env): + raise RuntimeError("boom") + + +def _make_float_actuator(key="x", min_v=0.0, max_v=1.0, current=0.5, apply_fn=None): + return HotcbActuator( + param_key=key, + type=ActuatorType.FLOAT, + apply_fn=apply_fn or _noop_apply, + min_value=min_v, + max_value=max_v, + current_value=current, + ) + + +# =================================================================== +# ActuatorType & validation +# =================================================================== + +class TestActuatorValidation: + + def test_float_actuator_validate_in_bounds(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_noop_apply, + min_value=0.0, max_value=1.0, + ) + vr = act.validate(0.5) + assert vr.valid + assert vr.errors == [] + + def test_float_actuator_validate_out_of_bounds(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_noop_apply, + min_value=0.0, max_value=1.0, + ) + vr = act.validate(1.5) + assert not vr.valid + assert any("above max" in e for e in vr.errors) + + vr_low = act.validate(-0.1) + assert not vr_low.valid + assert any("below min" in e for e in vr_low.errors) + + def test_log_float_actuator_validate(self): + act = HotcbActuator( + param_key="lr", type=ActuatorType.LOG_FLOAT, apply_fn=_noop_apply, + min_value=1e-7, max_value=1.0, + ) + vr = act.validate(1e-4) + assert vr.valid + + vr_neg = act.validate(-1.0) + assert not vr_neg.valid + assert any("log_float must be positive" in e for e in vr_neg.errors) + + def test_bool_actuator_validate(self): + act = HotcbActuator( + param_key="flag", type=ActuatorType.BOOL, apply_fn=_noop_apply, + ) + assert act.validate(True).valid + assert act.validate(False).valid + + vr = act.validate("yes") + assert not vr.valid + assert any("expected bool" in e for e in vr.errors) + + def test_int_actuator_validate(self): + act = HotcbActuator( + param_key="n", type=ActuatorType.INT, apply_fn=_noop_apply, + min_value=0, max_value=100, + ) + assert act.validate(50).valid + + vr_float = act.validate(50.5) + assert not vr_float.valid + assert any("expected int" in e for e in vr_float.errors) + + vr_bool = act.validate(True) + assert not vr_bool.valid + assert any("expected int" in e for e in vr_bool.errors) + + vr_oob = act.validate(101) + assert not vr_oob.valid + assert any("above max" in e for e in vr_oob.errors) + + def test_choice_actuator_validate(self): + act = HotcbActuator( + param_key="opt", type=ActuatorType.CHOICE, apply_fn=_noop_apply, + choices=["adam", "sgd", "adamw"], + ) + assert act.validate("adam").valid + + vr = act.validate("rmsprop") + assert not vr.valid + assert any("not in choices" in e for e in vr.errors) + + def test_tuple_actuator_validate(self): + act = HotcbActuator( + param_key="betas", type=ActuatorType.TUPLE, apply_fn=_noop_apply, + ) + assert act.validate((0.9, 0.999)).valid + assert act.validate([0.9, 0.999]).valid + + vr = act.validate("not a tuple") + assert not vr.valid + assert any("expected tuple/list" in e for e in vr.errors) + + +# =================================================================== +# State machine +# =================================================================== + +class TestStateMachine: + + def test_initial_state_is_init(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_noop_apply, + ) + assert act.state == ActuatorState.INIT + assert act.current_value is _INIT_SENTINEL + + def test_initialize_transitions_to_untouched(self): + opt = MockOptimizer(lr=1e-3) + acts = optimizer_actuators(opt) + ms = MutableState(acts) + + lr_act = ms.get("lr") + assert lr_act is not None + assert lr_act.state == ActuatorState.INIT # set by constructor, but current_value populated + + ms.initialize(env={}) + assert lr_act.state == ActuatorState.UNTOUCHED + + def test_apply_transitions_to_unverified(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + assert act.state == ActuatorState.UNTOUCHED + + result = ms.apply("x", 0.7, {}, step=10) + assert result.success + assert act.state == ActuatorState.UNVERIFIED + + def test_verify_transitions_to_verified(self): + act = _make_float_actuator(key="x", current=0.5) + act.metrics_dict_name = "x" + ms = MutableState([act]) + ms.initialize(env={}) + ms.apply("x", 0.7, {}, step=10) + assert act.state == ActuatorState.UNVERIFIED + + verified = ms.verify("x", {"x": 0.7}) + assert verified + assert act.state == ActuatorState.VERIFIED + + def test_apply_after_verified_goes_back_to_unverified(self): + act = _make_float_actuator(key="x", current=0.5) + act.metrics_dict_name = "x" + ms = MutableState([act]) + ms.initialize(env={}) + ms.apply("x", 0.7, {}, step=10) + ms.verify("x", {"x": 0.7}) + assert act.state == ActuatorState.VERIFIED + + result = ms.apply("x", 0.3, {}, step=20) + assert result.success + assert act.state == ActuatorState.UNVERIFIED + + def test_disabled_actuator_rejects_apply(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + ms.disable("x") + assert act.state == ActuatorState.DISABLED + + result = ms.apply("x", 0.7, {}, step=10) + assert not result.success + assert "actuator_disabled" in result.error + + def test_disable_actuator(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + assert act.state == ActuatorState.UNTOUCHED + + ms.disable("x") + assert act.state == ActuatorState.DISABLED + + def test_enable_after_disable(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + ms.disable("x") + assert act.state == ActuatorState.DISABLED + + ms.enable("x") + assert act.state == ActuatorState.UNTOUCHED + + +# =================================================================== +# Mutation tracking +# =================================================================== + +class TestMutationTracking: + + def test_mutation_recorded_on_apply(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + + result = ms.apply("x", 0.7, {}, step=10) + assert result.success + assert len(act.mutations) == 1 + m = act.mutations[0] + assert m.step == 10 + assert m.old_value == 0.5 + assert m.new_value == 0.7 + assert m.verified is False + + def test_multiple_mutations_accumulated(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + + ms.apply("x", 0.6, {}, step=10) + ms.apply("x", 0.7, {}, step=20) + ms.apply("x", 0.8, {}, step=30) + + assert len(act.mutations) == 3 + assert [m.new_value for m in act.mutations] == [0.6, 0.7, 0.8] + + def test_last_changed_step_updated(self): + act = _make_float_actuator(current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + + ms.apply("x", 0.7, {}, step=50) + assert act.last_changed_step == 50 + + +# =================================================================== +# apply_fn behaviour +# =================================================================== + +class TestApplyFn: + + def test_apply_fn_receives_value_and_env(self): + received = {} + + def capture_apply(value, env): + received["value"] = value + received["env"] = env + return ApplyResult(success=True) + + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=capture_apply, + min_value=0.0, max_value=1.0, current_value=0.5, + ) + ms = MutableState([act]) + ms.initialize(env={}) + + test_env = {"key": "val"} + ms.apply("x", 0.8, test_env, step=1) + + assert received["value"] == 0.8 + assert received["env"] is test_env + + def test_apply_fn_failure_does_not_mutate_state(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_failing_apply, + min_value=0.0, max_value=1.0, current_value=0.5, + ) + ms = MutableState([act]) + ms.initialize(env={}) + + result = ms.apply("x", 0.7, {}, step=10) + assert not result.success + # State should not change + assert act.current_value == 0.5 + assert len(act.mutations) == 0 + assert act.state == ActuatorState.UNTOUCHED + + def test_apply_fn_exception_caught(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_raising_apply, + min_value=0.0, max_value=1.0, current_value=0.5, + ) + ms = MutableState([act]) + ms.initialize(env={}) + + result = ms.apply("x", 0.7, {}, step=10) + assert not result.success + assert "apply_fn_exception" in result.error + assert "boom" in result.error + # State not corrupted + assert act.current_value == 0.5 + assert len(act.mutations) == 0 + assert act.state == ActuatorState.UNTOUCHED + + +# =================================================================== +# Snapshot / restore +# =================================================================== + +class TestSnapshotRestore: + + def test_snapshot_all(self): + a1 = _make_float_actuator(key="lr", current=1e-3) + a2 = _make_float_actuator(key="wd", current=1e-4) + ms = MutableState([a1, a2]) + ms.initialize(env={}) + + snap = ms.snapshot_all() + assert "lr" in snap + assert "wd" in snap + assert snap["lr"]["value"] == 1e-3 + assert snap["lr"]["state"] == "untouched" + assert snap["wd"]["value"] == 1e-4 + + def test_restore_from_snapshot(self): + # Track the "live" value so we can verify restore actually calls apply_fn + live = {"lr": 1e-3, "wd": 1e-4} + + def make_apply(key): + def _apply(value, env): + live[key] = value + return ApplyResult(success=True) + return _apply + + a1 = HotcbActuator( + param_key="lr", type=ActuatorType.FLOAT, apply_fn=make_apply("lr"), + min_value=0.0, max_value=1.0, current_value=1e-3, + ) + a2 = HotcbActuator( + param_key="wd", type=ActuatorType.FLOAT, apply_fn=make_apply("wd"), + min_value=0.0, max_value=1.0, current_value=1e-4, + ) + ms = MutableState([a1, a2]) + ms.initialize(env={}) + + # Snapshot + snap = ms.snapshot_all() + + # Apply mutations + ms.apply("lr", 5e-4, {}, step=10) + ms.apply("wd", 5e-5, {}, step=10) + assert live["lr"] == 5e-4 + assert live["wd"] == 5e-5 + + # Restore + results = ms.restore_all(snap, {}) + assert results["lr"].success + assert results["wd"].success + assert live["lr"] == 1e-3 + assert live["wd"] == 1e-4 + assert a1.current_value == 1e-3 + assert a2.current_value == 1e-4 + + +# =================================================================== +# describe_space +# =================================================================== + +class TestDescribeSpace: + + def test_describe_space_includes_all_fields(self): + act = HotcbActuator( + param_key="lr", + type=ActuatorType.LOG_FLOAT, + apply_fn=_noop_apply, + label="Learning Rate", + group="optimizer", + min_value=1e-7, + max_value=1.0, + step_size=0.01, + log_base=10.0, + current_value=1e-3, + ) + d = act.describe_space() + + assert d["param_key"] == "lr" + assert d["type"] == "log_float" + assert d["label"] == "Learning Rate" + assert d["group"] == "optimizer" + assert d["min"] == 1e-7 + assert d["max"] == 1.0 + assert d["step"] == 0.01 + assert d["log_base"] == 10.0 + assert d["choices"] is None + assert d["current"] == 1e-3 + assert d["state"] == "init" + + def test_describe_space_current_none_for_init_sentinel(self): + act = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_noop_apply, + ) + d = act.describe_space() + assert d["current"] is None + + def test_describe_space_log_base_only_for_log_float(self): + act_float = HotcbActuator( + param_key="x", type=ActuatorType.FLOAT, apply_fn=_noop_apply, + ) + assert act_float.describe_space()["log_base"] is None + + act_log = HotcbActuator( + param_key="y", type=ActuatorType.LOG_FLOAT, apply_fn=_noop_apply, + log_base=2.0, + ) + assert act_log.describe_space()["log_base"] == 2.0 + + def test_describe_all(self): + a1 = _make_float_actuator(key="a") + a2 = _make_float_actuator(key="b") + a3 = _make_float_actuator(key="c") + ms = MutableState([a1, a2, a3]) + + descs = ms.describe_all() + assert len(descs) == 3 + keys = [d["param_key"] for d in descs] + assert keys == ["a", "b", "c"] + + def test_describe_all_excludes_disabled(self): + a1 = _make_float_actuator(key="a") + a2 = _make_float_actuator(key="b") + ms = MutableState([a1, a2]) + ms.disable("b") + + descs = ms.describe_all() + assert len(descs) == 1 + assert descs[0]["param_key"] == "a" + + +# =================================================================== +# Convenience constructors — optimizer_actuators +# =================================================================== + +class TestOptimizerActuators: + + def test_optimizer_actuators_from_torch_optimizer(self): + opt = MockOptimizer(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)) + acts = optimizer_actuators(opt) + + assert len(acts) == 3 + keys = {a.param_key for a in acts} + assert keys == {"lr", "weight_decay", "betas"} + + lr_act = next(a for a in acts if a.param_key == "lr") + assert lr_act.type == ActuatorType.LOG_FLOAT + assert lr_act.current_value == 1e-3 + assert lr_act.group == "optimizer" + + wd_act = next(a for a in acts if a.param_key == "weight_decay") + assert wd_act.type == ActuatorType.LOG_FLOAT + assert wd_act.current_value == 1e-4 + + betas_act = next(a for a in acts if a.param_key == "betas") + assert betas_act.type == ActuatorType.TUPLE + assert betas_act.current_value == (0.9, 0.999) + + def test_optimizer_actuators_bounds(self): + opt = MockOptimizer() + acts = optimizer_actuators(opt, lr_bounds=(1e-6, 0.1), wd_bounds=(0.0, 0.5)) + + lr_act = next(a for a in acts if a.param_key == "lr") + assert lr_act.min_value == 1e-6 + assert lr_act.max_value == 0.1 + + wd_act = next(a for a in acts if a.param_key == "weight_decay") + assert wd_act.min_value == 0.0 + assert wd_act.max_value == 0.5 + + def test_optimizer_actuators_apply_fn_sets_param_groups(self): + opt = MockOptimizer(lr=1e-3) + # Add a second param group + opt.param_groups.append({"lr": 1e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}) + acts = optimizer_actuators(opt) + + lr_act = next(a for a in acts if a.param_key == "lr") + result = lr_act.apply_fn(5e-4, {"optimizer": opt}) + assert result.success + + # All param groups updated + for g in opt.param_groups: + assert g["lr"] == 5e-4 + + def test_optimizer_actuators_apply_fn_coordinates_scheduler(self): + opt = MockOptimizer(lr=1e-3) + sched = MockScheduler(base_lrs=[1e-3]) + acts = optimizer_actuators(opt) + + lr_act = next(a for a in acts if a.param_key == "lr") + result = lr_act.apply_fn(5e-4, {"optimizer": opt, "scheduler": sched}) + assert result.success + + assert opt.param_groups[0]["lr"] == 5e-4 + assert sched.base_lrs == [5e-4] + + def test_optimizer_actuators_wd_apply_fn(self): + opt = MockOptimizer(weight_decay=1e-4) + opt.param_groups.append({"lr": 1e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}) + acts = optimizer_actuators(opt) + + wd_act = next(a for a in acts if a.param_key == "weight_decay") + result = wd_act.apply_fn(5e-5, {}) + assert result.success + for g in opt.param_groups: + assert g["weight_decay"] == 5e-5 + + def test_optimizer_actuators_betas_apply_fn(self): + opt = MockOptimizer(betas=(0.9, 0.999)) + acts = optimizer_actuators(opt) + + betas_act = next(a for a in acts if a.param_key == "betas") + result = betas_act.apply_fn([0.85, 0.99], {}) + assert result.success + assert opt.param_groups[0]["betas"] == (0.85, 0.99) + + def test_optimizer_without_betas(self): + """Optimizer without betas (e.g. SGD) produces only lr + wd actuators.""" + opt = MockOptimizer() + del opt.param_groups[0]["betas"] + acts = optimizer_actuators(opt) + keys = {a.param_key for a in acts} + assert "betas" not in keys + assert "lr" in keys + assert "weight_decay" in keys + + +# =================================================================== +# Convenience constructors — loss_actuators +# =================================================================== + +class TestLossActuators: + + def test_loss_actuators_from_dict(self): + weights = {"recon": 1.0, "kl": 0.5, "perceptual": 0.3} + acts = loss_actuators(weights) + + assert len(acts) == 3 + keys = {a.param_key for a in acts} + assert keys == {"recon", "kl", "perceptual"} + + for a in acts: + assert a.type == ActuatorType.FLOAT + assert a.group == "loss" + + recon_act = next(a for a in acts if a.param_key == "recon") + assert recon_act.current_value == 1.0 + + def test_loss_actuators_apply_fn_mutates_dict(self): + weights = {"recon": 1.0, "kl": 0.5} + acts = loss_actuators(weights) + + recon_act = next(a for a in acts if a.param_key == "recon") + result = recon_act.apply_fn(2.0, {}) + assert result.success + assert weights["recon"] == 2.0 # original dict mutated + + def test_loss_actuators_bounds(self): + weights = {"recon": 1.0, "kl": 0.5} + acts = loss_actuators(weights, global_bounds=(0.0, 10.0)) + + for a in acts: + assert a.min_value == 0.0 + assert a.max_value == 10.0 + + def test_loss_actuators_key_bounds(self): + weights = {"recon": 1.0, "kl": 0.5} + acts = loss_actuators(weights, key_bounds={"kl": (0.0, 2.0)}) + + kl_act = next(a for a in acts if a.param_key == "kl") + recon_act = next(a for a in acts if a.param_key == "recon") + + assert kl_act.min_value == 0.0 + assert kl_act.max_value == 2.0 + # recon uses global_bounds default + assert recon_act.min_value == 0.0 + assert recon_act.max_value == 100.0 + + +# =================================================================== +# Convenience constructors — mutable_state +# =================================================================== + +class TestMutableStateConstructor: + + def test_mutable_state_constructor(self): + a1 = _make_float_actuator(key="lr") + a2 = _make_float_actuator(key="wd") + a3 = _make_float_actuator(key="recon_w") + + ms = mutable_state([a1, a2, a3]) + assert isinstance(ms, MutableState) + assert ms.keys() == ["lr", "wd", "recon_w"] + assert len(ms) == 3 + assert "lr" in ms + assert "missing" not in ms + + +# =================================================================== +# MutableState container basics +# =================================================================== + +class TestMutableStateContainer: + + def test_get_returns_actuator(self): + act = _make_float_actuator(key="x") + ms = MutableState([act]) + assert ms.get("x") is act + + def test_get_returns_none_for_missing(self): + ms = MutableState([]) + assert ms.get("missing") is None + + def test_apply_unknown_key(self): + ms = MutableState([]) + result = ms.apply("missing", 1.0, {}, step=0) + assert not result.success + assert "unknown_param" in result.error + + def test_verify_nonexistent_key(self): + ms = MutableState([]) + assert not ms.verify("missing", {}) + + def test_verify_no_metrics_dict_name(self): + act = _make_float_actuator(key="x", current=0.5) + ms = MutableState([act]) + ms.initialize(env={}) + ms.apply("x", 0.7, {}, step=10) + # No metrics_dict_name set + assert not ms.verify("x", {"x": 0.7}) + + def test_verify_wrong_state(self): + act = _make_float_actuator(key="x", current=0.5) + act.metrics_dict_name = "x" + ms = MutableState([act]) + ms.initialize(env={}) + # In UNTOUCHED, not UNVERIFIED + assert not ms.verify("x", {"x": 0.5}) + + def test_disable_nonexistent_key(self): + ms = MutableState([]) + # Should not raise + ms.disable("missing") + + def test_enable_nonexistent_key(self): + ms = MutableState([]) + # Should not raise + ms.enable("missing") + + def test_enable_non_disabled_is_noop(self): + act = _make_float_actuator(key="x") + ms = MutableState([act]) + ms.initialize(env={}) + assert act.state == ActuatorState.UNTOUCHED + ms.enable("x") # Should be noop since not disabled + assert act.state == ActuatorState.UNTOUCHED + + +# =================================================================== +# Integration: full end-to-end flow +# =================================================================== + +class TestIntegration: + + def test_full_optimizer_flow(self): + """Full lifecycle: create, initialize, apply, verify, snapshot, restore.""" + opt = MockOptimizer(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)) + sched = MockScheduler(base_lrs=[1e-3]) + acts = optimizer_actuators(opt) + ms = MutableState(acts) + + # Initialize + ms.initialize(env={}) + for a in acts: + assert a.state == ActuatorState.UNTOUCHED + + # Snapshot + snap = ms.snapshot_all() + + # Apply lr change + env = {"optimizer": opt, "scheduler": sched} + result = ms.apply("lr", 5e-4, env, step=10) + assert result.success + assert opt.param_groups[0]["lr"] == 5e-4 + assert sched.base_lrs == [5e-4] + + lr_act = ms.get("lr") + assert lr_act.state == ActuatorState.UNVERIFIED + assert lr_act.current_value == 5e-4 + + # Verify + ms.verify("lr", {"lr": 5e-4}) + assert lr_act.state == ActuatorState.VERIFIED + + # Restore snapshot + results = ms.restore_all(snap, env) + assert all(r.success for r in results.values()) + assert opt.param_groups[0]["lr"] == 1e-3 + + def test_full_loss_flow(self): + """Full lifecycle for loss actuators.""" + weights = {"recon": 1.0, "kl": 0.5} + acts = loss_actuators(weights) + ms = MutableState(acts) + + ms.initialize(env={}) + + # Apply + result = ms.apply("recon", 2.0, {}, step=5) + assert result.success + assert weights["recon"] == 2.0 + + # Describe + descs = ms.describe_all() + assert len(descs) == 2 + recon_desc = next(d for d in descs if d["param_key"] == "recon") + assert recon_desc["current"] == 2.0 + assert recon_desc["state"] == "unverified" + + def test_mixed_actuators(self): + """Optimizer + loss actuators in one MutableState.""" + opt = MockOptimizer(lr=1e-3, weight_decay=1e-4, betas=(0.9, 0.999)) + weights = {"recon": 1.0, "kl": 0.5} + + all_acts = optimizer_actuators(opt) + loss_actuators(weights) + ms = MutableState(all_acts) + ms.initialize(env={}) + + assert len(ms) == 5 # lr, wd, betas, recon, kl + assert ms.keys() == ["lr", "weight_decay", "betas", "recon", "kl"] + + descs = ms.describe_all() + groups = {d["group"] for d in descs} + assert groups == {"optimizer", "loss"} diff --git a/src/hotcb/tests/test_dashboard_config.py b/src/hotcb/tests/test_dashboard_config.py new file mode 100644 index 0000000..26d414c --- /dev/null +++ b/src/hotcb/tests/test_dashboard_config.py @@ -0,0 +1,712 @@ +"""Tests for hotcb.server.config — DashboardConfig dataclasses + loader.""" +from __future__ import annotations + +import json +import os +import tempfile + +import pytest + +from hotcb.server.config import ( + AutopilotConfig, + ChartConfig, + DashboardConfig, + ServerConfig, + UIConfig, +) + +# Skip all endpoint tests if fastapi not installed +fastapi = pytest.importorskip("fastapi") +httpx = pytest.importorskip("httpx") + + +@pytest.fixture +def tmp_dir(): + with tempfile.TemporaryDirectory() as d: + yield d + + +class TestConfigDefaults: + """test_config_defaults — all sub-configs have documented defaults.""" + + def test_server_defaults(self): + s = ServerConfig() + assert s.host == "0.0.0.0" + assert s.port == 8421 + assert s.poll_interval == 0.5 + assert s.history_limit_metrics == 500 + assert s.history_limit_applied == 200 + assert s.ws_initial_burst == 200 + assert s.ws_max_retries == 20 + assert s.ws_retry_base == 3.0 + assert s.ws_retry_cap == 30.0 + + def test_chart_defaults(self): + c = ChartConfig() + assert c.max_render_points == 2000 + assert c.line_tension == 0.15 + assert c.forecast_dash == (6, 3) + assert c.mutation_dash == (3, 4) + assert c.annotation_stagger_rows == 10 + assert c.annotation_min_distance == 70 + + def test_autopilot_defaults(self): + a = AutopilotConfig() + assert a.divergence_threshold == 2.0 + assert a.ratio_threshold == 0.5 + assert a.ai_min_interval == 10 + assert a.ai_max_wait == 200 + assert a.ai_default_cadence == 50 + + def test_ui_defaults(self): + u = UIConfig() + assert u.state_save_interval == 5000 + assert u.alert_poll_interval == 15000 + assert u.manifold_refresh_interval == 10000 + assert u.recipe_refresh_interval == 5000 + assert u.forecast_poll_interval == 5000 + assert u.forecast_step_cadence == 10 + assert u.forecast_batch_size == 8 + assert u.staged_change_threshold == 0.005 + assert u.health_ema_alpha == 0.1 + + def test_dashboard_config_all_defaults(self): + cfg = DashboardConfig() + assert isinstance(cfg.server, ServerConfig) + assert isinstance(cfg.chart, ChartConfig) + assert isinstance(cfg.autopilot, AutopilotConfig) + assert isinstance(cfg.ui, UIConfig) + assert cfg.run_dir == "" + assert cfg.controls == [] + + +class TestConfigFromYaml: + """test_config_from_yaml — YAML overrides applied, others preserved.""" + + def test_config_from_yaml(self, tmp_dir): + yaml_path = os.path.join(tmp_dir, "dashboard.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n port: 9000\nchart:\n line_tension: 0.3\n") + + cfg = DashboardConfig.load(tmp_dir, yaml_path=yaml_path) + # Overrides applied + assert cfg.server.port == 9000 + assert cfg.chart.line_tension == 0.3 + # Other defaults preserved + assert cfg.server.host == "0.0.0.0" + assert cfg.server.poll_interval == 0.5 + assert cfg.chart.max_render_points == 2000 + assert cfg.autopilot.divergence_threshold == 2.0 + + def test_config_from_yaml_missing_file(self, tmp_dir): + """Nonexistent YAML -> all defaults, no error.""" + cfg = DashboardConfig.load( + tmp_dir, yaml_path=os.path.join(tmp_dir, "nonexistent.yaml") + ) + assert cfg.server.port == 8421 + assert cfg.chart.line_tension == 0.15 + assert cfg.run_dir == tmp_dir + + +class TestConfigFromEnv: + """test_config_from_env — HOTCB_PORT=9000 etc. override defaults.""" + + def test_config_from_env(self, tmp_dir, monkeypatch): + monkeypatch.setenv("HOTCB_PORT", "9000") + monkeypatch.setenv("HOTCB_POLL_INTERVAL", "1.0") + cfg = DashboardConfig.load(tmp_dir) + assert cfg.server.port == 9000 + assert cfg.server.poll_interval == 1.0 + + def test_config_env_overrides_yaml(self, tmp_dir, monkeypatch): + """Env beats YAML.""" + yaml_path = os.path.join(tmp_dir, "dashboard.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n port: 9000\n") + + monkeypatch.setenv("HOTCB_PORT", "8000") + cfg = DashboardConfig.load(tmp_dir, yaml_path=yaml_path) + assert cfg.server.port == 8000 + + +class TestConfigCliOverrides: + """test_config_cli_overrides_all — CLI overrides beat both.""" + + def test_config_cli_overrides_all(self, tmp_dir, monkeypatch): + yaml_path = os.path.join(tmp_dir, "dashboard.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n port: 9000\n") + + monkeypatch.setenv("HOTCB_PORT", "8000") + cfg = DashboardConfig.load(tmp_dir, yaml_path=yaml_path, port=7000) + assert cfg.server.port == 7000 + + +class TestConfigToDict: + """test_config_to_dict_roundtrip — to_dict() is JSON serializable, has all sections.""" + + def test_config_to_dict_roundtrip(self): + cfg = DashboardConfig() + d = cfg.to_dict() + # Must be JSON-serializable + serialized = json.dumps(d) + roundtripped = json.loads(serialized) + # All sections present + assert "server" in roundtripped + assert "chart" in roundtripped + assert "autopilot" in roundtripped + assert "ui" in roundtripped + assert "run_dir" in roundtripped + assert "controls" in roundtripped + + def test_config_run_dir_in_dict(self): + cfg = DashboardConfig(run_dir="/tmp/x") + d = cfg.to_dict() + assert d["run_dir"] == "/tmp/x" + + def test_chart_tuples_serialize_as_arrays(self): + """forecast_dash and mutation_dash are tuples in Python, arrays in JSON.""" + cfg = DashboardConfig() + d = cfg.to_dict() + # In the dict they should be lists (JSON arrays) + assert isinstance(d["chart"]["forecast_dash"], list) + assert isinstance(d["chart"]["mutation_dash"], list) + assert d["chart"]["forecast_dash"] == [6, 3] + assert d["chart"]["mutation_dash"] == [3, 4] + + +class TestConfigEndpoint: + """test_config_endpoint_returns_full — GET /api/config returns all sections.""" + + @pytest.fixture + def tmp_run_dir(self): + with tempfile.TemporaryDirectory() as d: + # Create metrics file so _resolve_active_run_dir returns this dir + with open(os.path.join(d, "hotcb.metrics.jsonl"), "w") as f: + pass + yield d + + @pytest.fixture + def client(self, tmp_run_dir): + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + app = create_app(tmp_run_dir, poll_interval=60) + return TestClient(app) + + @pytest.fixture + def client_with_yaml(self, tmp_run_dir): + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + yaml_path = os.path.join(tmp_run_dir, "hotcb.dashboard.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n port: 9999\n") + + app = create_app(tmp_run_dir, poll_interval=60, config_yaml=yaml_path) + return TestClient(app) + + def test_config_endpoint_returns_full(self, client): + r = client.get("/api/config") + assert r.status_code == 200 + data = r.json() + assert "server" in data + assert "chart" in data + assert "autopilot" in data + assert "ui" in data + assert "run_dir" in data + assert "controls" in data + + def test_config_endpoint_reflects_overrides(self, client_with_yaml): + r = client_with_yaml.get("/api/config") + assert r.status_code == 200 + data = r.json() + assert data["server"]["port"] == 9999 + + +# ==================================================================== +# Phase 4: Dynamic Controls from Actuators +# ==================================================================== + + +class TestControlsFromMutableState: + """Phase 4 — /api/config controls populated from MutableState.""" + + @pytest.fixture + def tmp_run_dir(self): + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "hotcb.metrics.jsonl"), "w") as f: + pass + yield d + + @pytest.fixture + def _make_mutable_state(self): + """Helper to create a MutableState with common actuators.""" + from hotcb.actuators import ( + ActuatorType, + ApplyResult, + HotcbActuator, + MutableState, + ) + + def _factory(specs=None): + if specs is None: + specs = [ + dict( + param_key="lr", + type=ActuatorType.LOG_FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + label="Learning Rate", + group="optimizer", + min_value=1e-7, + max_value=1.0, + current_value=1e-3, + ), + dict( + param_key="weight_decay", + type=ActuatorType.LOG_FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + label="Weight Decay", + group="optimizer", + min_value=0.0, + max_value=1.0, + current_value=1e-4, + ), + dict( + param_key="recon_w", + type=ActuatorType.FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + label="Reconstruction Weight", + group="loss", + min_value=0.0, + max_value=100.0, + current_value=1.0, + ), + ] + actuators = [HotcbActuator(**s) for s in specs] + return MutableState(actuators) + + return _factory + + @pytest.fixture + def client_with_ms(self, tmp_run_dir, _make_mutable_state): + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + ms = _make_mutable_state() + app = create_app(tmp_run_dir, poll_interval=60) + app.state.mutable_state = ms + return TestClient(app) + + @pytest.fixture + def client_no_ms(self, tmp_run_dir): + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + app = create_app(tmp_run_dir, poll_interval=60) + return TestClient(app) + + def test_config_controls_from_mutable_state(self, client_with_ms): + """Controls populated from MutableState.describe_all().""" + r = client_with_ms.get("/api/config") + assert r.status_code == 200 + data = r.json() + controls = data["controls"] + assert len(controls) == 3 + keys = {c["param_key"] for c in controls} + assert keys == {"lr", "weight_decay", "recon_w"} + # Each entry has required fields + for c in controls: + assert "param_key" in c + assert "type" in c + assert "label" in c + assert "group" in c + assert "current" in c + + def test_config_controls_defaults_when_no_mutable_state(self, client_no_ms): + """No MutableState -> default optimizer controls are returned.""" + r = client_no_ms.get("/api/config") + assert r.status_code == 200 + data = r.json() + # Should have default lr and weight_decay controls + assert len(data["controls"]) >= 2 + keys = [c["param_key"] for c in data["controls"]] + assert "lr" in keys + assert "weight_decay" in keys + + def test_config_controls_types_match_actuators( + self, tmp_run_dir, _make_mutable_state + ): + """Control types match actuator types.""" + from starlette.testclient import TestClient + from hotcb.server.app import create_app + from hotcb.actuators import ActuatorType, ApplyResult, HotcbActuator, MutableState + + ms = _make_mutable_state( + [ + dict( + param_key="lr", + type=ActuatorType.LOG_FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + group="optimizer", + min_value=1e-7, + max_value=1.0, + current_value=1e-3, + ), + dict( + param_key="recon_w", + type=ActuatorType.FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + group="loss", + min_value=0.0, + max_value=10.0, + current_value=1.0, + ), + dict( + param_key="use_augment", + type=ActuatorType.BOOL, + apply_fn=lambda v, e: ApplyResult(success=True), + group="custom", + current_value=True, + ), + ] + ) + + app = create_app(tmp_run_dir, poll_interval=60) + app.state.mutable_state = ms + client = TestClient(app) + + r = client.get("/api/config") + controls = r.json()["controls"] + type_map = {c["param_key"]: c["type"] for c in controls} + assert type_map["lr"] == "log_float" + assert type_map["recon_w"] == "float" + assert type_map["use_augment"] == "bool" + + def test_config_controls_groups_present(self, client_with_ms): + """Controls have correct group field.""" + r = client_with_ms.get("/api/config") + controls = r.json()["controls"] + group_map = {c["param_key"]: c["group"] for c in controls} + assert group_map["lr"] == "optimizer" + assert group_map["weight_decay"] == "optimizer" + assert group_map["recon_w"] == "loss" + + def test_control_state_endpoint_uses_mutable_state(self, client_with_ms): + """GET /api/state/controls returns live values from MutableState.""" + r = client_with_ms.get("/api/state/controls") + assert r.status_code == 200 + data = r.json() + assert "controls" in data + controls = data["controls"] + assert len(controls) == 3 + # Verify current values are present + val_map = {c["param_key"]: c["current"] for c in controls} + assert val_map["lr"] == pytest.approx(1e-3) + assert val_map["recon_w"] == pytest.approx(1.0) + + def test_control_state_endpoint_defaults_when_no_ms(self, client_no_ms): + """GET /api/state/controls returns default controls when no MutableState.""" + r = client_no_ms.get("/api/state/controls") + assert r.status_code == 200 + data = r.json() + # Should have default lr and weight_decay controls + assert len(data["controls"]) >= 2 + keys = [c["param_key"] for c in data["controls"]] + assert "lr" in keys + + +class TestControlsFromMutableStateFunction: + """Unit tests for controls_from_mutable_state().""" + + def test_none_returns_empty(self): + from hotcb.server.config import controls_from_mutable_state + assert controls_from_mutable_state(None) == [] + + def test_with_mutable_state(self): + from hotcb.server.config import controls_from_mutable_state + from hotcb.actuators import ( + ActuatorType, + ApplyResult, + HotcbActuator, + MutableState, + ) + + ms = MutableState([ + HotcbActuator( + param_key="lr", + type=ActuatorType.LOG_FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + group="optimizer", + min_value=1e-7, + max_value=1.0, + current_value=0.001, + ), + ]) + result = controls_from_mutable_state(ms) + assert len(result) == 1 + assert result[0]["param_key"] == "lr" + assert result[0]["type"] == "log_float" + assert result[0]["current"] == pytest.approx(0.001) + + +# ==================================================================== +# Phase 6: Magic Number Extraction +# ==================================================================== + + +class TestTailerUsesConfigPollInterval: + """Phase 6 — tailer created with config.server.poll_interval.""" + + @pytest.fixture + def tmp_run_dir(self): + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "hotcb.metrics.jsonl"), "w") as f: + pass + yield d + + def test_tailer_uses_config_poll_interval(self, tmp_run_dir): + """Tailer created with config's poll_interval, not hardcoded default.""" + from hotcb.server.app import create_app + + app = create_app(tmp_run_dir, poll_interval=2.5) + tailer = app.state.tailer + assert tailer._poll_interval == 2.5 + + def test_tailer_uses_default_poll_interval(self, tmp_run_dir): + """When no poll_interval override, tailer uses default 0.5.""" + from hotcb.server.app import create_app + + app = create_app(tmp_run_dir) + tailer = app.state.tailer + assert tailer._poll_interval == 0.5 + + +class TestHistoryLimitsFromConfig: + """Phase 6 — metrics/applied history endpoints respect config limits.""" + + @pytest.fixture + def tmp_run_dir(self): + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "hotcb.metrics.jsonl"), "w") as f: + pass + yield d + + @pytest.fixture + def populated_run_dir(self): + """Run dir with 50 metrics records and 50 applied records.""" + with tempfile.TemporaryDirectory() as d: + metrics_path = os.path.join(d, "hotcb.metrics.jsonl") + with open(metrics_path, "w") as f: + for i in range(50): + f.write(json.dumps({"step": i, "metrics": {"loss": 1.0 - i * 0.01}}) + "\n") + applied_path = os.path.join(d, "hotcb.applied.jsonl") + with open(applied_path, "w") as f: + for i in range(50): + f.write(json.dumps({"step": i, "module": "opt", "decision": "applied", "params": {"lr": 0.001}}) + "\n") + yield d + + def test_metrics_history_uses_config_limit(self, populated_run_dir): + """GET /api/metrics/history with no last_n uses config limit.""" + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + yaml_path = os.path.join(populated_run_dir, "cfg.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n history_limit_metrics: 10\n") + + app = create_app(populated_run_dir, poll_interval=60, config_yaml=yaml_path) + client = TestClient(app) + r = client.get("/api/metrics/history") + assert r.status_code == 200 + records = r.json()["records"] + assert len(records) == 10 + + def test_metrics_history_explicit_last_n_overrides_config(self, populated_run_dir): + """GET /api/metrics/history?last_n=5 overrides config limit.""" + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + yaml_path = os.path.join(populated_run_dir, "cfg.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n history_limit_metrics: 10\n") + + app = create_app(populated_run_dir, poll_interval=60, config_yaml=yaml_path) + client = TestClient(app) + r = client.get("/api/metrics/history?last_n=5") + assert r.status_code == 200 + records = r.json()["records"] + assert len(records) == 5 + + def test_applied_history_uses_config_limit(self, populated_run_dir): + """GET /api/applied/history with no last_n uses config limit.""" + from starlette.testclient import TestClient + from hotcb.server.app import create_app + + yaml_path = os.path.join(populated_run_dir, "cfg.yaml") + with open(yaml_path, "w") as f: + f.write("server:\n history_limit_applied: 7\n") + + app = create_app(populated_run_dir, poll_interval=60, config_yaml=yaml_path) + client = TestClient(app) + r = client.get("/api/applied/history") + assert r.status_code == 200 + records = r.json()["records"] + assert len(records) == 7 + + +class TestAutopilotThresholdsFromConfig: + """Phase 6 — autopilot engine uses config thresholds as defaults.""" + + def test_divergence_uses_config_threshold(self): + """Divergence rule uses config threshold when rule doesn't specify one.""" + from hotcb.server.autopilot import AutopilotEngine, AutopilotRule + + config = AutopilotConfig(divergence_threshold=5.0) + engine = AutopilotEngine(run_dir="/tmp/test", mode="suggest", config=config) + engine.add_rule(AutopilotRule( + rule_id="div1", + condition="divergence", + metric_name="val_loss", + params={"window": 3}, # no threshold specified — should use config's 5.0 + action={"module": "opt", "op": "set_params", "params": {"lr": 0.0001}}, + confidence="high", + )) + + # Feed metric history that diverges by 4.0 (below config threshold of 5.0) + for i in range(3): + engine.evaluate(i, {"val_loss": 1.0}) + actions = engine.evaluate(3, {"val_loss": 5.0}) + # 5.0 - 1.0 = 4.0, which is < 5.0 threshold, so no divergence + assert len(actions) == 0 + + def test_divergence_rule_threshold_overrides_config(self): + """Rule-specified threshold takes precedence over config.""" + from hotcb.server.autopilot import AutopilotEngine, AutopilotRule + + config = AutopilotConfig(divergence_threshold=100.0) # very high config threshold + engine = AutopilotEngine(run_dir="/tmp/test", mode="suggest", config=config) + engine.add_rule(AutopilotRule( + rule_id="div2", + condition="divergence", + metric_name="val_loss", + params={"window": 3, "threshold": 1.0}, # rule specifies threshold=1.0 + action={"module": "opt", "op": "set_params", "params": {"lr": 0.0001}}, + confidence="high", + )) + + for i in range(3): + engine.evaluate(i, {"val_loss": 1.0}) + actions = engine.evaluate(3, {"val_loss": 5.0}) + # 5.0 - 1.0 = 4.0, which is > rule threshold 1.0, so divergence detected + assert len(actions) == 1 + assert actions[0].condition_met.startswith("Metric diverged") + + def test_overfitting_uses_config_ratio_threshold(self): + """Overfitting rule uses config ratio_threshold when rule doesn't specify one.""" + from hotcb.server.autopilot import AutopilotEngine, AutopilotRule + + config = AutopilotConfig(ratio_threshold=0.01) # very low threshold + engine = AutopilotEngine(run_dir="/tmp/test", mode="suggest", config=config) + engine.add_rule(AutopilotRule( + rule_id="over1", + condition="overfitting", + metric_name="", + params={"train_metric": "train_loss", "val_metric": "val_loss"}, + action={"module": "opt", "op": "set_params", "params": {"lr": 0.0001}}, + confidence="high", + )) + + # train/val ratio = 0.1/1.0 = 0.1, which is > 0.01 threshold (no overfitting) + actions = engine.evaluate(0, {"train_loss": 0.1, "val_loss": 1.0}) + assert len(actions) == 0 + + def test_autopilot_no_config_uses_hardcoded_defaults(self): + """Without config, autopilot uses hardcoded defaults (2.0, 0.5).""" + from hotcb.server.autopilot import AutopilotEngine, AutopilotRule + + engine = AutopilotEngine(run_dir="/tmp/test", mode="suggest") # no config + engine.add_rule(AutopilotRule( + rule_id="div_default", + condition="divergence", + metric_name="val_loss", + params={"window": 3}, + action={"module": "opt", "op": "set_params", "params": {"lr": 0.0001}}, + confidence="high", + )) + + # Feed metric history that diverges by 3.0 (above default 2.0) + for i in range(3): + engine.evaluate(i, {"val_loss": 1.0}) + actions = engine.evaluate(3, {"val_loss": 4.0}) + # 4.0 - 1.0 = 3.0, > default threshold 2.0 + assert len(actions) == 1 + + +class TestAICadenceFromConfig: + """Phase 6 — AI engine uses config cadence params.""" + + def test_ai_engine_uses_config_min_interval(self): + """AI engine min_interval from AutopilotConfig.""" + from hotcb.server.ai_engine import LLMAutopilotEngine, AIConfig + + ap_config = AutopilotConfig(ai_min_interval=25) + engine = LLMAutopilotEngine( + run_dir="/tmp/test", + config=AIConfig(api_key="test-key"), + autopilot_config=ap_config, + ) + assert engine._min_interval == 25 + + def test_ai_engine_uses_config_max_wait(self): + """AI engine max_wait from AutopilotConfig.""" + from hotcb.server.ai_engine import LLMAutopilotEngine, AIConfig + + ap_config = AutopilotConfig(ai_max_wait=500) + engine = LLMAutopilotEngine( + run_dir="/tmp/test", + config=AIConfig(api_key="test-key"), + autopilot_config=ap_config, + ) + assert engine._max_wait == 500 + + def test_ai_engine_uses_config_default_cadence(self): + """AI engine cadence from AutopilotConfig when AIConfig has default.""" + from hotcb.server.ai_engine import LLMAutopilotEngine, AIConfig + + ap_config = AutopilotConfig(ai_default_cadence=100) + engine = LLMAutopilotEngine( + run_dir="/tmp/test", + config=AIConfig(api_key="test-key"), + autopilot_config=ap_config, + ) + assert engine.config.cadence == 100 + + def test_ai_engine_min_interval_governs_should_invoke(self): + """should_invoke respects config min_interval.""" + from hotcb.server.ai_engine import LLMAutopilotEngine, AIConfig + + ap_config = AutopilotConfig(ai_min_interval=30) + engine = LLMAutopilotEngine( + run_dir="/tmp/test", + config=AIConfig(api_key="test-key"), + autopilot_config=ap_config, + ) + engine._last_invoked_step = 0 + + # Step 15 is within min_interval=30, should not invoke + assert engine.should_invoke(15, []) is False + # Step 31 is past min_interval=30, periodic cadence check applies + assert engine.should_invoke(31, [{"alert": "test"}]) is True + + def test_ai_engine_defaults_without_autopilot_config(self): + """AI engine uses hardcoded defaults when no autopilot_config.""" + from hotcb.server.ai_engine import LLMAutopilotEngine, AIConfig + + engine = LLMAutopilotEngine( + run_dir="/tmp/test", + config=AIConfig(), + ) + assert engine._min_interval == 10 + assert engine._max_wait == 200 + assert engine.config.cadence == 50 diff --git a/src/hotcb/tests/test_freeze_modes.py b/src/hotcb/tests/test_freeze_modes.py index 290804e..de832fa 100644 --- a/src/hotcb/tests/test_freeze_modes.py +++ b/src/hotcb/tests/test_freeze_modes.py @@ -7,6 +7,7 @@ import pytest +from hotcb.actuators import optimizer_actuators, mutable_state from hotcb.kernel import HotKernel @@ -120,9 +121,9 @@ def test_replay_applies_recipe_ignores_external( ) write_freeze(mode="replay", recipe_path=recipe_path) - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) - opt = _make_optimizer(lr=1e-4) + ms = mutable_state(optimizer_actuators(opt)) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) # Append an external conflicting opt command write_commands({"module": "opt", "op": "set_params", "id": "main", "params": {"lr": 9e-4}}) @@ -180,8 +181,9 @@ def test_adjusted_lr(self, run_dir, make_env, write_recipe, write_freeze, read_l write_freeze(mode="replay_adjusted", recipe_path=recipe_path, adjust_path=adjust_path) - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) opt = _make_optimizer(lr=1e-4) + ms = mutable_state(optimizer_actuators(opt)) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) for step in range(1, 6): env = make_env(step=step, optimizer=opt) @@ -211,8 +213,9 @@ def test_unfreeze_restores_normal_operation( """ write_freeze(mode="prod") - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) opt = _make_optimizer(lr=1e-4) + ms = mutable_state(optimizer_actuators(opt)) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) # Step 1: external opt command -- should be ignored write_commands({"module": "opt", "op": "set_params", "id": "main", "params": {"lr": 5e-4}}) diff --git a/src/hotcb/tests/test_hotloss.py b/src/hotcb/tests/test_hotloss.py index a161410..a395114 100644 --- a/src/hotcb/tests/test_hotloss.py +++ b/src/hotcb/tests/test_hotloss.py @@ -1,154 +1,351 @@ -"""Unit tests for HotLossController (spec §19.7).""" +"""Tests for loss weight mutations via kernel default stream + MutableState.""" from __future__ import annotations +import os + import pytest -from hotcb.modules.loss import HotLossController +from hotcb.actuators import loss_actuators, mutable_state, ApplyResult +from hotcb.actuators.actuator import ActuatorState, ActuatorType, HotcbActuator +from hotcb.actuators.state import MutableState +from hotcb.kernel import HotKernel from hotcb.ops import HotOp -def _make_loss_state(): - return {"weights": {}, "terms": {}, "ramps": {}} +def _kernel_with_loss(run_dir, loss_weights, extra_actuators=None, **bounds): + """Build a kernel with MutableState containing loss actuators.""" + acts = loss_actuators(loss_weights, **bounds) + if extra_actuators: + acts += extra_actuators + ms = mutable_state(acts) + return HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms), loss_weights def _op(op="set_params", params=None, id="main"): return HotOp(module="loss", op=op, id=id, params=params) +# ------------------------------------------------------------------ # +# 1. Weight mutation via direct key +# ------------------------------------------------------------------ # class TestWeightsMutation: - def test_distill_w_suffix(self): - ctrl = HotLossController() - ls = _make_loss_state() - result = ctrl.apply_op(_op(params={"distill_w": 0.2, "depth_w": 1.5}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["weights"]["distill"] == 0.2 - assert ls["weights"]["depth"] == 1.5 - - def test_fallback_to_weights_bucket(self): - ctrl = HotLossController() - ls = _make_loss_state() - result = ctrl.apply_op(_op(params={"custom_metric": 0.5}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["weights"]["custom_metric"] == 0.5 - - -class TestTermsToggle: - def test_terms_dot_prefix(self): - ctrl = HotLossController() - ls = _make_loss_state() - result = ctrl.apply_op(_op(params={"terms.aux_depth": False, "terms.aux_heatmap": True}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["terms"]["aux_depth"] is False - assert ls["terms"]["aux_heatmap"] is True - - def test_terms_as_dict(self): - ctrl = HotLossController() - ls = _make_loss_state() - result = ctrl.apply_op(_op(params={"terms": {"aux_depth": False, "aux_heatmap": True}}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["terms"]["aux_depth"] is False - assert ls["terms"]["aux_heatmap"] is True - - -class TestRamps: - def test_ramps_dot_prefix(self): - ctrl = HotLossController() - ls = _make_loss_state() - ramp_cfg = {"type": "linear", "warmup_frac": 0.2, "end": 2.0} - result = ctrl.apply_op(_op(params={"ramps.depth": ramp_cfg}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["ramps"]["depth"] == ramp_cfg - - def test_ramps_as_dict(self): - ctrl = HotLossController() - ls = _make_loss_state() - ramp_cfg = {"type": "linear", "end": 2.0} - result = ctrl.apply_op(_op(params={"ramps": {"depth": ramp_cfg}}), {"loss_state": ls}) - assert result.decision == "applied" - assert ls["ramps"]["depth"] == ramp_cfg - - -class TestMissingLossState: - def test_no_loss_state_in_env(self): - ctrl = HotLossController() - result = ctrl.apply_op(_op(params={"distill_w": 0.2}), {}) - assert result.decision == "failed" - assert result.error == "missing_loss_state" - - def test_resolve_loss_state_callable(self): - ctrl = HotLossController() - ls = _make_loss_state() - env = {"resolve_loss_state": lambda: ls} - result = ctrl.apply_op(_op(params={"distill_w": 0.3}), env) - assert result.decision == "applied" - assert ls["weights"]["distill"] == 0.3 - - + def test_set_weight_by_key(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0, "depth": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 0.2}), env, "train_step", 1, + ) + + assert weights["distill"] == pytest.approx(0.2) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" + + def test_set_weight_direct_format(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0, "depth": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"distill": 0.2, "depth": 1.5}), env, "train_step", 1, + ) + + assert weights["distill"] == pytest.approx(0.2) + assert weights["depth"] == pytest.approx(1.5) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" + + +# ------------------------------------------------------------------ # +# 2. Bounds enforcement +# ------------------------------------------------------------------ # +class TestBounds: + def test_out_of_bounds_rejected(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights, global_bounds=(0.0, 10.0)) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 50.0}), env, "train_step", 1, + ) + + assert weights["distill"] == pytest.approx(1.0) # unchanged + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "above max" in ledger[0]["error"] + + +# ------------------------------------------------------------------ # +# 3. Missing MutableState +# ------------------------------------------------------------------ # +class TestMissingMutableState: + def test_no_mutable_state_fails(self, run_dir, make_env, read_ledger): + kernel = HotKernel(run_dir=run_dir, debounce_steps=1) # no mutable_state + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 0.2}), env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "no_mutable_state" in ledger[0]["error"] + + +# ------------------------------------------------------------------ # +# 4. Unknown param key +# ------------------------------------------------------------------ # +class TestUnknownKey: + def test_unknown_key_fails(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, _ = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "nonexistent", "value": 0.5}), env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "unknown_param" in ledger[0]["error"] + + +# ------------------------------------------------------------------ # +# 5. Enable / disable +# ------------------------------------------------------------------ # class TestEnableDisable: - def test_disabled_handle_skips(self): - ctrl = HotLossController() - ls = _make_loss_state() - env = {"loss_state": ls} - ctrl.apply_op(HotOp(module="loss", op="disable", id="main"), env) - result = ctrl.apply_op(_op(params={"distill_w": 0.2}), env) - assert result.decision == "skipped_noop" - assert result.notes == "handle_disabled" - assert ls["weights"] == {} - - def test_re_enable_then_apply(self): - ctrl = HotLossController() - ls = _make_loss_state() - env = {"loss_state": ls} - ctrl.apply_op(HotOp(module="loss", op="disable", id="main"), env) - ctrl.apply_op(HotOp(module="loss", op="enable", id="main"), env) - result = ctrl.apply_op(_op(params={"distill_w": 0.2}), env) - assert result.decision == "applied" - assert ls["weights"]["distill"] == 0.2 - - -class TestAutoDisableOnError: - def test_error_disables_handle(self): - ctrl = HotLossController(auto_disable_on_error=True) - - class BrokenDict(dict): - def setdefault(self, key, default=None): - raise RuntimeError("broken") - - env = {"loss_state": BrokenDict()} - result = ctrl.apply_op(_op(params={"distill_w": 0.2}), env) - assert result.decision == "failed" - assert "broken" in result.error - # handle should be disabled - assert ctrl.handles["main"].enabled is False - - def test_no_auto_disable_when_off(self): - ctrl = HotLossController(auto_disable_on_error=False) - - class BrokenDict(dict): - def setdefault(self, key, default=None): - raise RuntimeError("broken") - - env = {"loss_state": BrokenDict()} - result = ctrl.apply_op(_op(params={"distill_w": 0.2}), env) - assert result.decision == "failed" - assert ctrl.handles["main"].enabled is True - - -class TestStatus: - def test_status_structure(self): - ctrl = HotLossController() - ls = _make_loss_state() - ctrl.apply_op(_op(params={"distill_w": 0.2}), {"loss_state": ls}) - status = ctrl.status() - assert "main" in status - assert status["main"]["enabled"] is True - assert "distill_w" in status["main"]["last_params"] - assert status["main"]["last_error"] is None - - + def test_disabled_rejects_set(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + HotOp(module="loss", op="disable", params={"key": "distill"}), + env, "train_step", 1, + ) + kernel._apply_single( + _op(params={"key": "distill", "value": 0.2}), env, "train_step", 2, + ) + + assert weights["distill"] == pytest.approx(1.0) # unchanged + ledger = read_ledger() + set_entries = [e for e in ledger if e["op"] == "set_params"] + assert set_entries[0]["decision"] == "failed" + assert "disabled" in set_entries[0]["error"] + + def test_re_enable_then_apply(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + HotOp(module="loss", op="disable", params={"key": "distill"}), + env, "train_step", 1, + ) + kernel._apply_single( + HotOp(module="loss", op="enable", params={"key": "distill"}), + env, "train_step", 2, + ) + kernel._apply_single( + _op(params={"key": "distill", "value": 0.3}), env, "train_step", 3, + ) + + assert weights["distill"] == pytest.approx(0.3) + + +# ------------------------------------------------------------------ # +# 6. Ledger format +# ------------------------------------------------------------------ # +class TestLedgerFormat: + def test_module_preserved_in_ledger(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, _ = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 0.5}), env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["module"] == "loss" + assert ledger[0]["decision"] == "applied" + + +# ------------------------------------------------------------------ # +# 7. Multiple weights in single op +# ------------------------------------------------------------------ # +class TestMultipleWeights: + def test_set_multiple_weights(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0, "depth": 1.0, "kl": 0.5} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"distill": 0.2, "depth": 1.5, "kl": 0.8}), + env, "train_step", 1, + ) + + assert weights["distill"] == pytest.approx(0.2) + assert weights["depth"] == pytest.approx(1.5) + assert weights["kl"] == pytest.approx(0.8) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" + + +# ------------------------------------------------------------------ # +# 8. Apply error +# ------------------------------------------------------------------ # +class TestApplyError: + def test_apply_fn_failure_recorded(self, run_dir, make_env, read_ledger): + def _bad_apply(value, env): + raise RuntimeError("broken") + + bad_act = HotcbActuator( + param_key="bad_weight", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + group="loss", + min_value=0.0, + max_value=100.0, + current_value=1.0, + ) + ms = mutable_state([bad_act]) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) + + kernel._apply_single( + HotOp(module="loss", op="set_params", params={"key": "bad_weight", "value": 2.0}), + env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "broken" in ledger[0]["error"] + + +# ------------------------------------------------------------------ # +# 9. Unknown op +# ------------------------------------------------------------------ # class TestUnknownOp: - def test_unknown_op_ignored(self): - ctrl = HotLossController() - result = ctrl.apply_op(HotOp(module="loss", op="reset", id="main"), {"loss_state": {}}) - assert result.decision == "ignored" - assert "unknown_op" in (result.notes or "") + def test_unknown_op_ignored(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, _ = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + HotOp(module="loss", op="reset", id="main"), env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "ignored" + assert "unknown_op" in (ledger[0].get("notes") or "") + + +# ------------------------------------------------------------------ # +# 10. Custom module routes through default stream +# ------------------------------------------------------------------ # +class TestCustomModule: + def test_custom_module_routes_to_mutable_state(self, run_dir, make_env, read_ledger): + """Any unknown module name routes through MutableState if it has the key.""" + custom_act = HotcbActuator( + param_key="dropout", + type=ActuatorType.FLOAT, + apply_fn=lambda v, e: ApplyResult(success=True), + group="custom", + min_value=0.0, + max_value=1.0, + current_value=0.5, + ) + ms = mutable_state([custom_act]) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) + + kernel._apply_single( + HotOp(module="custom", op="set_params", params={"key": "dropout", "value": 0.3}), + env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" + assert ledger[0]["module"] == "custom" + + +# ------------------------------------------------------------------ # +# 11. No params produces error +# ------------------------------------------------------------------ # +class TestNoParams: + def test_empty_params_fails(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, _ = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={}), env, "train_step", 1, + ) + + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "no_params" in ledger[0]["error"] + + +# ------------------------------------------------------------------ # +# 12. Partial success with errors +# ------------------------------------------------------------------ # +class TestPartialSuccess: + def test_mixed_known_and_unknown_keys(self, run_dir, make_env, read_ledger): + weights = {"distill": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"distill": 0.5, "nonexistent": 2.0}), + env, "train_step", 1, + ) + + # Should succeed for known key, note error for unknown + assert weights["distill"] == pytest.approx(0.5) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" + assert "unknown_param" in (ledger[0].get("notes") or "") + + +# ------------------------------------------------------------------ # +# 13. Freeze enforcement for default stream +# ------------------------------------------------------------------ # +class TestFreezeEnforcement: + def test_freeze_blocks_default_stream(self, run_dir, make_env, write_freeze, read_ledger): + weights = {"distill": 1.0} + write_freeze(mode="prod") + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 0.2}), + env, "train_step", 1, + ) + + assert weights["distill"] == pytest.approx(1.0) # unchanged + ledger = read_ledger() + assert ledger[0]["decision"] == "ignored_freeze" + + +# ------------------------------------------------------------------ # +# 14. Mutation tracking +# ------------------------------------------------------------------ # +class TestMutationTracking: + def test_mutation_recorded_in_actuator(self, run_dir, make_env): + weights = {"distill": 1.0} + kernel, weights = _kernel_with_loss(run_dir, weights) + env = make_env(step=1) + + kernel._apply_single( + _op(params={"key": "distill", "value": 0.5}), env, "train_step", 10, + ) + + act = kernel._mutable_state.get("distill") + assert act is not None + assert len(act.mutations) == 1 + assert act.mutations[0].step == 10 + assert act.mutations[0].new_value == pytest.approx(0.5) + assert act.state == ActuatorState.UNVERIFIED diff --git a/src/hotcb/tests/test_hotopt.py b/src/hotcb/tests/test_hotopt.py index e229b34..a16eb2d 100644 --- a/src/hotcb/tests/test_hotopt.py +++ b/src/hotcb/tests/test_hotopt.py @@ -1,10 +1,15 @@ -"""Tests for HotOptController — spec section 19.6.""" +"""Tests for optimizer param mutations via kernel default stream + MutableState.""" from __future__ import annotations +import os + import pytest -from hotcb.modules.opt import HotOptController, OptHandle +from hotcb.actuators import optimizer_actuators, mutable_state, ApplyResult +from hotcb.actuators.actuator import ActuatorState, ActuatorType, HotcbActuator +from hotcb.actuators.state import MutableState +from hotcb.kernel import HotKernel from hotcb.ops import HotOp @@ -17,6 +22,15 @@ def _make_param_groups(n=2, lr=1e-3, weight_decay=0.01): return [{"lr": lr, "weight_decay": weight_decay, "params": []} for _ in range(n)] +def _kernel_with_opt(run_dir, opt, extra_actuators=None): + """Build a kernel with MutableState containing optimizer actuators.""" + acts = optimizer_actuators(opt) + if extra_actuators: + acts += extra_actuators + ms = mutable_state(acts) + return HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + + def _op(params=None, op="set_params", id="main"): return HotOp(module="opt", op=op, id=id, params=params) @@ -25,214 +39,265 @@ def _op(params=None, op="set_params", id="main"): # 1. Global lr update # ------------------------------------------------------------------ # class TestGlobalLrUpdate: - def test_both_groups_updated(self): - ctrl = HotOptController(auto_disable_on_error=True) + def test_both_groups_updated(self, run_dir, make_env): groups = _make_param_groups(2) - env = {"optimizer": MockOptimizer(groups)} + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) + kernel._apply_single(_op(params={"lr": 1e-4}), env, "train_step", 1) - assert result.decision == "applied" assert groups[0]["lr"] == pytest.approx(1e-4) assert groups[1]["lr"] == pytest.approx(1e-4) # ------------------------------------------------------------------ # -# 2. Group-specific lr -# ------------------------------------------------------------------ # -class TestGroupSpecificLr: - def test_only_target_group_changed(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) - env = {"optimizer": MockOptimizer(groups)} - - result = ctrl.apply_op(_op(params={"group": 1, "lr": 5e-5}), env) - - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-3) # unchanged - assert groups[1]["lr"] == pytest.approx(5e-5) - - -# ------------------------------------------------------------------ # -# 3. Weight decay update +# 2. Weight decay update # ------------------------------------------------------------------ # class TestWeightDecayUpdate: - def test_all_groups_updated(self): - ctrl = HotOptController(auto_disable_on_error=True) + def test_all_groups_updated(self, run_dir, make_env): groups = _make_param_groups(2, weight_decay=0.01) - env = {"optimizer": MockOptimizer(groups)} + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op(_op(params={"weight_decay": 0.05}), env) + kernel._apply_single(_op(params={"weight_decay": 0.05}), env, "train_step", 1) - assert result.decision == "applied" assert groups[0]["weight_decay"] == pytest.approx(0.05) assert groups[1]["weight_decay"] == pytest.approx(0.05) # ------------------------------------------------------------------ # -# 4. scheduler_scale +# 3. Scheduler coordination # ------------------------------------------------------------------ # -class TestSchedulerScale: - def test_lr_halved(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) - env = {"optimizer": MockOptimizer(groups)} +class TestSchedulerCoordination: + def test_scheduler_base_lrs_updated(self, run_dir, make_env): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + + class FakeScheduler: + def __init__(self): + self.base_lrs = [1e-3] - result = ctrl.apply_op(_op(params={"scheduler_scale": 0.5}), env) + sched = FakeScheduler() + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt, scheduler=sched) - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-3 * 0.5) - assert groups[1]["lr"] == pytest.approx(1e-3 * 0.5) + kernel._apply_single(_op(params={"lr": 5e-4}), env, "train_step", 1) + + assert groups[0]["lr"] == pytest.approx(5e-4) + assert sched.base_lrs[0] == pytest.approx(5e-4) # ------------------------------------------------------------------ # -# 5. scheduler_drop +# 4. Missing optimizer (no MutableState) # ------------------------------------------------------------------ # -class TestSchedulerDrop: - def test_lr_multiplied(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) - env = {"optimizer": MockOptimizer(groups)} +class TestMissingMutableState: + def test_failed_with_no_mutable_state(self, run_dir, make_env, read_ledger): + kernel = HotKernel(run_dir=run_dir, debounce_steps=1) # no mutable_state + env = make_env(step=1) - result = ctrl.apply_op(_op(params={"scheduler_drop": 0.1}), env) + kernel._apply_single(_op(params={"lr": 1e-4}), env, "train_step", 1) - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-3 * 0.1) - assert groups[1]["lr"] == pytest.approx(1e-3 * 0.1) + ledger = read_ledger() + assert len(ledger) == 1 + assert ledger[0]["decision"] == "failed" + assert "no_mutable_state" in ledger[0]["error"] # ------------------------------------------------------------------ # -# 6. clip_norm +# 5. Unknown param key # ------------------------------------------------------------------ # -class TestClipNorm: - def test_stored_in_group(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2) - env = {"optimizer": MockOptimizer(groups)} +class TestUnknownParamKey: + def test_unknown_param_fails(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op(_op(params={"clip_norm": 2.0}), env) + kernel._apply_single( + _op(params={"key": "nonexistent", "value": 1.0}), + env, "train_step", 1, + ) - assert result.decision == "applied" - assert groups[0]["hotcb_clip_norm"] == pytest.approx(2.0) - assert groups[1]["hotcb_clip_norm"] == pytest.approx(2.0) + ledger = read_ledger() + assert len(ledger) == 1 + assert ledger[0]["decision"] == "failed" + assert "unknown_param" in ledger[0]["error"] # ------------------------------------------------------------------ # -# 7. Per-group mapping +# 6. Enable / disable via default stream # ------------------------------------------------------------------ # -class TestPerGroupMapping: - def test_each_group_gets_specific_lr(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) - env = {"optimizer": MockOptimizer(groups)} +class TestEnableDisable: + def test_disabled_actuator_rejects_set_params(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) + + # Disable the lr actuator + kernel._apply_single( + HotOp(module="opt", op="disable", params={"key": "lr"}), + env, "train_step", 1, + ) + + # set_params should fail + kernel._apply_single(_op(params={"lr": 1e-4}), env, "train_step", 2) + assert groups[0]["lr"] == pytest.approx(1e-3) # unchanged + + ledger = read_ledger() + set_entries = [e for e in ledger if e["op"] == "set_params"] + assert len(set_entries) == 1 + assert set_entries[0]["decision"] == "failed" + assert "disabled" in set_entries[0]["error"] + + def test_re_enable_then_apply(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op( - _op(params={"groups": {"0": {"lr": 1e-3}, "1": {"lr": 2e-3}}}), env + # Disable then re-enable + kernel._apply_single( + HotOp(module="opt", op="disable", params={"key": "lr"}), + env, "train_step", 1, + ) + kernel._apply_single( + HotOp(module="opt", op="enable", params={"key": "lr"}), + env, "train_step", 2, ) - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-3) - assert groups[1]["lr"] == pytest.approx(2e-3) + # Now it should apply + kernel._apply_single(_op(params={"lr": 1e-4}), env, "train_step", 3) + assert groups[0]["lr"] == pytest.approx(1e-4) # ------------------------------------------------------------------ # -# 8. Missing optimizer +# 7. Ledger records preserve module field # ------------------------------------------------------------------ # -class TestMissingOptimizer: - def test_failed_with_error(self): - ctrl = HotOptController(auto_disable_on_error=True) - env = {} # no optimizer, no resolve_optimizer +class TestLedgerFormat: + def test_ledger_preserves_module_field(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) + kernel._apply_single(_op(params={"lr": 5e-4}), env, "train_step", 1) - assert result.decision == "failed" - assert "missing_optimizer" in result.error + ledger = read_ledger() + assert len(ledger) == 1 + assert ledger[0]["module"] == "opt" + assert ledger[0]["decision"] == "applied" # ------------------------------------------------------------------ # -# 9. resolve_optimizer callable +# 8. New format: explicit key+value # ------------------------------------------------------------------ # -class TestResolveOptimizer: - def test_resolver_returns_optimizer(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) +class TestNewKeyValueFormat: + def test_explicit_key_value(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) opt = MockOptimizer(groups) - env = {"resolve_optimizer": lambda: opt} + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) + kernel._apply_single( + HotOp(module="opt", op="set_params", params={"key": "lr", "value": 5e-4}), + env, "train_step", 1, + ) - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-4) - assert groups[1]["lr"] == pytest.approx(1e-4) + assert groups[0]["lr"] == pytest.approx(5e-4) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" # ------------------------------------------------------------------ # -# 10. Enable / disable +# 9. Validation error # ------------------------------------------------------------------ # -class TestEnableDisable: - def test_disabled_handle_skips(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2, lr=1e-3) - env = {"optimizer": MockOptimizer(groups)} - - # Disable the handle - ctrl.apply_op(_op(op="disable", id="main"), env) - - # set_params should be skipped - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) - assert result.decision == "skipped_noop" - assert result.notes == "handle_disabled" - assert groups[0]["lr"] == pytest.approx(1e-3) # unchanged +class TestValidationError: + def test_out_of_bounds_rejected(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - # Re-enable - ctrl.apply_op(_op(op="enable", id="main"), env) + # lr actuator has max_value=1.0 by default + kernel._apply_single( + _op(params={"key": "lr", "value": 5.0}), + env, "train_step", 1, + ) - # Now it should apply - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) - assert result.decision == "applied" - assert groups[0]["lr"] == pytest.approx(1e-4) + assert groups[0]["lr"] == pytest.approx(1e-3) # unchanged + ledger = read_ledger() + assert ledger[0]["decision"] == "failed" + assert "above max" in ledger[0]["error"] # ------------------------------------------------------------------ # -# 11. Auto-disable on error +# 10. Betas set # ------------------------------------------------------------------ # -class TestAutoDisableOnError: - def test_handle_disabled_after_error(self): - ctrl = HotOptController(auto_disable_on_error=True) +class TestBetasSet: + def test_betas_via_mutable_state(self, run_dir, make_env, read_ledger): + groups = [{"lr": 1e-3, "weight_decay": 0.01, "betas": (0.9, 0.999)}] + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) - class BrokenOptimizer: - @property - def param_groups(self): - raise RuntimeError("gpu exploded") + kernel._apply_single( + _op(params={"key": "betas", "value": (0.8, 0.99)}), + env, "train_step", 1, + ) - env = {"optimizer": BrokenOptimizer()} + assert groups[0]["betas"] == pytest.approx((0.8, 0.99)) + ledger = read_ledger() + assert ledger[0]["decision"] == "applied" - result = ctrl.apply_op(_op(params={"lr": 1e-4}), env) - assert result.decision == "failed" - handle = ctrl.handles["main"] - assert handle.enabled is False - assert handle.last_error == "gpu exploded" +# ------------------------------------------------------------------ # +# 11. Multiple params in single op +# ------------------------------------------------------------------ # +class TestMultipleParams: + def test_lr_and_wd_set_together(self, run_dir, make_env, read_ledger): + groups = _make_param_groups(1, lr=1e-3, weight_decay=0.01) + opt = MockOptimizer(groups) + kernel = _kernel_with_opt(run_dir, opt) + env = make_env(step=1, optimizer=opt) + + kernel._apply_single( + _op(params={"lr": 5e-4, "weight_decay": 0.05}), + env, "train_step", 1, + ) + + assert groups[0]["lr"] == pytest.approx(5e-4) + assert groups[0]["weight_decay"] == pytest.approx(0.05) # ------------------------------------------------------------------ # -# 12. Status +# 12. Apply error produces error in ledger # ------------------------------------------------------------------ # -class TestStatus: - def test_status_structure(self): - ctrl = HotOptController(auto_disable_on_error=True) - groups = _make_param_groups(2) - env = {"optimizer": MockOptimizer(groups)} - - ctrl.apply_op(_op(params={"lr": 1e-4}), env) +class TestApplyError: + def test_apply_fn_failure_recorded(self, run_dir, make_env, read_ledger): + def _bad_apply(value, env): + raise RuntimeError("gpu exploded") + + bad_act = HotcbActuator( + param_key="bad_param", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + min_value=0.0, + max_value=10.0, + current_value=1.0, + ) + ms = mutable_state([bad_act]) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) - status = ctrl.status() + kernel._apply_single( + HotOp(module="opt", op="set_params", params={"key": "bad_param", "value": 2.0}), + env, "train_step", 1, + ) - assert "main" in status - entry = status["main"] - assert "enabled" in entry - assert "last_params" in entry - assert "last_error" in entry - assert entry["enabled"] is True - assert entry["last_params"]["lr"] == pytest.approx(1e-4) - assert entry["last_error"] is None + ledger = read_ledger() + assert len(ledger) == 1 + assert ledger[0]["decision"] == "failed" + assert "gpu exploded" in ledger[0]["error"] diff --git a/src/hotcb/tests/test_hottune.py b/src/hotcb/tests/test_hottune.py index 58cbdfd..c06a8f9 100644 --- a/src/hotcb/tests/test_hottune.py +++ b/src/hotcb/tests/test_hottune.py @@ -10,9 +10,7 @@ import pytest -from hotcb.actuators.base import ApplyResult, ValidationResult -from hotcb.actuators.optimizer import OptimizerActuator -from hotcb.actuators.loss_state import LossStateActuator +from hotcb.actuators.base import ApplyResult, ValidationResult, BaseActuator from hotcb.modules.tune.schemas import ( AcceptanceConfig, ActuatorConfig, @@ -43,12 +41,231 @@ def __init__(self, lr=0.001, wd=0.01, betas=(0.9, 0.999)): self.param_groups = [{"lr": lr, "weight_decay": wd, "betas": betas}] +class MockOptActuator: + """Mock optimizer actuator for tune tests (implements BaseActuator protocol).""" + name = "opt" + + def __init__(self, lr_bounds=(1e-7, 1.0), wd_bounds=(0.0, 1.0)): + self.lr_bounds = lr_bounds + self.wd_bounds = wd_bounds + + def _resolve_optimizer(self, env, opt_idx=0): + opt = env.get("optimizer") + return opt + + def snapshot(self, env): + opt = self._resolve_optimizer(env) + if opt is None: + return {} + g = opt.param_groups[0] + groups = [{"lr": g.get("lr")}] + if "weight_decay" in g: + groups[0]["weight_decay"] = g["weight_decay"] + if "betas" in g: + groups[0]["betas"] = list(g["betas"]) + return {"groups": groups, "all_groups": [groups]} + + def validate(self, patch, env): + op = patch.get("op") + value = patch.get("value") + valid_ops = {"lr_mult", "lr_set", "wd_mult", "wd_set", "betas_set"} + errors = [] + if op not in valid_ops: + errors.append(f"unknown op: {op}") + return ValidationResult(valid=False, errors=errors) + if value is None: + errors.append("missing value") + return ValidationResult(valid=False, errors=errors) + if op == "lr_mult" and (not isinstance(value, (int, float)) or value <= 0): + errors.append(f"lr_mult value must be positive, got {value}") + elif op == "lr_set": + if not isinstance(value, (int, float)) or value <= 0: + errors.append(f"lr_set value must be positive, got {value}") + elif not (self.lr_bounds[0] <= value <= self.lr_bounds[1]): + errors.append(f"lr_set value {value} out of bounds {self.lr_bounds}") + elif op == "wd_mult" and (not isinstance(value, (int, float)) or value <= 0): + errors.append(f"wd_mult value must be positive, got {value}") + elif op == "wd_set": + if not isinstance(value, (int, float)) or value < 0: + errors.append(f"wd_set value must be non-negative, got {value}") + elif not (self.wd_bounds[0] <= value <= self.wd_bounds[1]): + errors.append(f"wd_set value {value} out of bounds {self.wd_bounds}") + elif op == "betas_set": + if not isinstance(value, (list, tuple)) or len(value) != 2: + errors.append(f"betas_set expects [beta1, beta2], got {value}") + else: + for i, b in enumerate(value): + if not isinstance(b, (int, float)) or not (0.0 <= b < 1.0): + errors.append(f"beta{i+1} must be in [0, 1), got {b}") + return ValidationResult(valid=len(errors) == 0, errors=errors) + + def apply(self, patch, env): + opt = self._resolve_optimizer(env) + if opt is None: + return ApplyResult(success=False, error="missing_optimizer") + op = patch.get("op") + value = patch.get("value") + try: + for g in opt.param_groups: + if op == "lr_mult": + new_lr = g["lr"] * float(value) + new_lr = max(self.lr_bounds[0], min(self.lr_bounds[1], new_lr)) + g["lr"] = new_lr + elif op == "lr_set": + g["lr"] = float(value) + elif op == "wd_mult": + wd = g.get("weight_decay", 0.0) + new_wd = wd * float(value) + new_wd = max(self.wd_bounds[0], min(self.wd_bounds[1], new_wd)) + g["weight_decay"] = new_wd + elif op == "wd_set": + g["weight_decay"] = float(value) + elif op == "betas_set": + g["betas"] = tuple(float(b) for b in value) + return ApplyResult(success=True, detail=patch) + except Exception as e: + return ApplyResult(success=False, error=str(e)) + + def restore(self, snapshot, env): + groups = snapshot.get("groups", []) + opt = self._resolve_optimizer(env) + if opt is None: + return ApplyResult(success=False, error="missing_optimizer") + try: + for i, snap in enumerate(groups): + if i >= len(opt.param_groups): + break + g = opt.param_groups[i] + if "lr" in snap: + g["lr"] = snap["lr"] + if "weight_decay" in snap: + g["weight_decay"] = snap["weight_decay"] + if "betas" in snap: + g["betas"] = tuple(snap["betas"]) + return ApplyResult(success=True) + except Exception as e: + return ApplyResult(success=False, error=str(e)) + + def describe_space(self): + return { + "actuator": self.name, + "mutations": { + "lr_mult": {"type": "float"}, + "lr_set": {"type": "float", "bounds": list(self.lr_bounds)}, + "wd_mult": {"type": "float"}, + "wd_set": {"type": "float", "bounds": list(self.wd_bounds)}, + "betas_set": {"type": "list[float]", "length": 2}, + }, + } + + +class MockLossActuator: + """Mock loss actuator for tune tests (implements BaseActuator protocol).""" + name = "loss" + + def __init__(self, global_bounds=(0.0, 100.0), key_bounds=None): + self.global_bounds = global_bounds + self.key_bounds = key_bounds or {} + + def _resolve_mutable_state(self, env): + return env.get("mutable_state") + + def _get_bounds(self, key): + return self.key_bounds.get(key, self.global_bounds) + + def snapshot(self, env): + import copy + ms = self._resolve_mutable_state(env) + if ms is None: + return {} + weights = ms.get("weights", {}) + return {"weights": copy.deepcopy(weights)} + + def validate(self, patch, env): + errors = [] + op = patch.get("op") + key = patch.get("key") + value = patch.get("value") + valid_ops = {"set", "mult", "delta"} + if op not in valid_ops: + errors.append(f"unknown op: {op}") + return ValidationResult(valid=False, errors=errors) + if key is None: + errors.append("missing key") + if value is None: + errors.append("missing value") + if not isinstance(value, (int, float)): + errors.append(f"value must be numeric, got {type(value).__name__}") + if errors: + return ValidationResult(valid=False, errors=errors) + ms = self._resolve_mutable_state(env) + if ms is not None: + weights = ms.get("weights", {}) + bounds = self._get_bounds(key) + if op == "set" and not (bounds[0] <= value <= bounds[1]): + errors.append(f"set value {value} out of bounds {bounds} for key {key}") + elif op == "mult": + if value <= 0: + errors.append(f"mult value must be positive, got {value}") + current = weights.get(key, 1.0) + result = current * value + if not (bounds[0] <= result <= bounds[1]): + errors.append(f"mult would produce {result}, out of bounds {bounds}") + return ValidationResult(valid=len(errors) == 0, errors=errors) + + def apply(self, patch, env): + import copy + ms = self._resolve_mutable_state(env) + if ms is None: + return ApplyResult(success=False, error="missing_mutable_state") + op = patch.get("op") + key = patch.get("key") + value = patch.get("value") + weights = ms.setdefault("weights", {}) + bounds = self._get_bounds(key) + try: + current = weights.get(key, 1.0) + if op == "set": + new_val = float(value) + elif op == "mult": + new_val = current * float(value) + elif op == "delta": + new_val = current + float(value) + else: + return ApplyResult(success=False, error=f"unknown op: {op}") + new_val = max(bounds[0], min(bounds[1], new_val)) + weights[key] = new_val + return ApplyResult(success=True, detail={"key": key, "old": current, "new": new_val}) + except Exception as e: + return ApplyResult(success=False, error=str(e)) + + def restore(self, snapshot, env): + import copy + ms = self._resolve_mutable_state(env) + if ms is None: + return ApplyResult(success=False, error="missing_mutable_state") + saved = snapshot.get("weights", {}) + ms["weights"] = copy.deepcopy(saved) + return ApplyResult(success=True) + + def describe_space(self): + return { + "actuator": self.name, + "mutations": { + "set": {"type": "float"}, + "mult": {"type": "float"}, + "delta": {"type": "float"}, + }, + "global_bounds": list(self.global_bounds), + } + + def make_env( step=100, epoch=1, phase="val", optimizer=None, - loss_state=None, + mutable_state=None, loss=None, metric_fn=None, max_steps=1000, @@ -61,8 +278,8 @@ def make_env( } if optimizer is not None: env["optimizer"] = optimizer - if loss_state is not None: - env["loss_state"] = loss_state + if mutable_state is not None: + env["mutable_state"] = mutable_state if loss is not None: env["loss"] = loss if metric_fn is not None: @@ -99,11 +316,11 @@ def simple_recipe(**overrides) -> TuneRecipe: # ========== Actuator tests ========== -class TestOptimizerActuator: +class TestMockOptActuator: def test_snapshot_and_restore(self): opt = FakeOptimizer(lr=0.001, wd=0.01, betas=(0.9, 0.999)) env = {"optimizer": opt} - act = OptimizerActuator() + act = MockOptActuator() snap = act.snapshot(env) assert snap["groups"][0]["lr"] == 0.001 @@ -119,14 +336,14 @@ def test_snapshot_and_restore(self): assert opt.param_groups[0]["lr"] == 0.001 def test_validate_lr_mult(self): - act = OptimizerActuator() + act = MockOptActuator() env = {"optimizer": FakeOptimizer()} assert act.validate({"op": "lr_mult", "value": 0.9}, env).valid assert not act.validate({"op": "lr_mult", "value": -1}, env).valid assert not act.validate({"op": "unknown", "value": 1}, env).valid def test_validate_betas_set(self): - act = OptimizerActuator() + act = MockOptActuator() env = {"optimizer": FakeOptimizer()} assert act.validate({"op": "betas_set", "value": [0.9, 0.98]}, env).valid assert not act.validate({"op": "betas_set", "value": [1.5, 0.98]}, env).valid @@ -135,7 +352,7 @@ def test_validate_betas_set(self): def test_apply_lr_mult(self): opt = FakeOptimizer(lr=0.001) env = {"optimizer": opt} - act = OptimizerActuator() + act = MockOptActuator() result = act.apply({"op": "lr_mult", "value": 0.5}, env) assert result.success assert opt.param_groups[0]["lr"] == pytest.approx(0.0005) @@ -143,7 +360,7 @@ def test_apply_lr_mult(self): def test_apply_lr_set(self): opt = FakeOptimizer(lr=0.001) env = {"optimizer": opt} - act = OptimizerActuator() + act = MockOptActuator() result = act.apply({"op": "lr_set", "value": 0.01}, env) assert result.success assert opt.param_groups[0]["lr"] == 0.01 @@ -151,7 +368,7 @@ def test_apply_lr_set(self): def test_apply_wd_mult(self): opt = FakeOptimizer(wd=0.01) env = {"optimizer": opt} - act = OptimizerActuator() + act = MockOptActuator() result = act.apply({"op": "wd_mult", "value": 2.0}, env) assert result.success assert opt.param_groups[0]["weight_decay"] == pytest.approx(0.02) @@ -159,36 +376,36 @@ def test_apply_wd_mult(self): def test_apply_betas_set(self): opt = FakeOptimizer() env = {"optimizer": opt} - act = OptimizerActuator() + act = MockOptActuator() result = act.apply({"op": "betas_set", "value": [0.85, 0.95]}, env) assert result.success assert opt.param_groups[0]["betas"] == (0.85, 0.95) def test_apply_missing_optimizer(self): - act = OptimizerActuator() + act = MockOptActuator() result = act.apply({"op": "lr_mult", "value": 0.9}, {}) assert not result.success assert "missing_optimizer" in result.error def test_describe_space(self): - act = OptimizerActuator() + act = MockOptActuator() space = act.describe_space() assert "lr_mult" in space["mutations"] def test_lr_bounds_clamping(self): opt = FakeOptimizer(lr=0.5) env = {"optimizer": opt} - act = OptimizerActuator(lr_bounds=(1e-7, 1.0)) + act = MockOptActuator(lr_bounds=(1e-7, 1.0)) result = act.apply({"op": "lr_mult", "value": 10.0}, env) assert result.success assert opt.param_groups[0]["lr"] == 1.0 # clamped -class TestLossStateActuator: +class TestMockLossActuator: def test_snapshot_and_restore(self): ls = {"weights": {"main": 1.0, "aux": 0.5}} - env = {"loss_state": ls} - act = LossStateActuator() + env = {"mutable_state": ls} + act = MockLossActuator() snap = act.snapshot(env) assert snap["weights"]["main"] == 1.0 @@ -199,51 +416,51 @@ def test_snapshot_and_restore(self): assert ls["weights"]["main"] == 1.0 def test_validate_set(self): - act = LossStateActuator(global_bounds=(0.0, 10.0)) + act = MockLossActuator(global_bounds=(0.0, 10.0)) ls = {"weights": {"main": 1.0}} - env = {"loss_state": ls} + env = {"mutable_state": ls} assert act.validate({"op": "set", "key": "main", "value": 5.0}, env).valid assert not act.validate({"op": "set", "key": "main", "value": 11.0}, env).valid def test_validate_mult_bounds(self): - act = LossStateActuator(global_bounds=(0.0, 10.0)) + act = MockLossActuator(global_bounds=(0.0, 10.0)) ls = {"weights": {"main": 5.0}} - env = {"loss_state": ls} + env = {"mutable_state": ls} assert not act.validate({"op": "mult", "key": "main", "value": 3.0}, env).valid # 15 > 10 def test_apply_set(self): ls = {"weights": {"main": 1.0}} - env = {"loss_state": ls} - act = LossStateActuator() + env = {"mutable_state": ls} + act = MockLossActuator() result = act.apply({"op": "set", "key": "main", "value": 2.0}, env) assert result.success assert ls["weights"]["main"] == 2.0 def test_apply_mult(self): ls = {"weights": {"main": 1.0}} - env = {"loss_state": ls} - act = LossStateActuator() + env = {"mutable_state": ls} + act = MockLossActuator() result = act.apply({"op": "mult", "key": "main", "value": 1.5}, env) assert result.success assert ls["weights"]["main"] == pytest.approx(1.5) def test_apply_delta(self): ls = {"weights": {"main": 1.0}} - env = {"loss_state": ls} - act = LossStateActuator() + env = {"mutable_state": ls} + act = MockLossActuator() result = act.apply({"op": "delta", "key": "main", "value": 0.3}, env) assert result.success assert ls["weights"]["main"] == pytest.approx(1.3) - def test_apply_missing_loss_state(self): - act = LossStateActuator() + def test_apply_missing_mutable_state(self): + act = MockLossActuator() result = act.apply({"op": "set", "key": "main", "value": 1.0}, {}) assert not result.success def test_key_bounds(self): - act = LossStateActuator(key_bounds={"main": (0.5, 1.5)}) + act = MockLossActuator(key_bounds={"main": (0.5, 1.5)}) ls = {"weights": {"main": 1.0}} - env = {"loss_state": ls} + env = {"mutable_state": ls} assert not act.validate({"op": "set", "key": "main", "value": 2.0}, env).valid assert act.validate({"op": "set", "key": "main", "value": 1.2}, env).valid @@ -523,7 +740,7 @@ def test_on_event_observe_mode_no_mutations(self): ctrl = HotTuneController(recipe=simple_recipe()) ctrl.state.mode = "observe" opt = FakeOptimizer() - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) env = make_env(optimizer=opt) ctrl.on_event("val_epoch_end", env) assert ctrl.state.mutation_counter == 0 @@ -537,14 +754,14 @@ def test_on_event_active_proposes(self, tmp_path): ) ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("loss", MockLossActuator()) def metric_fn(name, default=None): return {"val/loss": 0.5}.get(name, default) - env = make_env(optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) assert ctrl.state.mutation_counter >= 1 assert ctrl.state.active_mutation is not None @@ -557,8 +774,8 @@ def test_full_accept_cycle(self, tmp_path): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) # First val_epoch_end: proposes and applies mutation metrics_val = [0.5] @@ -568,13 +785,13 @@ def metric_fn(name, default=None): return metrics_val[0] return default - env = make_env(step=100, optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(step=100, optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) assert ctrl.state.active_segment is not None # Second val_epoch_end: evaluates the segment metrics_val[0] = 0.3 # improvement - env = make_env(step=200, optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(step=200, optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) # Segment should have been evaluated @@ -589,8 +806,8 @@ def test_full_reject_and_rollback_cycle(self, tmp_path): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) original_lr = opt.param_groups[0]["lr"] @@ -601,13 +818,13 @@ def metric_fn(name, default=None): return metrics_val[0] return default - env = make_env(step=100, optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(step=100, optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) assert ctrl.state.active_segment is not None # Regression metrics_val[0] = 0.6 - env = make_env(step=200, optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(step=200, optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) assert len(ctrl.state.history) == 1 @@ -620,7 +837,7 @@ def test_safety_block_on_nan(self, tmp_path): run_dir=str(tmp_path), ) ctrl.state.mode = "active" - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) env = make_env(optimizer=FakeOptimizer(), loss=float("nan")) ctrl.on_event("val_epoch_end", env) assert ctrl.state.mutation_counter == 0 @@ -632,7 +849,7 @@ def test_reject_streak_blocks_mutations(self, tmp_path): ) ctrl.state.mode = "active" ctrl.state.reject_streak = 10 # way over limit - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) env = make_env(optimizer=FakeOptimizer()) ctrl.on_event("val_epoch_end", env) # No mutation should be proposed due to reject streak @@ -659,7 +876,7 @@ def test_kernel_has_tune_module(self, tmp_path): def test_kernel_actuator_registry(self, tmp_path): k = HotKernel(run_dir=str(tmp_path)) - act = OptimizerActuator() + act = MockOptActuator() k.register_actuator("opt", act) assert k.get_actuator("opt") is act # Also registered in tune module @@ -772,8 +989,8 @@ def test_suggest_does_not_apply_mutation(self, tmp_path): ctrl.state.mode = "suggest" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) original_lr = opt.param_groups[0]["lr"] original_w = ls["weights"]["main_w"] @@ -781,7 +998,7 @@ def test_suggest_does_not_apply_mutation(self, tmp_path): def metric_fn(name, default=None): return {"val/loss": 0.5}.get(name, default) - env = make_env(optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) # Mutation should be logged as suggested @@ -833,7 +1050,7 @@ def test_replay_applies_mutations_in_order(self, tmp_path): ) ctrl.state.mode = "replay" opt = FakeOptimizer(lr=0.001) - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) env = make_env(step=100, optimizer=opt) ctrl.on_event("val_epoch_end", env) @@ -865,7 +1082,7 @@ def test_replay_exhausted(self, tmp_path): ) ctrl.state.mode = "replay" opt = FakeOptimizer(lr=0.001) - ctrl.register_actuator("opt", OptimizerActuator()) + ctrl.register_actuator("opt", MockOptActuator()) # First event applies ctrl.on_event("val_epoch_end", make_env(step=100, optimizer=opt)) @@ -939,8 +1156,8 @@ def _run_simulation(self, tmp_path, loss_surface, num_epochs=10): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) for epoch in range(num_epochs): step = epoch * 100 @@ -953,7 +1170,7 @@ def metric_fn(name, default=None, _loss=current_loss): env = make_env( step=step, epoch=epoch, optimizer=opt, - loss_state=ls, loss=current_loss, + mutable_state=ls, loss=current_loss, metric_fn=metric_fn, max_steps=num_epochs * 100, ) ctrl.on_event("val_epoch_end", env) @@ -994,8 +1211,8 @@ def test_instability_blocks(self, tmp_path): ) ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) env = make_env(step=100, optimizer=opt, loss=float("nan")) ctrl.on_event("val_epoch_end", env) @@ -1015,8 +1232,8 @@ def test_delayed_reward(self, tmp_path): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) for i, loss_val in enumerate(losses): def metric_fn(name, default=None, _v=loss_val): @@ -1024,7 +1241,7 @@ def metric_fn(name, default=None, _v=loss_val): env = make_env( step=i * 100, epoch=i, optimizer=opt, - loss_state=ls, loss=loss_val, metric_fn=metric_fn, + mutable_state=ls, loss=loss_val, metric_fn=metric_fn, ) ctrl.on_event("val_epoch_end", env) @@ -1069,15 +1286,15 @@ def test_rollback_failure_logged(self, tmp_path): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) metrics_val = [0.5] def metric_fn(name, default=None): return metrics_val[0] if name == "val/loss" else default - env = make_env(step=100, optimizer=opt, loss_state=ls, metric_fn=metric_fn) + env = make_env(step=100, optimizer=opt, mutable_state=ls, metric_fn=metric_fn) ctrl.on_event("val_epoch_end", env) if ctrl.state.active_segment: @@ -1098,10 +1315,10 @@ def test_missing_metric_fn(self, tmp_path): ctrl.state.mode = "active" opt = FakeOptimizer(lr=0.001) ls = {"weights": {"main_w": 1.0}} - ctrl.register_actuator("opt", OptimizerActuator()) - ctrl.register_actuator("loss", LossStateActuator()) + ctrl.register_actuator("opt", MockOptActuator()) + ctrl.register_actuator("loss", MockLossActuator()) - env = make_env(step=100, optimizer=opt, loss_state=ls) + env = make_env(step=100, optimizer=opt, mutable_state=ls) # No metric_fn in env ctrl.on_event("val_epoch_end", env) # Should not crash, may or may not propose @@ -1156,22 +1373,22 @@ def test_empty_actuator_registry_observe_only(self, tmp_path): assert ctrl.state.mutation_counter == 0 def test_validate_missing_value(self): - act = OptimizerActuator() + act = MockOptActuator() env = {"optimizer": FakeOptimizer()} result = act.validate({"op": "lr_mult"}, env) assert not result.valid assert any("missing value" in e for e in result.errors) def test_loss_actuator_validate_missing_key(self): - act = LossStateActuator() - env = {"loss_state": {"weights": {}}} + act = MockLossActuator() + env = {"mutable_state": {"weights": {}}} result = act.validate({"op": "set", "value": 1.0}, env) assert not result.valid assert any("missing key" in e for e in result.errors) def test_loss_actuator_validate_non_numeric(self): - act = LossStateActuator() - env = {"loss_state": {"weights": {}}} + act = MockLossActuator() + env = {"mutable_state": {"weights": {}}} result = act.validate({"op": "set", "key": "x", "value": "bad"}, env) assert not result.valid diff --git a/src/hotcb/tests/test_integration_lightning.py b/src/hotcb/tests/test_integration_lightning.py index ac5b962..7ebbd1d 100644 --- a/src/hotcb/tests/test_integration_lightning.py +++ b/src/hotcb/tests/test_integration_lightning.py @@ -12,6 +12,7 @@ from hotcb.kernel import HotKernel from hotcb.adapters.lightning import HotCBLightning +from hotcb.actuators import loss_actuators, mutable_state as make_mutable_state class TinyModel(pl.LightningModule): @@ -160,16 +161,17 @@ def test_replay_same_lr_change(self, tmp_path): assert len(replay_applied) >= 1 -class TestLossStateMutation: - def test_loss_state_changed(self, tmp_path): +class TestMutableStateMutation: + def test_mutable_state_changed(self, tmp_path): run_dir = str(tmp_path / "run_loss") os.makedirs(run_dir, exist_ok=True) - _write_commands(run_dir, {"module": "loss", "op": "set_params", "id": "main", "params": {"distill_w": 0.5}}) + _write_commands(run_dir, {"module": "loss", "op": "set_params", "id": "main", "params": {"key": "distill", "value": 0.5}}) - loss_state = {"weights": {}, "terms": {}, "ramps": {}} - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) - adapter = HotCBLightning(kernel=kernel, loss_state=loss_state) + loss_weights = {"distill": 1.0} + ms = make_mutable_state(loss_actuators(loss_weights)) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + adapter = HotCBLightning(kernel=kernel) model = TinyModel() trainer = pl.Trainer( @@ -179,7 +181,7 @@ def test_loss_state_changed(self, tmp_path): ) trainer.fit(model) - assert loss_state["weights"]["distill"] == 0.5 + assert loss_weights["distill"] == pytest.approx(0.5) ledger = _read_ledger(run_dir) loss_applied = [e for e in ledger if e.get("module") == "loss" and e.get("decision") == "applied"] diff --git a/src/hotcb/tests/test_kernel_core.py b/src/hotcb/tests/test_kernel_core.py index 833a1f0..b441929 100644 --- a/src/hotcb/tests/test_kernel_core.py +++ b/src/hotcb/tests/test_kernel_core.py @@ -9,6 +9,7 @@ import pytest from hotcb.kernel import HotKernel +from hotcb.actuators import optimizer_actuators, loss_actuators, mutable_state # --------------------------------------------------------------------------- @@ -159,10 +160,11 @@ def test_time_gate_allows_after_interval(self, run_dir, make_env, write_commands class TestRouting: def test_route_to_opt(self, run_dir, make_env, write_commands, read_ledger): - """Ops with module='opt' reach the opt controller.""" + """Ops with module='opt' route through default stream (MutableState).""" optimizer = _mock_optimizer(lr=0.01) write_commands({"module": "opt", "op": "set_params", "params": {"lr": 0.002}}) - kernel = HotKernel(run_dir, debounce_steps=1) + ms = mutable_state(optimizer_actuators(optimizer)) + kernel = HotKernel(run_dir, debounce_steps=1, mutable_state=ms) env = make_env(step=1, optimizer=optimizer) kernel.apply(env, ["step"]) @@ -174,18 +176,19 @@ def test_route_to_opt(self, run_dir, make_env, write_commands, read_ledger): assert optimizer.param_groups[0]["lr"] == pytest.approx(0.002) def test_route_to_loss(self, run_dir, make_env, write_commands, read_ledger): - """Ops with module='loss' reach the loss controller.""" - loss_state = {"weights": {}, "terms": {}, "ramps": {}} - write_commands({"module": "loss", "op": "set_params", "params": {"kl_w": 0.5}}) - kernel = HotKernel(run_dir, debounce_steps=1) - env = make_env(step=1, loss_state=loss_state) + """Ops with module='loss' route through default stream (MutableState).""" + loss_weights = {"kl": 1.0} + write_commands({"module": "loss", "op": "set_params", "params": {"key": "kl", "value": 0.5}}) + ms = mutable_state(loss_actuators(loss_weights)) + kernel = HotKernel(run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) kernel.apply(env, ["step"]) ledger = read_ledger() assert len(ledger) == 1 assert ledger[0]["decision"] == "applied" assert ledger[0]["module"] == "loss" - assert loss_state["weights"]["kl"] == 0.5 + assert loss_weights["kl"] == pytest.approx(0.5) def test_route_to_core(self, run_dir, make_env, write_commands, read_ledger): """Ops with module='core' are handled by the kernel itself.""" @@ -211,17 +214,17 @@ def test_route_to_cb(self, run_dir, make_env, write_commands, read_ledger): assert len(ledger) == 1 assert ledger[0]["module"] == "cb" - def test_opt_missing_optimizer_fails(self, run_dir, make_env, write_commands, read_ledger): - """set_params on opt without an optimizer produces failed.""" + def test_opt_without_mutable_state_fails(self, run_dir, make_env, write_commands, read_ledger): + """set_params on opt without a MutableState produces failed.""" write_commands({"module": "opt", "op": "set_params", "params": {"lr": 0.1}}) kernel = HotKernel(run_dir, debounce_steps=1) - env = make_env(step=1) # no optimizer + env = make_env(step=1) # no mutable_state kernel.apply(env, ["step"]) ledger = read_ledger() assert len(ledger) == 1 assert ledger[0]["decision"] == "failed" - assert "missing_optimizer" in (ledger[0].get("error") or "") + assert "no_mutable_state" in (ledger[0].get("error") or "") # --------------------------------------------------------------------------- @@ -265,9 +268,11 @@ def test_seq_monotonically_increasing(self, run_dir, make_env, write_commands, r def test_required_fields_populated(self, run_dir, make_env, write_commands, read_ledger): """step, event, source, decision fields are present in every ledger record.""" - write_commands({"module": "opt", "op": "enable", "id": "a"}) - kernel = HotKernel(run_dir, debounce_steps=1) - env = make_env(step=42, optimizer=_mock_optimizer()) + optimizer = _mock_optimizer() + ms = mutable_state(optimizer_actuators(optimizer)) + write_commands({"module": "opt", "op": "set_params", "params": {"lr": 0.005}}) + kernel = HotKernel(run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=42, optimizer=optimizer) kernel.apply(env, ["on_batch_end"]) ledger = read_ledger() @@ -282,8 +287,8 @@ def test_required_fields_populated(self, run_dir, make_env, write_commands, read def test_failure_includes_error_text(self, run_dir, make_env, write_commands, read_ledger): """Failed ops have error text in the ledger.""" write_commands({"module": "opt", "op": "set_params", "params": {"lr": 0.1}}) - kernel = HotKernel(run_dir, debounce_steps=1) - env = make_env(step=1) # no optimizer -> failure + kernel = HotKernel(run_dir, debounce_steps=1) # no mutable_state -> failure + env = make_env(step=1) kernel.apply(env, ["step"]) ledger = read_ledger() @@ -298,9 +303,9 @@ def test_failure_includes_error_text(self, run_dir, make_env, write_commands, re # --------------------------------------------------------------------------- class TestUnknownModule: - def test_unknown_module_fails(self, run_dir, make_env, write_commands, read_ledger): - """Op with an unknown module is recorded as failed with error text.""" - write_commands({"module": "xyz", "op": "enable", "id": "thing"}) + def test_unknown_module_without_mutable_state_fails(self, run_dir, make_env, write_commands, read_ledger): + """Op with an unknown module and no MutableState is recorded as failed.""" + write_commands({"module": "xyz", "op": "set_params", "params": {"key": "thing", "value": 1}}) kernel = HotKernel(run_dir, debounce_steps=1) env = make_env(step=1) kernel.apply(env, ["step"]) @@ -308,8 +313,7 @@ def test_unknown_module_fails(self, run_dir, make_env, write_commands, read_ledg ledger = read_ledger() assert len(ledger) == 1 assert ledger[0]["decision"] == "failed" - assert "unknown_module" in ledger[0]["error"] - assert "xyz" in ledger[0]["error"] + assert "no_mutable_state" in ledger[0]["error"] def test_unknown_module_still_produces_ledger_record(self, run_dir, make_env, write_commands, read_ledger): """Unknown modules still get a ledger entry with correct fields.""" diff --git a/src/hotcb/tests/test_new_features.py b/src/hotcb/tests/test_new_features.py index e0ae0dc..d9670f0 100644 --- a/src/hotcb/tests/test_new_features.py +++ b/src/hotcb/tests/test_new_features.py @@ -18,8 +18,8 @@ from hotcb.kernel import HotKernel from hotcb.modules.cb import CallbackModule, _capture_source -from hotcb.modules.opt import HotOptController -from hotcb.modules.loss import HotLossController +from hotcb.actuators import optimizer_actuators, loss_actuators, mutable_state, ApplyResult +from hotcb.actuators.actuator import ActuatorType, HotcbActuator from hotcb.ops import HotOp from hotcb.cli import ( cmd_status, @@ -169,93 +169,124 @@ def test_external_source_capture_still_works( # =========================================================================== class TestTracebackInModuleResult: - """When set_params fails with an exception, ModuleResult should carry traceback.""" - - def test_opt_set_params_failure_has_traceback(self): - """HotOptController.set_params failure produces a result with traceback.""" - ctrl = HotOptController() - # Passing group index that doesn't exist on a single-group optimizer - optimizer = _mock_optimizer(lr=0.01) - env = {"optimizer": optimizer} - op = HotOp(module="opt", op="set_params", params={"group": "99", "lr": 0.001}) - result = ctrl.apply_op(op, env) - assert result.decision == "failed" - assert result.error is not None - assert result.traceback is not None - assert "Traceback" in result.traceback - - def test_loss_set_params_failure_has_traceback(self): - """HotLossController.set_params failure produces a result with traceback.""" - ctrl = HotLossController() - # Give a loss_state where _apply_params will raise due to non-dict value - # We force an error by patching the loss_state to raise on setdefault - class BadLossState(dict): - def setdefault(self, key, default=None): - raise TypeError("intentional error for test") - - env = {"loss_state": BadLossState()} - op = HotOp(module="loss", op="set_params", params={"kl_w": 0.5}) - result = ctrl.apply_op(op, env) - assert result.decision == "failed" + """When set_params fails with an exception, error info is recorded.""" + + def test_opt_apply_fn_failure_has_error(self): + """Optimizer actuator apply_fn failure produces error in ApplyResult.""" + def _bad_apply(value, env): + raise RuntimeError("gpu exploded") + + bad_act = HotcbActuator( + param_key="bad_lr", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + min_value=0.0, + max_value=1.0, + current_value=0.01, + ) + from hotcb.actuators.state import MutableState + ms = MutableState([bad_act]) + result = ms.apply("bad_lr", 0.001, {}, step=1) + assert not result.success assert result.error is not None - assert result.traceback is not None - assert "Traceback" in result.traceback + assert "gpu exploded" in result.error + + def test_loss_apply_fn_failure_has_error(self): + """Loss actuator apply_fn failure produces error in ApplyResult.""" + def _bad_apply(value, env): + raise TypeError("intentional error for test") + + bad_act = HotcbActuator( + param_key="bad_weight", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + min_value=0.0, + max_value=100.0, + current_value=1.0, + ) + from hotcb.actuators.state import MutableState + ms = MutableState([bad_act]) + result = ms.apply("bad_weight", 0.5, {}, step=1) + assert not result.success + assert "intentional error" in result.error - def test_kernel_ledger_has_traceback_on_opt_failure( + def test_kernel_ledger_has_error_on_opt_failure( self, run_dir, make_env, write_commands, read_ledger, ): - """When opt set_params fails through the kernel, ledger entry has traceback.""" + """When opt set_params fails through the kernel, ledger entry has error.""" + def _bad_apply(value, env): + raise RuntimeError("gpu exploded") + + bad_act = HotcbActuator( + param_key="lr", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + min_value=0.0, + max_value=1.0, + current_value=0.01, + ) + ms = mutable_state([bad_act]) + write_commands({ "module": "opt", "op": "set_params", - "params": {"group": "99", "lr": 0.001}, + "params": {"key": "lr", "value": 0.001}, }) - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) - optimizer = _mock_optimizer(lr=0.01) - env = make_env(step=1, optimizer=optimizer) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) kernel.apply(env, ["train_step_end"]) ledger = read_ledger() assert len(ledger) == 1 assert ledger[0]["decision"] == "failed" - assert ledger[0]["traceback"] is not None - assert "Traceback" in ledger[0]["traceback"] + assert ledger[0]["error"] is not None + assert "gpu exploded" in ledger[0]["error"] - def test_kernel_ledger_has_traceback_on_loss_failure( + def test_kernel_ledger_has_error_on_loss_failure( self, run_dir, make_env, write_commands, read_ledger, ): - """When loss set_params fails through the kernel, ledger entry has traceback.""" - - class BadLossState(dict): - def setdefault(self, key, default=None): - raise TypeError("intentional error for test") + """When loss set_params fails through the kernel, ledger entry has error.""" + def _bad_apply(value, env): + raise TypeError("intentional error for test") + + bad_act = HotcbActuator( + param_key="kl", + type=ActuatorType.FLOAT, + apply_fn=_bad_apply, + min_value=0.0, + max_value=100.0, + current_value=1.0, + ) + ms = mutable_state([bad_act]) write_commands({ "module": "loss", "op": "set_params", - "params": {"kl_w": 0.5}, + "params": {"key": "kl", "value": 0.5}, }) - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) - env = make_env(step=1, loss_state=BadLossState()) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) + env = make_env(step=1) kernel.apply(env, ["train_step_end"]) ledger = read_ledger() assert len(ledger) == 1 assert ledger[0]["decision"] == "failed" - assert ledger[0]["traceback"] is not None - assert "Traceback" in ledger[0]["traceback"] + assert ledger[0]["error"] is not None + assert "intentional error" in ledger[0]["error"] def test_successful_op_has_no_traceback( self, run_dir, make_env, write_commands, read_ledger, ): """Successful ops should have traceback=None in the ledger.""" + optimizer = _mock_optimizer(lr=0.01) + ms = mutable_state(optimizer_actuators(optimizer)) + write_commands({ "module": "opt", "op": "set_params", "params": {"lr": 0.001}, }) - kernel = HotKernel(run_dir=run_dir, debounce_steps=1) - optimizer = _mock_optimizer(lr=0.01) + kernel = HotKernel(run_dir=run_dir, debounce_steps=1, mutable_state=ms) env = make_env(step=1, optimizer=optimizer) kernel.apply(env, ["train_step_end"]) diff --git a/src/hotcb/tests/test_robustness.py b/src/hotcb/tests/test_robustness.py index 6557f43..d90c239 100644 --- a/src/hotcb/tests/test_robustness.py +++ b/src/hotcb/tests/test_robustness.py @@ -58,19 +58,23 @@ def test_jsonl_partial_line(run_dir, make_env, read_ledger): assert isinstance(ledger, list) # no crash -def test_read_new_jsonl_partial_line_raises(): - """read_new_jsonl itself must raise json.JSONDecodeError on malformed JSON.""" +def test_read_new_jsonl_skips_malformed_lines(): + """read_new_jsonl should skip malformed JSON lines instead of raising.""" import tempfile with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: f.write(json.dumps({"module": "opt", "op": "enable"}) + "\n") - f.write('{"module":"opt"\n') + f.write('{"module":"opt"\n') # malformed + f.write(json.dumps({"module": "loss", "op": "set_params"}) + "\n") path = f.name try: cursor = FileCursor(path=path, offset=0) - with pytest.raises(json.JSONDecodeError): - read_new_jsonl(cursor) + records, _ = read_new_jsonl(cursor) + # Should get 2 valid records, malformed line skipped + assert len(records) == 2 + assert records[0]["module"] == "opt" + assert records[1]["module"] == "loss" finally: os.unlink(path) diff --git a/src/hotcb/tests/test_server_app.py b/src/hotcb/tests/test_server_app.py index 0911bfa..6fa3ffe 100644 --- a/src/hotcb/tests/test_server_app.py +++ b/src/hotcb/tests/test_server_app.py @@ -184,3 +184,200 @@ def test_applied_history(self, client): records = r.json()["records"] assert len(records) == 1 assert records[0]["module"] == "opt" + + +# --------------------------------------------------------------------------- +# Phase 5: Immutable run_dir tests +# --------------------------------------------------------------------------- + + +class TestImmutableRunDir: + """Phase 5: run_dir is set once at create_app() and never changes.""" + + def test_no_app_state_run_dir_attribute(self, populated_dir): + """app.state should not have a mutable run_dir; use config.run_dir.""" + app = create_app(populated_dir, poll_interval=60) + # app.state.run_dir should NOT be set (removed in Phase 5) + assert not hasattr(app.state, "run_dir"), ( + "app.state.run_dir should not exist; use app.state.config.run_dir" + ) + + def test_config_run_dir_is_set(self, populated_dir): + """app.state.config.run_dir should be set to the resolved run_dir.""" + app = create_app(populated_dir, poll_interval=60) + assert hasattr(app.state, "config") + assert app.state.config.run_dir == populated_dir + + def test_endpoints_use_config_run_dir(self, populated_dir): + """All endpoints should read from config.run_dir.""" + from starlette.testclient import TestClient + app = create_app(populated_dir, poll_interval=60) + client = TestClient(app) + + # /api/health returns the config run_dir + r = client.get("/api/health") + assert r.status_code == 200 + assert r.json()["run_dir"] == populated_dir + + # /api/status returns data from config run_dir + r = client.get("/api/status") + assert r.status_code == 200 + assert r.json()["run_dir"] == populated_dir + + # /api/metrics/history returns data from config run_dir + r = client.get("/api/metrics/history?last_n=3") + assert r.status_code == 200 + assert len(r.json()["records"]) == 3 + + def test_config_endpoint_returns_run_dir(self, populated_dir): + """GET /api/config should include the immutable run_dir.""" + from starlette.testclient import TestClient + app = create_app(populated_dir, poll_interval=60) + client = TestClient(app) + + r = client.get("/api/config") + assert r.status_code == 200 + data = r.json() + assert data["run_dir"] == populated_dir + + def test_discover_runs_reads_only(self, populated_dir): + """/api/runs/discover should scan read-only, never mutate run_dir.""" + from starlette.testclient import TestClient + app = create_app(populated_dir, poll_interval=60) + client = TestClient(app) + + r = client.get("/api/runs/discover") + assert r.status_code == 200 + runs = r.json()["runs"] + # populated_dir has metrics, so it should be discovered + assert len(runs) >= 1 + # config.run_dir unchanged after discover + assert app.state.config.run_dir == populated_dir + + +class TestTailerNoRewire: + """Phase 5: JsonlTailer no longer has a rewire method.""" + + def test_tailer_no_rewire_method(self): + """JsonlTailer should not have a rewire() method.""" + from hotcb.server.tailer import JsonlTailer + tailer = JsonlTailer() + assert not hasattr(tailer, "rewire"), ( + "JsonlTailer.rewire() should be removed in Phase 5" + ) + + def test_tailer_still_has_diagnostics(self): + """get_cursor_offsets should still be available for diagnostics.""" + from hotcb.server.tailer import JsonlTailer + tailer = JsonlTailer() + path = "/tmp/nonexistent.jsonl" + tailer.watch("test", path) + offsets = tailer.get_cursor_offsets() + assert "test" in offsets + assert offsets["test"] == 0 + + +class TestLauncherImmutableRunDir: + """Phase 5: Launcher writes directly to run_dir, no subdirs.""" + + def test_launcher_writes_to_run_dir_directly(self): + """Launcher.start() should write JSONL files to run_dir, not subdirs.""" + import threading + + with tempfile.TemporaryDirectory() as tmpdir: + from hotcb.server.launcher import TrainingLauncher, TrainingConfig + + # Register a minimal training config + def _noop_train(run_dir, max_steps, step_delay, stop_event): + # Write one metric record and exit + with open(os.path.join(run_dir, "hotcb.metrics.jsonl"), "a") as f: + f.write(json.dumps({"step": 0, "metrics": {"loss": 0.5}}) + "\n") + + launcher = TrainingLauncher(tmpdir) + launcher.register_config(TrainingConfig( + config_id="test", + name="Test", + description="test", + train_fn=_noop_train, + defaults={"max_steps": 1, "step_delay": 0.0}, + )) + + result = launcher.start(config_id="test", max_steps=1, step_delay=0.0) + assert result.get("started") is True + assert result["run_dir"] == tmpdir # writes to run_dir directly + + # Wait for training to finish + import time + for _ in range(50): + if not launcher.running: + break + time.sleep(0.1) + + # JSONL files should be in tmpdir, not in subdirs + assert os.path.exists(os.path.join(tmpdir, "hotcb.metrics.jsonl")) + assert os.path.exists(os.path.join(tmpdir, "hotcb.run.json")) + + # No subdirs should have been created + entries = os.listdir(tmpdir) + subdirs = [e for e in entries if os.path.isdir(os.path.join(tmpdir, e))] + assert subdirs == [], f"No subdirs expected, found: {subdirs}" + + def test_launcher_truncates_on_restart(self): + """Starting training again should truncate JSONL files.""" + with tempfile.TemporaryDirectory() as tmpdir: + from hotcb.server.launcher import TrainingLauncher, TrainingConfig + + step_counter = [0] + + def _counting_train(run_dir, max_steps, step_delay, stop_event): + step_counter[0] += 1 + with open(os.path.join(run_dir, "hotcb.metrics.jsonl"), "a") as f: + f.write(json.dumps({ + "step": 0, "metrics": {"loss": 0.5, "run": step_counter[0]} + }) + "\n") + + launcher = TrainingLauncher(tmpdir) + launcher.register_config(TrainingConfig( + config_id="test", + name="Test", + description="test", + train_fn=_counting_train, + defaults={"max_steps": 1, "step_delay": 0.0}, + )) + + # First run: write some data + metrics_path = os.path.join(tmpdir, "hotcb.metrics.jsonl") + with open(metrics_path, "w") as f: + f.write(json.dumps({"step": 99, "metrics": {"old": True}}) + "\n") + + # Start should truncate existing data + result = launcher.start(config_id="test", max_steps=1, step_delay=0.0) + assert result.get("started") is True + + import time + for _ in range(50): + if not launcher.running: + break + time.sleep(0.1) + + # Read the metrics file — old data (step 99) should be gone + records = [] + with open(metrics_path) as f: + for line in f: + line = line.strip() + if line: + records.append(json.loads(line)) + + # Should only have the new record, not the old step=99 + assert all(r["step"] != 99 for r in records), ( + "Old data should be truncated on restart" + ) + + def test_launcher_no_active_run_dir(self): + """TrainingLauncher should not have _active_run_dir attribute.""" + with tempfile.TemporaryDirectory() as tmpdir: + from hotcb.server.launcher import TrainingLauncher + launcher = TrainingLauncher(tmpdir) + assert not hasattr(launcher, "_active_run_dir"), ( + "_active_run_dir removed in Phase 5; launcher uses _run_dir only" + ) diff --git a/src/hotcb/util.py b/src/hotcb/util.py index a25466b..6585268 100644 --- a/src/hotcb/util.py +++ b/src/hotcb/util.py @@ -11,10 +11,17 @@ class FileCursor: """ Tracks incremental read state for an append-only file (JSONL). + + ``last_size`` stores the file size at the time of the last read, + enabling proper truncation detection: if the file shrinks, the + cursor resets and the ``truncated`` flag is set so callers can + distinguish "new file" from "appended data". """ path: str offset: int = 0 + last_size: int = 0 + truncated: bool = False def ensure_dir(path: str) -> None: @@ -35,15 +42,31 @@ def safe_mtime(path: str) -> float: def read_new_jsonl(cursor: FileCursor, max_lines: int = 10_000) -> Tuple[List[dict], FileCursor]: """ Read newly appended JSONL records starting from cursor.offset. + + Handles file truncation safely: if the file shrinks (e.g. ``open(f, 'w')`` + clears it), the cursor resets to byte 0 and reads the new content. The + returned cursor has ``truncated=True`` so callers can distinguish a fresh + file from an append to an existing one. + + When the file is *overwritten* (truncated then written with fewer bytes + than before), the cursor detects the shrink via ``last_size`` and resets + to 0 — reading only the new content, not stale leftovers. """ if not os.path.exists(cursor.path): return [], cursor - # Detect file truncation (reset) file_size = os.path.getsize(cursor.path) effective_offset = cursor.offset - if file_size < effective_offset: - effective_offset = 0 # file was truncated, start from beginning + was_truncated = False + + # Detect truncation: file shrank since our last read + if file_size < cursor.last_size: + effective_offset = 0 + was_truncated = True + # Also catch: cursor past EOF (file overwritten with shorter content) + elif file_size < effective_offset: + effective_offset = 0 + was_truncated = True out: List[dict] = [] with open(cursor.path, "r", encoding="utf-8") as f: @@ -55,16 +78,30 @@ def read_new_jsonl(cursor: FileCursor, max_lines: int = 10_000) -> Tuple[List[di s = line.strip() if not s: continue - out.append(json.loads(s)) + try: + out.append(json.loads(s)) + except json.JSONDecodeError: + continue new_offset = f.tell() - return out, FileCursor(path=cursor.path, offset=new_offset) + return out, FileCursor( + path=cursor.path, + offset=new_offset, + last_size=file_size, + truncated=was_truncated, + ) def append_jsonl(path: str, obj: dict) -> None: - """Append a single JSON object to a JSONL file (creates parent dirs).""" + """Append a single JSON object as a line to a JSONL file (with file locking).""" + import fcntl ensure_dir(os.path.dirname(path)) + line = json.dumps(sanitize_floats(obj), ensure_ascii=False) + "\n" with open(path, "a", encoding="utf-8") as f: - f.write(json.dumps(obj, ensure_ascii=False) + "\n") + fcntl.flock(f, fcntl.LOCK_EX) + try: + f.write(line) + finally: + fcntl.flock(f, fcntl.LOCK_UN) def dedupe_keep_order(items: Iterable[Any]) -> List[Any]: @@ -82,6 +119,20 @@ def dedupe_keep_order(items: Iterable[Any]) -> List[Any]: return out +def sanitize_floats(obj: Any) -> Any: + """Replace NaN/inf/-inf with None recursively in dicts/lists for JSON safety.""" + import math + if isinstance(obj, float): + if math.isnan(obj) or math.isinf(obj): + return None + return obj + if isinstance(obj, dict): + return {k: sanitize_floats(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [sanitize_floats(v) for v in obj] + return obj + + def now() -> float: """Epoch seconds helper.""" return time.time() diff --git a/v3_simulator_wip_plan.md b/v3_simulator_wip_plan.md deleted file mode 100644 index 429151b..0000000 --- a/v3_simulator_wip_plan.md +++ /dev/null @@ -1,418 +0,0 @@ -# hotcb 2.0 — Live Training Control Plane: Dashboard & Interactive System Plan - -## Realism Assessment: The Paradigm Shift - -**The core insight is sound.** Training today is "fire and forget" — you set hyperparameters, walk away, and check back hours/days later. The feedback loop is glacially slow. Making it interactive (3-10 interventions/hour instead of 1/3-hours) is genuinely valuable. Here's why: - -**What makes this realistic:** -- You already have the hard part built — the kernel, actuators, freeze/replay, tune engine. Most "interactive training" projects die trying to build safe hot-mutation infrastructure. You have it. -- The JSONL streaming architecture is inherently dashboard-friendly — just tail the files. -- Recipe replay means every human intervention is reproducible, which is the killer feature for papers and production. -- XGBoost projections on loss curves are computationally cheap and surprisingly accurate for short-horizon forecasting. - -**What to be honest about:** -- The LinkedIn traction issue isn't the concept — it's that callback hot-swap alone is a power-user niche. The dashboard + interactive tuning story is what makes it "paradigm shift" material. **Lead with the UI, not the plumbing.** -- At-scale (100B+ params, multi-node) the human-in-the-loop model breaks down — but your freeze/recipe modes already handle this correctly. Position it as: "interactive development → freeze recipe → production replay." -- Competition: W&B Sweeps, Optuna Dashboard, Ray Tune — but none of them do **live hot-swap with human override + replay**. That's your moat. - -**Marketing pivot suggestion:** Don't call it "callback management." Call it something like **"hotcb: Live Training Cockpit"** or **"Interactive Training Control Plane."** The simulation/cockpit metaphor maps perfectly to your knobs + graphs + projections vision. - ---- - -## Architecture Overview - -``` -┌─────────────────────────────────────────────────────┐ -│ Training Process │ -│ ┌──────────┐ ┌──────────┐ ┌───────────────────┐ │ -│ │ Adapter │→│ HotKernel │→│ Actuators/Modules │ │ -│ │(LT/HF) │ │ │ │(opt/loss/cb/tune) │ │ -│ └──────────┘ └────┬─────┘ └───────────────────┘ │ -│ │ JSONL streams + metrics │ -│ ▼ │ -│ ┌──────────────────────────────────────────┐ │ -│ │ MetricsCollector (new) │ │ -│ │ - intercepts env metrics at each step │ │ -│ │ - writes hotcb.metrics.jsonl │ │ -│ │ - ring buffer for feature snapshots │ │ -│ └──────────────┬───────────────────────────┘ │ -└─────────────────┼─────────────────────────────────────┘ - │ filesystem (JSONL/YAML) - ▼ -┌─────────────────────────────────────────────────────┐ -│ hotcb-server (new, separate process) │ -│ ┌─────────────┐ ┌──────────────┐ ┌────────────┐ │ -│ │ File Tailer │ │ Projection │ │ WebSocket │ │ -│ │ (metrics, │ │ Engine │ │ Server │ │ -│ │ applied, │ │ (XGBoost, │ │ (FastAPI) │ │ -│ │ mutations, │ │ manifold, │ │ │ │ -│ │ segments) │ │ feature-PCA)│ │ │ │ -│ └─────────────┘ └──────────────┘ └────────────┘ │ -│ ┌─────────────┐ ┌──────────────┐ ┌────────────┐ │ -│ │ Notification │ │ Recipe │ │ REST API │ │ -│ │ Engine │ │ Editor │ │ (commands, │ │ -│ │ (email/slack)│ │ (trim/edit/ │ │ status) │ │ -│ │ │ │ replay) │ │ │ │ -│ └─────────────┘ └──────────────┘ └────────────┘ │ -└─────────────────────┬───────────────────────────────┘ - │ HTTP + WebSocket - ▼ -┌─────────────────────────────────────────────────────┐ -│ Dashboard (React SPA, served by hotcb-server) │ -│ │ -│ ┌─ Control Bar ──────────────────────────────────┐ │ -│ │ [Mode: Engineer ▾] [Freeze ▾] [Notifications] │ │ -│ └─────────────────────────────────────────────────┘ │ -│ ┌─ Live Metrics Panel ───────────────────────────┐ │ -│ │ streaming loss/metric charts (multi-run) │ │ -│ │ + projection overlays (dashed lines) │ │ -│ │ + intervention markers (vertical lines) │ │ -│ └─────────────────────────────────────────────────┘ │ -│ ┌─ Knob Panel ──────┐ ┌─ Projection Panel ──────┐ │ -│ │ lr: [====●===] 3e-4│ │ XGBoost forecast │ │ -│ │ wd: [==●=====] 1e-2│ │ manifold plot │ │ -│ │ loss_w: [●===] 0.5 │ │ feature PCA (3D) │ │ -│ │ [Apply] [Schedule] │ │ [Lock metrics ▾] │ │ -│ └────────────────────┘ └─────────────────────────┘ │ -│ ┌─ Mutation Timeline ─────────────────────────────┐ │ -│ │ step 100: lr 3e-4→1e-3 ✓ step 200: wd +0.01 ✗│ │ -│ │ [Edit Recipe] [Export] [Replay Preview] │ │ -│ └─────────────────────────────────────────────────┘ │ -└───────────────────────────────────────────────────────┘ -``` - -**Key design decision:** The server is a **separate process** that communicates with the training process only through the filesystem (JSONL streams). This means: -- Zero coupling to training code — no new imports in training loop -- Works with any framework adapter (Lightning, HF, raw PyTorch) -- Dashboard can attach/detach without affecting training -- Multiple dashboards can observe the same run - ---- - -## Tech Stack - -| Layer | Choice | Rationale | -|-------|--------|-----------| -| **Server** | FastAPI + uvicorn | Async WebSocket native, lightweight, Python ecosystem | -| **File tailing** | watchdog + incremental JSONL read | Reuse existing `FileCursor` pattern from `util.py` | -| **Frontend** | React + TypeScript | Professional look, rich ecosystem, SSR not needed | -| **Charts** | Plotly.js (via react-plotly) | 3D support, streaming updates, publication-quality | -| **Knobs/Controls** | Custom React + headless UI | Sliders, toggles, scheduling modals | -| **Projections** | XGBoost (server-side) | Multivariate forecast, cheap, well-understood | -| **Manifolds** | UMAP/t-SNE (server-side) | For metric manifolds and feature space | -| **Feature extraction** | PyTorch hooks (opt-in) | Register forward hooks on selected layers | -| **Notifications** | slack_sdk + smtplib | Threshold-based alerts, anomaly projections | -| **Build** | Vite for frontend, bundled as static assets | `hotcb serve` serves the SPA | -| **State sync** | WebSocket (server→client), REST POST (client→server→JSONL) | Unidirectional data flow | - ---- - -## Phase Plan - -### Phase 1: Foundation — Metrics Streaming + Server Skeleton (Week 1-2) - -**Goal:** Get a live dashboard showing streaming metrics from a running training. - -**New modules:** -- `hotcb.server` — FastAPI app - - `hotcb.server.app` — main app, mount static, WebSocket endpoints - - `hotcb.server.tailer` — background task tailing JSONL files, pushing to WebSocket - - `hotcb.server.api` — REST endpoints for commands, status, config -- `hotcb.metrics` — MetricsCollector - - Hooks into kernel.apply() to capture `env["metric"]` values - - Writes `hotcb.metrics.jsonl` with step, epoch, timestamp, metric_name, value - - Configurable metric whitelist/blacklist - -**CLI addition:** -```bash -hotcb serve --dir --port 8421 --host 0.0.0.0 -``` - -**Frontend (minimal):** -- Single page with streaming line charts (Plotly) -- Multi-metric overlay (select which metrics to show) -- Intervention markers from `hotcb.applied.jsonl` -- Basic status bar (freeze mode, tune mode, active mutations) - -**Kernel changes:** -- Add `MetricsCollector` to kernel (opt-in, zero overhead when unused) -- Emit structured metric events: `{"step": N, "metrics": {"train_loss": 0.5, "val_loss": 0.6, ...}}` - -### Phase 2: Interactive Controls — Knobs + Commands (Week 2-3) - -**Goal:** Control training from the dashboard. - -**Server additions:** -- REST endpoints that write to `hotcb.commands.jsonl`: - - `POST /api/opt/set` — `{lr: 0.001, wd: 0.01}` - - `POST /api/loss/set` — `{recon_w: 0.5}` - - `POST /api/tune/mode` — `{mode: "active"}` - - `POST /api/cb/{id}/enable|disable` - - `POST /api/freeze` — `{mode: "prod"}` - - `POST /api/schedule` — `{at_step: 500, module: "opt", op: "set_params", params: {...}}` -- Validation endpoint: `POST /api/validate` — dry-run a mutation against actuator bounds - -**Frontend additions:** -- **Knob panel:** Sliders for each actuator parameter with: - - Current value (live from applied ledger) - - Bounds from `actuator.describe_space()` / tune recipe - - "Apply" button → writes command - - "Schedule" button → deferred application at step N -- **Quick actions:** Enable/disable tune, freeze mode toggle -- **Command history:** Live feed of applied operations with status badges - -### Phase 3: Projections — XGBoost + Metric Forecasting (Week 3-4) - -**Goal:** Show where training is heading. - -**Server additions:** -- `hotcb.server.projections` module: - - **Univariate forecast:** XGBoost trained on recent N steps of a single metric, projects K steps ahead - - **Multivariate forecast:** Given a proposed HP change, predict impact on all tracked metrics - - Train XGBoost on (step, hp_values, metric_values) → next_metric_values - - Show "what-if" overlays: "if you change lr to X, projected loss trajectory is..." - - **Confidence bands:** Bootstrap or quantile regression for uncertainty -- WebSocket channel for projection updates (recomputed on new data or HP change preview) - -**Frontend additions:** -- Dashed projection lines on metric charts with confidence bands -- "What-if" mode: drag a knob, see projected impact before committing -- "Lock metrics" selector: choose a set of metrics to project together -- Projection horizon slider (how far ahead to forecast) - -### Phase 4: Manifolds + Feature Space (Week 4-5) - -**Goal:** Visualize the loss landscape and feature space dynamics. - -**Server additions:** -- `hotcb.server.manifolds`: - - **Metric manifold:** UMAP/t-SNE on the vector of (all tracked metrics) across steps - - Shows trajectory through metric space, colored by time - - Intervention points highlighted - - **Feature space projection (opt-in):** - - Training process registers forward hooks on selected layers - - Writes activation snapshots to `hotcb.features.bin` (memory-mapped, ring buffer) - - Server reads snapshots, runs PCA→3D - - Shows how representation space evolves - -**Frontend additions:** -- 3D Plotly scatter for metric manifold (rotatable, zoomable) -- 3D feature space viewer (toggled on-demand to avoid overhead) -- Color coding: step progression, intervention markers, segment boundaries -- Side-by-side: metric manifold + loss curve, linked brushing - -**Kernel changes:** -- Optional `FeatureCapture` hook: - ```python - kernel.enable_feature_capture(model, layer_names=["encoder.layer.4"], every_n_steps=50, max_samples=256) - ``` -- Writes compressed activations (PCA pre-reduced to 64 dims in-process to save I/O) - -### Phase 5: Management — Notifications + Alerts (Week 5-6) - -**Goal:** The dashboard works for you when you're away. - -**Server additions:** -- `hotcb.server.notifications`: - - **Threshold alerts:** "Notify me if val_loss > X" or "if projection shows divergence" - - **Anomaly detection:** Z-score on recent metric windows, flag spikes - - **Channels:** Slack webhook, email (SMTP), desktop notification (WebSocket push) - - **Suggestion toggles:** "Pause and suggest" mode — when anomaly detected, pause tune and suggest human review -- **Scheduling:** Cron-like for recurring checks - -**Frontend additions:** -- Notification panel with alert history -- Alert configuration UI (metric, threshold, channel, action) -- "Call for help" button → sends formatted Slack/email with current state snapshot + charts - -### Phase 6: Recipe Editor + Replay Dashboard (Week 6-7) - -**Goal:** Edit and replay training recipes with the same visual quality. - -**Server additions:** -- `hotcb.server.recipe_editor`: - - Load recipe JSONL, parse into timeline - - CRUD operations on recipe entries (add, remove, modify, reorder) - - Apply adjustment overlays (shift_step, replace_params, etc.) - - Validate recipe against actuator bounds - - Export edited recipe - -**Frontend additions:** -- **Timeline editor:** Visual timeline of all recipe entries - - Drag to reorder, click to edit params, right-click to delete - - "Insert intervention" at any step -- **Replay preview:** Show what the recipe would do at each step (dry-run visualization) -- **Diff view:** Compare two recipes side-by-side -- **Replay dashboard:** Same streaming charts, but replaying a previous run's recipe - - Overlay: original run metrics vs. replay run metrics - -### Phase 7: Multi-Mode UI — Engineer / Education / Vibe-Coder (Week 7-8) - -**Goal:** Three audience modes. - -**Mode definitions:** -- **Engineer mode** (default): All knobs exposed, raw metric names, actuator details, full recipe editor, CLI integration -- **Education mode:** Simplified knobs with explanations ("Learning Rate: how fast the model learns"), tooltips, guided tutorials, "what does this do?" on every control, limited to safe mutations -- **Vibe-coder mode:** AI-assisted suggestions, natural language commands ("make it learn faster"), auto-bounds from tune recipe, simplified dashboard with just key metrics + a "health score" - -**Implementation:** -- Mode stored in dashboard state + persisted to `hotcb.ui.json` -- Components conditionally render based on mode -- Education mode: wrap controls in `` components with tooltip text from a knowledge base -- Vibe-coder mode: add a chat/command bar that translates NL → hotcb commands (could use a local LLM or rule-based for v1) - -### Phase 8: Self-Mode + Community Guidelines (Week 8-9) - -**Goal:** Autonomous operation with guardrails. - -**Server additions:** -- `hotcb.server.autopilot`: - - **Rule engine:** Load community guidelines YAML (published best practices) - ```yaml - rules: - - if: "val_loss plateau > 5 epochs" - then: "reduce lr by 0.5x" - confidence: high - - if: "train_loss < val_loss * 0.5" - then: "increase weight_decay by 2x" - confidence: medium - ``` - - **Action loop:** Monitor metrics → match rules → propose or auto-apply based on confidence - - **Human-in-the-loop:** Low confidence → notify + wait for approval. High confidence → apply + notify. - - **Community guideline sources:** Built-in defaults + user-contributed YAML files (future: community repo) - -**Integration with tune module:** -- Self-mode uses the existing tune controller but with rule-based proposals instead of TPE -- Mutations go through the same safety checks (constraints, cooldowns, risk levels) - -### Phase 9: Benchmarking + Paper Eval (Week 9-10) - -**Goal:** Reproducible benchmarks for publication. - -**Components:** -- `hotcb.bench` module: - - **Benchmark suite:** Standard tasks (CIFAR-10, MNIST, synthetic) with defined HP search spaces - - **Comparison modes:** - - Baseline: fixed HP, no intervention - - Auto-tune: hotcb tune in active mode (no human) - - Human-interactive: hotcb with dashboard (track human decision times, quality) - - Recipe replay: reproduce best interactive run - - **Metrics collected:** final metric, time-to-target, human intervention count, compute cost - - **Export:** LaTeX tables, matplotlib figures, raw CSV - -**Production recipe-replay benchmark:** -- Compare: original training (hours of tuning) vs. recipe replay (deterministic, no search overhead) -- Show: same final quality, fraction of the compute - ---- - -## Multi-Run Support - -One thing your vision implies but isn't explicit: **multi-run comparison.** - -- Dashboard should support attaching to multiple `run_dir`s simultaneously -- Overlay metrics from different runs (different HP configs, different recipes) -- Compare: "Run A (lr=1e-3) vs Run B (lr=1e-4)" live -- This is what makes the "simulation with knobs" metaphor really land - -Implementation: server takes `--dirs run1,run2,run3` or discovers runs in a parent directory. - ---- - -## Critical Path & Dependencies - -``` -Phase 1 (metrics + server) ← everything depends on this - ├── Phase 2 (knobs) ← needs server + WebSocket - ├── Phase 3 (projections) ← needs metrics stream - │ └── Phase 5 (notifications) ← needs projections for anomaly - ├── Phase 4 (manifolds) ← needs metrics, independent of knobs - ├── Phase 6 (recipe editor) ← needs server, independent of projections - └── Phase 7 (multi-mode) ← needs all UI components to exist - └── Phase 8 (self-mode) ← needs multi-mode + projections - └── Phase 9 (benchmarks) ← needs everything working -``` - -Phases 2, 3, 4, 6 can be parallelized after Phase 1. - ---- - -## Package Structure (proposed) - -``` -src/hotcb/ -├── ... (existing) -├── metrics/ -│ ├── __init__.py -│ ├── collector.py # MetricsCollector, hooks into kernel -│ └── features.py # FeatureCapture (opt-in forward hooks) -├── server/ -│ ├── __init__.py -│ ├── app.py # FastAPI app, mount everything -│ ├── tailer.py # Background JSONL tailers → WebSocket -│ ├── api.py # REST endpoints (commands, status, config) -│ ├── projections.py # XGBoost forecasting, what-if -│ ├── manifolds.py # UMAP/t-SNE computation -│ ├── notifications.py # Slack/email alerts -│ ├── recipe_editor.py # Recipe CRUD + validation -│ ├── autopilot.py # Self-mode rule engine -│ └── static/ # Built React SPA assets -├── dashboard/ # React source (separate build) -│ ├── package.json -│ ├── src/ -│ │ ├── App.tsx -│ │ ├── components/ -│ │ │ ├── MetricsChart.tsx -│ │ │ ├── KnobPanel.tsx -│ │ │ ├── ProjectionOverlay.tsx -│ │ │ ├── ManifoldViewer.tsx -│ │ │ ├── RecipeTimeline.tsx -│ │ │ ├── NotificationPanel.tsx -│ │ │ └── ModeSelector.tsx -│ │ ├── hooks/ -│ │ │ ├── useWebSocket.ts -│ │ │ └── useMetrics.ts -│ │ └── stores/ -│ │ └── dashboardStore.ts # zustand -│ └── vite.config.ts -└── bench/ - ├── __init__.py - ├── tasks.py # Benchmark task definitions - ├── runner.py # Benchmark execution - └── report.py # LaTeX/CSV export -``` - -**Optional deps update for pyproject.toml:** -```toml -[project.optional-dependencies] -tune = ["optuna>=3.0", "pyyaml>=6.0"] -dashboard = ["fastapi>=0.100", "uvicorn>=0.20", "websockets>=11.0", "xgboost>=1.7", "umap-learn>=0.5"] -bench = ["matplotlib>=3.5", "pandas>=1.5"] -all = ["hotcb[tune,dashboard,bench]"] -``` - ---- - -## What Would Make This a Paper - -**Title idea:** *"From Passive to Active: Human-in-the-Loop Training Control with Live Hyperparameter Steering"* - -**Key claims to benchmark:** -1. Human-interactive tuning reaches target metric in fewer GPU-hours than grid/random/Bayesian search alone -2. Recipe replay achieves deterministic reproduction with zero search overhead -3. XGBoost projections give actionable 80%+ accuracy on short-horizon metric forecasting -4. The combined human+auto system (self-mode with human override) outperforms either alone - -**Eval plan:** -- Tasks: CIFAR-10 ResNet, GPT-2 small fine-tune, simple GAN -- Baselines: fixed HP, Optuna standalone, W&B Sweeps -- Conditions: hotcb auto-only, hotcb human-only, hotcb human+auto -- Metrics: time-to-target, final quality, total interventions, compute cost - ---- - -## Summary - -The vision is ambitious but **architecturally grounded** — the hard infrastructure (kernel, actuators, replay) already exists. The dashboard/server layer is a natural extension that reads the same JSONL streams your CLI already produces. The paradigm shift narrative is credible if you lead with the interactive experience rather than the plumbing. - -**Immediate next step when ready to implement:** Phase 1 — MetricsCollector + FastAPI server + basic streaming charts. That alone is a demo-able product.