From 4ff544efcda7f06eb660fc2aa1da23d23a603569 Mon Sep 17 00:00:00 2001 From: Walter Simson Date: Fri, 27 Mar 2026 10:35:07 -0700 Subject: [PATCH 1/4] Add agent-first CLI for k-Wave simulations Structured JSON I/O at every step, deterministic exit codes, file-backed session model. Commands: session init/status/reset, phantom generate, sensor define, plan, run. Acceptance test: new_api_ivp_2D.py runs entirely via CLI with exact numerical parity to the Python API. Co-Authored-By: Claude Opus 4.6 (1M context) --- kwave/cli/__init__.py | 9 ++ kwave/cli/commands/__init__.py | 0 kwave/cli/commands/phantom.py | 108 +++++++++++++++ kwave/cli/commands/plan.py | 108 +++++++++++++++ kwave/cli/commands/run.py | 94 +++++++++++++ kwave/cli/commands/sensor.py | 40 ++++++ kwave/cli/commands/session_cmd.py | 39 ++++++ kwave/cli/main.py | 33 +++++ kwave/cli/schema.py | 98 ++++++++++++++ kwave/cli/session.py | 195 +++++++++++++++++++++++++++ kwave/kspaceFirstOrder.py | 3 +- kwave/solvers/kspace_solver.py | 4 +- pyproject.toml | 6 +- tests/test_cli/__init__.py | 0 tests/test_cli/test_e2e.py | 210 ++++++++++++++++++++++++++++++ uv.lock | 16 +++ 16 files changed, 960 insertions(+), 3 deletions(-) create mode 100644 kwave/cli/__init__.py create mode 100644 kwave/cli/commands/__init__.py create mode 100644 kwave/cli/commands/phantom.py create mode 100644 kwave/cli/commands/plan.py create mode 100644 kwave/cli/commands/run.py create mode 100644 kwave/cli/commands/sensor.py create mode 100644 kwave/cli/commands/session_cmd.py create mode 100644 kwave/cli/main.py create mode 100644 kwave/cli/schema.py create mode 100644 kwave/cli/session.py create mode 100644 tests/test_cli/__init__.py create mode 100644 tests/test_cli/test_e2e.py diff --git a/kwave/cli/__init__.py b/kwave/cli/__init__.py new file mode 100644 index 00000000..ad0a52d3 --- /dev/null +++ b/kwave/cli/__init__.py @@ -0,0 +1,9 @@ +def __getattr__(name): + if name == "Session": + from kwave.cli.session import Session + + return Session + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["Session"] diff --git a/kwave/cli/commands/__init__.py b/kwave/cli/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kwave/cli/commands/phantom.py b/kwave/cli/commands/phantom.py new file mode 100644 index 00000000..f1fc2c72 --- /dev/null +++ b/kwave/cli/commands/phantom.py @@ -0,0 +1,108 @@ +"""Phantom generation and loading commands.""" + +import click +import numpy as np + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIError, CLIResponse, ValidationError, json_command + + +@click.group("phantom") +def phantom(): + """Define the simulation phantom (medium + initial pressure).""" + pass + + +@phantom.command() +@click.option("--type", "phantom_type", required=True, type=click.Choice(["disc", "spherical", "layered"])) +@click.option("--grid-size", required=True, help="Grid dimensions, e.g. 128,128") +@click.option("--spacing", required=True, type=float, help="Grid spacing in meters, e.g. 0.1e-3") +@click.option("--sound-speed", type=float, default=1500, help="Medium sound speed (m/s)") +@click.option("--density", type=float, default=1000, help="Medium density (kg/m^3)") +@click.option("--disc-center", default=None, help="Disc center, e.g. 64,64") +@click.option("--disc-radius", type=int, default=5, help="Disc radius in grid points") +@pass_session +@json_command("phantom.generate") +def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_center, disc_radius): + """Generate an analytical phantom.""" + sess.load() + + grid_n = tuple(int(x) for x in grid_size.split(",")) + ndim = len(grid_n) + grid_spacing = (spacing,) * ndim + + if phantom_type == "disc": + if ndim != 2: + raise ValidationError( + CLIError( + code="DISC_REQUIRES_2D", + field="grid_size", + value=grid_size, + constraint="disc phantom requires 2D grid", + suggestion="Use --grid-size Nx,Ny (two dimensions)", + ) + ) + from kwave.data import Vector + from kwave.utils.mapgen import make_disc + + if disc_center is None: + center = Vector([n // 2 for n in grid_n]) + else: + center = Vector([int(x) for x in disc_center.split(",")]) + + p0 = make_disc(Vector(list(grid_n)), center, disc_radius).astype(float) + + elif phantom_type == "spherical": + # Simple spherical inclusion centered in grid + center = np.array([n // 2 for n in grid_n]) + coords = np.mgrid[tuple(slice(0, n) for n in grid_n)] + dist = np.sqrt(sum((c - cn) ** 2 for c, cn in zip(coords, center))) + p0 = (dist <= disc_radius).astype(float) + + elif phantom_type == "layered": + p0 = np.zeros(grid_n) + layer_pos = grid_n[0] // 4 + p0[layer_pos, ...] = 1.0 + + # Save arrays + p0_path = sess.save_array("p0", p0) + + # Update session + sess.update( + "grid", + { + "N": list(grid_n), + "spacing": list(grid_spacing), + "sound_speed_for_time": sound_speed, + }, + ) + sess.update( + "medium", + { + "sound_speed": sound_speed, + "density": density, + }, + ) + sess.update( + "source", + { + "type": "initial-pressure", + "p0_path": p0_path, + }, + ) + + return CLIResponse( + result={ + "phantom_type": phantom_type, + "grid_size": list(grid_n), + "spacing": list(grid_spacing), + "p0_shape": list(p0.shape), + "p0_max": float(p0.max()), + "sound_speed": sound_speed, + "density": density, + }, + derived={ + "ndim": ndim, + "grid_points": int(np.prod(grid_n)), + }, + ) diff --git a/kwave/cli/commands/plan.py b/kwave/cli/commands/plan.py new file mode 100644 index 00000000..59c2b630 --- /dev/null +++ b/kwave/cli/commands/plan.py @@ -0,0 +1,108 @@ +"""Plan command: derive full simulation config, validate, estimate cost.""" + +import click +import numpy as np + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIResponse, SessionError, json_command + + +@click.command("plan") +@pass_session +@json_command("plan") +def plan(sess): + """Derive full simulation config and validate before running.""" + sess.load() + + # Check completeness + comp = sess._completeness() + missing = [k for k, v in comp.items() if not v] + if missing: + raise SessionError(f"Cannot plan: missing {', '.join(missing)}. Complete setup first.") + + kgrid = sess.make_grid() + medium = sess.make_medium() + + g = sess.state["grid"] + grid_n = tuple(g["N"]) + spacing = tuple(g["spacing"]) + ndim = len(grid_n) + grid_points = int(np.prod(grid_n)) + + # Time stepping + dt = float(kgrid.dt) + Nt = int(kgrid.Nt) + + # PPW check + c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) + max_spacing = max(spacing) + # For IVP problems, use grid-based wavelength estimate + min_wavelength = c_min * dt * Nt / 2 # rough estimate + ppw = c_min / (max_spacing * (1 / (2 * dt))) # Nyquist-based PPW + + # CFL + c_max = float(np.max(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) + cfl = c_max * dt / min(spacing) + + # Memory estimate: ~(3 + 2*ndim) fields of float64 + n_fields = 3 + 2 * ndim + memory_bytes = grid_points * n_fields * 8 + memory_mb = memory_bytes / (1024 * 1024) + + # Runtime estimate + cost_per_point_step_ns = 50 # ~50ns per grid point per step on CPU + estimated_runtime_s = grid_points * Nt * cost_per_point_step_ns / 1e9 + + # PML + pml_size = 20 + + # Warnings + warnings = [] + if ppw < 4: + warnings.append( + { + "code": "LOW_PPW", + "detail": f"PPW={ppw:.1f} is below recommended minimum of 4", + "suggestion": "Increase grid resolution or reduce maximum frequency", + } + ) + if cfl > 0.5: + warnings.append( + { + "code": "HIGH_CFL", + "detail": f"CFL={cfl:.3f} exceeds 0.5, simulation may be unstable", + "suggestion": "Reduce time step or increase grid spacing", + } + ) + + result = { + "grid": { + "N": list(grid_n), + "spacing": list(spacing), + "ndim": ndim, + "dt": dt, + "Nt": Nt, + }, + "pml": {"size": pml_size}, + "medium": { + "sound_speed": c_min if c_min == c_max else f"{c_min}-{c_max}", + }, + "source": sess.state["source"], + "sensor": sess.state["sensor"], + "backend": "python", + "device": "cpu", + } + + derived = { + "ppw": round(ppw, 2), + "cfl": round(cfl, 4), + "grid_points": grid_points, + "estimated_memory_mb": round(memory_mb, 1), + "estimated_runtime_s": round(estimated_runtime_s, 1), + } + + return CLIResponse( + result=result, + derived=derived, + warnings=warnings, + ) diff --git a/kwave/cli/commands/run.py b/kwave/cli/commands/run.py new file mode 100644 index 00000000..3aa9b8f5 --- /dev/null +++ b/kwave/cli/commands/run.py @@ -0,0 +1,94 @@ +"""Run command: execute simulation with structured JSON progress.""" + +import json +import sys +import time + +import click +import numpy as np + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIResponse, SessionError, json_command + + +def _emit_event(event: dict): + """Write a JSON event to stdout and flush.""" + click.echo(json.dumps(event, default=str)) + + +@click.command("run") +@click.option("--backend", default="python", type=click.Choice(["python", "cpp"])) +@click.option("--device", default="cpu", type=click.Choice(["cpu", "gpu"])) +@pass_session +@json_command("run") +def run(sess, backend, device): + """Execute the simulation.""" + sess.load() + + # Check completeness + comp = sess._completeness() + missing = [k for k, v in comp.items() if not v] + if missing: + raise SessionError(f"Cannot run: missing {', '.join(missing)}. Complete setup first.") + + kgrid = sess.make_grid() + medium = sess.make_medium() + source = sess.make_source() + sensor = sess.make_sensor() + + Nt = int(kgrid.Nt) + + _emit_event({"event": "started", "backend": backend, "device": device, "Nt": Nt}) + + t_start = time.time() + last_pct = -5 # emit at most every 5% + + def progress_callback(step, total): + nonlocal last_pct + pct = round(100 * step / total, 1) + if pct - last_pct >= 5 or step == total: + last_pct = pct + _emit_event( + { + "event": "progress", + "step": step, + "total": total, + "pct": pct, + "elapsed_s": round(time.time() - t_start, 2), + } + ) + + from kwave.kspaceFirstOrder import kspaceFirstOrder + + result = kspaceFirstOrder( + kgrid, + medium, + source, + sensor, + backend=backend, + device=device, + quiet=True, + progress_callback=progress_callback, + ) + + elapsed = round(time.time() - t_start, 2) + + # Save results + result_info = {} + for key, val in result.items(): + if isinstance(val, np.ndarray): + path = sess.save_array(f"result_{key}", val) + result_info[key] = {"shape": list(val.shape), "path": path} + else: + result_info[key] = val + + sess.update("result_path", str(sess.data_dir)) + + _emit_event({"event": "completed", "elapsed_s": elapsed, "output_keys": list(result.keys())}) + + return CLIResponse( + result={ + "elapsed_s": elapsed, + "outputs": result_info, + }, + ) diff --git a/kwave/cli/commands/sensor.py b/kwave/cli/commands/sensor.py new file mode 100644 index 00000000..78c9f18b --- /dev/null +++ b/kwave/cli/commands/sensor.py @@ -0,0 +1,40 @@ +"""Sensor definition command.""" + +import click + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIResponse, json_command + + +@click.group("sensor") +def sensor(): + """Define sensor configuration.""" + pass + + +@sensor.command() +@click.option("--mask", required=True, help="Sensor mask: 'full-grid' or path to .npy file") +@click.option("--record", default="p,p_final", help="Comma-separated fields to record, e.g. p,p_final,ux") +@pass_session +@json_command("sensor.define") +def define(sess, mask, record): + """Define what and where to record.""" + sess.load() + + record_fields = [r.strip() for r in record.split(",")] + + sensor_config = {"record": record_fields} + if mask == "full-grid": + sensor_config["mask_type"] = "full-grid" + else: + sensor_config["mask_type"] = "file" + sensor_config["mask_path"] = mask + + sess.update("sensor", sensor_config) + + return CLIResponse( + result={ + "mask_type": sensor_config["mask_type"], + "record": record_fields, + } + ) diff --git a/kwave/cli/commands/session_cmd.py b/kwave/cli/commands/session_cmd.py new file mode 100644 index 00000000..57991cfd --- /dev/null +++ b/kwave/cli/commands/session_cmd.py @@ -0,0 +1,39 @@ +"""Session management commands.""" + +import click + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIResponse, json_command + + +@click.group("session") +def session(): + """Manage simulation session.""" + pass + + +@session.command() +@pass_session +@json_command("session.init") +def init(sess): + """Create a new session.""" + info = sess.init() + return CLIResponse(result=info) + + +@session.command() +@pass_session +@json_command("session.status") +def status(sess): + """Return full current session state.""" + sess.load() + return CLIResponse(result=sess.status()) + + +@session.command() +@pass_session +@json_command("session.reset") +def reset(sess): + """Clear session state.""" + info = sess.reset() + return CLIResponse(result=info) diff --git a/kwave/cli/main.py b/kwave/cli/main.py new file mode 100644 index 00000000..c8e4741a --- /dev/null +++ b/kwave/cli/main.py @@ -0,0 +1,33 @@ +"""Agent-first CLI for k-Wave simulations. All commands return structured JSON.""" + +from pathlib import Path +from typing import Optional + +import click + +from kwave.cli.session import Session + +pass_session = click.make_pass_decorator(Session, ensure=True) + + +@click.group() +@click.option("--session-dir", type=click.Path(), default=None, envvar="KWAVE_SESSION_DIR", help="Session directory (default: ~/.kwave)") +@click.pass_context +def cli(ctx, session_dir): + """k-Wave agent-first CLI. All commands return structured JSON.""" + base_dir = Path(session_dir) if session_dir else None + ctx.obj = Session(base_dir=base_dir) + + +# Register command groups +from kwave.cli.commands.phantom import phantom # noqa: E402 +from kwave.cli.commands.plan import plan # noqa: E402 +from kwave.cli.commands.run import run # noqa: E402 +from kwave.cli.commands.sensor import sensor # noqa: E402 +from kwave.cli.commands.session_cmd import session # noqa: E402 + +cli.add_command(session) +cli.add_command(phantom) +cli.add_command(sensor) +cli.add_command(plan) +cli.add_command(run) diff --git a/kwave/cli/schema.py b/kwave/cli/schema.py new file mode 100644 index 00000000..f9d3024b --- /dev/null +++ b/kwave/cli/schema.py @@ -0,0 +1,98 @@ +"""Response envelope and error types for the agent-first CLI.""" + +from __future__ import annotations + +import json +import sys +import traceback +from dataclasses import asdict, dataclass, field +from functools import wraps +from typing import Any, Literal + +import click + +# Exit codes +EXIT_OK = 0 +EXIT_VALIDATION = 1 +EXIT_SESSION = 2 +EXIT_SIMULATION = 3 +EXIT_IO = 4 + + +@dataclass +class CLIError: + code: str + field: str = "" + value: Any = None + constraint: str = "" + suggestion: str = "" + + +@dataclass +class CLIResponse: + status: Literal["ok", "error", "warning"] = "ok" + step: str = "" + result: dict = field(default_factory=dict) + derived: dict = field(default_factory=dict) + warnings: list[dict] = field(default_factory=list) + errors: list[dict] = field(default_factory=list) + + def to_json(self) -> str: + return json.dumps(asdict(self), indent=2, default=str) + + +def json_command(step_name: str): + """Decorator that wraps a Click command to always return the JSON envelope.""" + + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + resp = fn(*args, **kwargs) + if not isinstance(resp, CLIResponse): + resp = CLIResponse(step=step_name, result=resp or {}) + resp.step = step_name + click.echo(resp.to_json()) + sys.exit(EXIT_OK) + except click.exceptions.Exit: + raise + except SystemExit: + raise + except ValidationError as e: + resp = CLIResponse( + status="error", + step=step_name, + errors=[asdict(e.error)], + ) + click.echo(resp.to_json()) + sys.exit(EXIT_VALIDATION) + except SessionError as e: + resp = CLIResponse( + status="error", + step=step_name, + errors=[asdict(CLIError(code="SESSION_ERROR", suggestion=str(e)))], + ) + click.echo(resp.to_json()) + sys.exit(EXIT_SESSION) + except Exception as e: + resp = CLIResponse( + status="error", + step=step_name, + errors=[asdict(CLIError(code="UNEXPECTED_ERROR", suggestion=str(e)))], + ) + click.echo(resp.to_json()) + sys.exit(EXIT_SIMULATION) + + return wrapper + + return decorator + + +class ValidationError(Exception): + def __init__(self, error: CLIError): + self.error = error + super().__init__(error.code) + + +class SessionError(Exception): + pass diff --git a/kwave/cli/session.py b/kwave/cli/session.py new file mode 100644 index 00000000..c35c8a82 --- /dev/null +++ b/kwave/cli/session.py @@ -0,0 +1,195 @@ +"""Single file-backed session for the k-Wave CLI.""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +import numpy as np + +from kwave.cli.schema import SessionError + +DEFAULT_SESSION_DIR = Path.home() / ".kwave" +SESSION_FILE = "session.json" + + +def _default_state() -> dict: + return { + "grid": None, + "medium": None, + "source": None, + "sensor": None, + "modality": None, + "resolution_tier": None, + "output_intent": None, + "probe": None, + "sim_options": {}, + "result_path": None, + } + + +class Session: + """Single file-backed simulation session. + + Stores parameters as JSON-serializable dicts. Array data is stored + as .npy files in the session directory. Materializer methods construct + kWave objects on demand. + """ + + def __init__(self, base_dir: Optional[Path] = None): + self.base_dir = Path(base_dir) if base_dir else DEFAULT_SESSION_DIR + self.session_file = self.base_dir / SESSION_FILE + self.data_dir = self.base_dir / "data" + self._state: Optional[dict] = None + self._id: Optional[str] = None + self._created_at: Optional[str] = None + + @property + def state(self) -> dict: + if self._state is None: + raise SessionError("No active session. Run 'kwave session init' first.") + return self._state + + @property + def id(self) -> str: + if self._id is None: + raise SessionError("No active session. Run 'kwave session init' first.") + return self._id + + def init(self) -> dict: + """Create a new session, overwriting any existing one.""" + self.base_dir.mkdir(parents=True, exist_ok=True) + self.data_dir.mkdir(parents=True, exist_ok=True) + self._id = uuid.uuid4().hex[:12] + self._created_at = datetime.now(timezone.utc).isoformat() + self._state = _default_state() + self._save() + return {"session_id": self._id, "created_at": self._created_at} + + def load(self) -> dict: + """Load the current session from disk.""" + if not self.session_file.exists(): + raise SessionError("No active session. Run 'kwave session init' first.") + raw = json.loads(self.session_file.read_text()) + self._id = raw["id"] + self._created_at = raw["created_at"] + self._state = raw["state"] + return self.status() + + def reset(self) -> dict: + """Clear session state, keep the session ID.""" + self.load() + self._state = _default_state() + # Clean up data files + if self.data_dir.exists(): + for f in self.data_dir.iterdir(): + f.unlink() + self._save() + return {"session_id": self._id, "reset": True} + + def status(self) -> dict: + """Return full current state.""" + return { + "session_id": self.id, + "created_at": self._created_at, + "state": self.state, + "completeness": self._completeness(), + } + + def update(self, key: str, value) -> None: + """Update a state field and persist.""" + self.state[key] = value + self._save() + + def save_array(self, name: str, arr: np.ndarray) -> str: + """Save an array to the session data dir, return the path.""" + self.data_dir.mkdir(parents=True, exist_ok=True) + path = self.data_dir / f"{name}.npy" + np.save(path, arr) + return str(path) + + def load_array(self, path: str) -> np.ndarray: + """Load an array from a saved path.""" + return np.load(path) + + # --- Materializers: session state -> kWave objects --- + + def make_grid(self): + """Construct a kWaveGrid from session state.""" + from kwave.kgrid import kWaveGrid + + g = self.state["grid"] + if g is None: + raise SessionError("Grid not defined. Run 'kwave phantom generate' first.") + grid_size = tuple(g["N"]) + grid_spacing = tuple(g["spacing"]) + kgrid = kWaveGrid(grid_size, grid_spacing) + if g.get("sound_speed_for_time") is not None: + kgrid.makeTime(g["sound_speed_for_time"]) + return kgrid + + def make_medium(self): + """Construct a kWaveMedium from session state.""" + from kwave.kmedium import kWaveMedium + + m = self.state["medium"] + if m is None: + raise SessionError("Medium not defined. Run 'kwave phantom generate' first.") + kwargs = {} + for field in ("sound_speed", "density", "alpha_coeff", "alpha_power", "BonA"): + if field in m and m[field] is not None: + kwargs[field] = m[field] + return kWaveMedium(**kwargs) + + def make_source(self): + """Construct a kSource from session state.""" + from kwave.ksource import kSource + + s = self.state["source"] + if s is None: + raise SessionError("Source not defined. It was auto-set by phantom generate.") + source = kSource() + if s.get("p0_path"): + source.p0 = np.load(s["p0_path"]) + return source + + def make_sensor(self): + """Construct a kSensor from session state.""" + from kwave.ksensor import kSensor + + sen = self.state["sensor"] + if sen is None: + raise SessionError("Sensor not defined. Run 'kwave sensor define' first.") + record = sen.get("record", ["p", "p_final"]) + if sen.get("mask_type") == "full-grid": + g = self.state["grid"] + mask = np.ones(tuple(g["N"]), dtype=bool) + elif sen.get("mask_path"): + mask = np.load(sen["mask_path"]) + else: + raise SessionError("Invalid sensor mask configuration.") + sensor = kSensor(mask=mask, record=record) + return sensor + + # --- Private --- + + def _completeness(self) -> dict: + """Which steps have been completed.""" + s = self.state + return { + "grid": s["grid"] is not None, + "medium": s["medium"] is not None, + "source": s["source"] is not None, + "sensor": s["sensor"] is not None, + } + + def _save(self): + raw = { + "id": self._id, + "created_at": self._created_at, + "state": self._state, + } + self.session_file.write_text(json.dumps(raw, indent=2, default=str)) diff --git a/kwave/kspaceFirstOrder.py b/kwave/kspaceFirstOrder.py index 3a4e9eac..1b672681 100644 --- a/kwave/kspaceFirstOrder.py +++ b/kwave/kspaceFirstOrder.py @@ -100,6 +100,7 @@ def kspaceFirstOrder( debug: bool = False, num_threads: Optional[int] = None, device_num: Optional[int] = None, + progress_callback=None, ) -> dict: """Run a k-Wave simulation. @@ -203,7 +204,7 @@ def kspaceFirstOrder( smooth_p0=False, pml_size=pml_size, pml_alpha=pml_alpha, - ).run() + ).run(progress_callback=progress_callback) elif backend == "cpp": from kwave.solvers.cpp_simulation import CppSimulation diff --git a/kwave/solvers/kspace_solver.py b/kwave/solvers/kspace_solver.py index b94f4bf2..245b82ca 100644 --- a/kwave/solvers/kspace_solver.py +++ b/kwave/solvers/kspace_solver.py @@ -628,12 +628,14 @@ def step(self): self.t += 1 return self - def run(self): + def run(self, progress_callback=None): """Run simulation to completion. Returns results dict.""" if not self._is_setup: self.setup() while self.t < self.Nt: self.step() + if progress_callback is not None: + progress_callback(self.t, self.Nt) # Copy to CPU one-by-one, freeing GPU memory as we go result = {} for k in list(self.sensor_data): diff --git a/pyproject.toml b/pyproject.toml index f61d7313..34fc5dfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ docs = [ "sphinx-mdinclude==0.6.2", "sphinx-toolbox==3.8.0", "furo==2024.8.6"] dev = ["pre-commit==4.5.1"] +cli = ["click>=8.0"] + +[project.scripts] +kwave = "kwave.cli.main:cli" [tool.hatch.version] path = "kwave/__init__.py" @@ -65,7 +69,7 @@ path = "kwave/__init__.py" allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["kwave", "kwave.utils", "kwave.reconstruction", "kwave.kWaveSimulation_helper"] +packages = ["kwave", "kwave.utils", "kwave.reconstruction", "kwave.kWaveSimulation_helper", "kwave.cli", "kwave.cli.commands"] [tool.hatch.build.targets.sdist] diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cli/test_e2e.py b/tests/test_cli/test_e2e.py new file mode 100644 index 00000000..7fefa901 --- /dev/null +++ b/tests/test_cli/test_e2e.py @@ -0,0 +1,210 @@ +"""End-to-end CLI test: replicates new_api_ivp_2D.py via CLI commands.""" + +import json + +import numpy as np +import pytest +from click.testing import CliRunner + +from kwave.cli.main import cli +from kwave.data import Vector +from kwave.kgrid import kWaveGrid +from kwave.kmedium import kWaveMedium +from kwave.ksensor import kSensor +from kwave.ksource import kSource +from kwave.kspaceFirstOrder import kspaceFirstOrder +from kwave.utils.mapgen import make_disc + + +def _invoke(runner, args, session_dir): + result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False) + assert result.exit_code == 0, f"Command failed: {args}\n{result.output}" + # The run command emits progress events (single-line JSON) before the final + # multi-line JSON response. Find the last complete JSON object. + output = result.output.strip() + # Try parsing the full output first (non-run commands produce a single JSON object) + try: + return json.loads(output) + except json.JSONDecodeError: + pass + # For run command: find the last '{' that starts a top-level JSON object + depth = 0 + last_start = None + for i, ch in enumerate(output): + if ch == "{" and depth == 0: + last_start = i + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if last_start is not None: + return json.loads(output[last_start:]) + raise ValueError(f"Could not parse JSON from output: {output[:200]}") + + +@pytest.fixture +def session_dir(tmp_path): + return tmp_path / "kwave_test_session" + + +@pytest.fixture +def runner(): + return CliRunner() + + +class TestSessionLifecycle: + def test_init(self, runner, session_dir): + resp = _invoke(runner, ["session", "init"], session_dir) + assert resp["status"] == "ok" + assert "session_id" in resp["result"] + + def test_status_without_init_fails(self, runner, session_dir): + result = runner.invoke(cli, ["--session-dir", str(session_dir), "session", "status"]) + assert result.exit_code != 0 + + def test_reset(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + resp = _invoke(runner, ["session", "reset"], session_dir) + assert resp["result"]["reset"] is True + + +class TestPhantomGenerate: + def test_disc_phantom(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + resp = _invoke( + runner, + [ + "phantom", + "generate", + "--type", + "disc", + "--grid-size", + "64,64", + "--spacing", + "0.1e-3", + "--sound-speed", + "1500", + "--density", + "1000", + "--disc-radius", + "5", + ], + session_dir, + ) + assert resp["status"] == "ok" + assert resp["result"]["grid_size"] == [64, 64] + assert resp["result"]["p0_max"] == 1.0 + + def test_disc_requires_2d(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + result = runner.invoke( + cli, + [ + "--session-dir", + str(session_dir), + "phantom", + "generate", + "--type", + "disc", + "--grid-size", + "64,64,64", + "--spacing", + "0.1e-3", + ], + ) + assert result.exit_code != 0 + resp = json.loads(result.output) + assert resp["errors"][0]["code"] == "DISC_REQUIRES_2D" + + +class TestSensorDefine: + def test_full_grid(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + resp = _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + assert resp["result"]["mask_type"] == "full-grid" + assert resp["result"]["record"] == ["p", "p_final"] + + +class TestPlan: + def test_plan_incomplete_session(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + result = runner.invoke(cli, ["--session-dir", str(session_dir), "plan"]) + assert result.exit_code != 0 + + def test_plan_complete_session(self, runner, session_dir): + _invoke(runner, ["session", "init"], session_dir) + _invoke( + runner, + [ + "phantom", + "generate", + "--type", + "disc", + "--grid-size", + "64,64", + "--spacing", + "0.1e-3", + "--sound-speed", + "1500", + "--density", + "1000", + ], + session_dir, + ) + _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + resp = _invoke(runner, ["plan"], session_dir) + assert resp["status"] == "ok" + assert resp["result"]["grid"]["Nt"] > 0 + assert resp["derived"]["cfl"] > 0 + + +class TestEndToEnd: + """Replicate new_api_ivp_2D.py via CLI and compare results.""" + + def test_cli_matches_python_api(self, runner, session_dir): + # Run via CLI + _invoke(runner, ["session", "init"], session_dir) + _invoke( + runner, + [ + "phantom", + "generate", + "--type", + "disc", + "--grid-size", + "128,128", + "--spacing", + "0.1e-3", + "--sound-speed", + "1500", + "--density", + "1000", + "--disc-center", + "64,64", + "--disc-radius", + "5", + ], + session_dir, + ) + _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + resp = _invoke(runner, ["run"], session_dir) + assert resp["status"] == "ok" + + # Load CLI results + cli_p = np.load(resp["result"]["outputs"]["p"]["path"]) + cli_p_final = np.load(resp["result"]["outputs"]["p_final"]["path"]) + + # Run directly via Python API (the example) + kgrid = kWaveGrid([128, 128], [0.1e-3, 0.1e-3]) + kgrid.makeTime(1500) + medium = kWaveMedium(sound_speed=1500, density=1000) + source = kSource() + source.p0 = make_disc(Vector([128, 128]), Vector([64, 64]), 5).astype(float) + sensor = kSensor(mask=np.ones((128, 128), dtype=bool)) + result = kspaceFirstOrder(kgrid, medium, source, sensor, quiet=True) + + # Compare + assert cli_p.shape == result["p"].shape + assert cli_p_final.shape == result["p_final"].shape + np.testing.assert_allclose(cli_p, result["p"], rtol=0, atol=0) + np.testing.assert_allclose(cli_p_final, result["p_final"], rtol=0, atol=0) diff --git a/uv.lock b/uv.lock index 6491700c..95d77769 100644 --- a/uv.lock +++ b/uv.lock @@ -236,6 +236,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/68/687187c7e26cb24ccbd88e5069f5ef00eba804d36dde11d99aad0838ab45/charset_normalizer-3.4.6-py3-none-any.whl", hash = "sha256:947cf925bc916d90adba35a64c82aace04fa39b46b52d4630ece166655905a69", size = 61455 }, ] +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -773,6 +785,9 @@ dependencies = [ ] [package.optional-dependencies] +cli = [ + { name = "click" }, +] dev = [ { name = "pre-commit" }, ] @@ -803,6 +818,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "beartype", specifier = "==0.22.9" }, + { name = "click", marker = "extra == 'cli'", specifier = ">=8.0" }, { name = "coverage", marker = "extra == 'test'", specifier = "==7.10.6" }, { name = "deepdiff", specifier = "==8.6.1" }, { name = "deprecated", specifier = ">=1.2.14" }, From 945a042210a9c709bfbccc474696f2ee622ef5dd Mon Sep 17 00:00:00 2001 From: Walter Simson Date: Fri, 27 Mar 2026 10:41:25 -0700 Subject: [PATCH 2/4] Add phantom load, source define, and 1D IVP support - phantom load: accepts .npy files for heterogeneous medium properties - source define: loads custom p0 from .npy file - sensor define: supports file-based masks (sparse sensors) - Grid materializer passes full sound speed array to makeTime (not just max) - Custom CFL support via --cfl flag - Test: ivp_1D_simulation.py replicated via CLI with exact parity Co-Authored-By: Claude Opus 4.6 (1M context) --- kwave/cli/commands/phantom.py | 72 +++++++++++++- kwave/cli/commands/source.py | 43 +++++++++ kwave/cli/main.py | 2 + kwave/cli/session.py | 31 +++++- tests/test_cli/test_1d_ivp.py | 176 ++++++++++++++++++++++++++++++++++ 5 files changed, 318 insertions(+), 6 deletions(-) create mode 100644 kwave/cli/commands/source.py create mode 100644 tests/test_cli/test_1d_ivp.py diff --git a/kwave/cli/commands/phantom.py b/kwave/cli/commands/phantom.py index f1fc2c72..7f8a97df 100644 --- a/kwave/cli/commands/phantom.py +++ b/kwave/cli/commands/phantom.py @@ -13,6 +13,75 @@ def phantom(): pass +@phantom.command("load") +@click.option("--grid-size", required=True, help="Grid dimensions, e.g. 512 or 128,128") +@click.option("--spacing", required=True, type=float, help="Grid spacing in meters") +@click.option("--sound-speed", required=True, help="Scalar value (m/s) or path to .npy file") +@click.option("--density", default=None, help="Scalar value (kg/m^3) or path to .npy file") +@click.option("--cfl", type=float, default=None, help="CFL number for time step calculation") +@pass_session +@json_command("phantom.load") +def load(sess, grid_size, spacing, sound_speed, density, cfl): + """Load medium properties from scalar values or .npy files.""" + sess.load() + + grid_n = tuple(int(x) for x in grid_size.split(",")) + ndim = len(grid_n) + grid_spacing = (spacing,) * ndim + + medium_state = {} + + # Sound speed: scalar or .npy path + if sound_speed.endswith(".npy"): + arr = np.load(sound_speed) + path = sess.save_array("sound_speed", arr) + medium_state["sound_speed_path"] = path + medium_state["sound_speed_scalar"] = None + # Store path so makeTime gets the full array (not just max) + sound_speed_for_time_path = path + sound_speed_for_time_scalar = None + else: + medium_state["sound_speed_scalar"] = float(sound_speed) + medium_state["sound_speed_path"] = None + sound_speed_for_time_path = None + sound_speed_for_time_scalar = float(sound_speed) + + # Density: scalar, .npy path, or None + if density is not None: + if density.endswith(".npy"): + arr = np.load(density) + path = sess.save_array("density", arr) + medium_state["density_path"] = path + medium_state["density_scalar"] = None + else: + medium_state["density_scalar"] = float(density) + medium_state["density_path"] = None + + grid_state = { + "N": list(grid_n), + "spacing": list(grid_spacing), + "sound_speed_for_time_path": sound_speed_for_time_path, + "sound_speed_for_time_scalar": sound_speed_for_time_scalar, + } + if cfl is not None: + grid_state["cfl"] = cfl + + sess.update("grid", grid_state) + sess.update("medium", medium_state) + + return CLIResponse( + result={ + "grid_size": list(grid_n), + "spacing": list(grid_spacing), + "medium": medium_state, + }, + derived={ + "ndim": ndim, + "grid_points": int(np.prod(grid_n)), + }, + ) + + @phantom.command() @click.option("--type", "phantom_type", required=True, type=click.Choice(["disc", "spherical", "layered"])) @click.option("--grid-size", required=True, help="Grid dimensions, e.g. 128,128") @@ -73,7 +142,8 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_ { "N": list(grid_n), "spacing": list(grid_spacing), - "sound_speed_for_time": sound_speed, + "sound_speed_for_time_scalar": sound_speed, + "sound_speed_for_time_path": None, }, ) sess.update( diff --git a/kwave/cli/commands/source.py b/kwave/cli/commands/source.py new file mode 100644 index 00000000..4d21e8cc --- /dev/null +++ b/kwave/cli/commands/source.py @@ -0,0 +1,43 @@ +"""Source definition command.""" + +import click +import numpy as np + +from kwave.cli.main import pass_session +from kwave.cli.schema import CLIResponse, json_command + + +@click.group("source") +def source(): + """Define simulation source.""" + pass + + +@source.command() +@click.option("--type", "source_type", required=True, type=click.Choice(["initial-pressure"])) +@click.option("--p0-file", required=True, type=click.Path(exists=True), help="Path to .npy file with initial pressure distribution") +@pass_session +@json_command("source.define") +def define(sess, source_type, p0_file): + """Define source from file.""" + sess.load() + + p0 = np.load(p0_file) + p0_path = sess.save_array("p0", p0) + + sess.update( + "source", + { + "type": source_type, + "p0_path": p0_path, + }, + ) + + return CLIResponse( + result={ + "type": source_type, + "p0_shape": list(p0.shape), + "p0_max": float(p0.max()), + "p0_min": float(p0.min()), + } + ) diff --git a/kwave/cli/main.py b/kwave/cli/main.py index c8e4741a..81266411 100644 --- a/kwave/cli/main.py +++ b/kwave/cli/main.py @@ -25,9 +25,11 @@ def cli(ctx, session_dir): from kwave.cli.commands.run import run # noqa: E402 from kwave.cli.commands.sensor import sensor # noqa: E402 from kwave.cli.commands.session_cmd import session # noqa: E402 +from kwave.cli.commands.source import source # noqa: E402 cli.add_command(session) cli.add_command(phantom) cli.add_command(sensor) +cli.add_command(source) cli.add_command(plan) cli.add_command(run) diff --git a/kwave/cli/session.py b/kwave/cli/session.py index c35c8a82..c6a1ba72 100644 --- a/kwave/cli/session.py +++ b/kwave/cli/session.py @@ -123,12 +123,24 @@ def make_grid(self): g = self.state["grid"] if g is None: - raise SessionError("Grid not defined. Run 'kwave phantom generate' first.") + raise SessionError("Grid not defined. Run 'kwave phantom generate' or 'kwave phantom load' first.") grid_size = tuple(g["N"]) grid_spacing = tuple(g["spacing"]) kgrid = kWaveGrid(grid_size, grid_spacing) - if g.get("sound_speed_for_time") is not None: - kgrid.makeTime(g["sound_speed_for_time"]) + + # Resolve sound speed for time stepping (array or scalar) + sound_speed = None + if g.get("sound_speed_for_time_path") is not None: + sound_speed = np.load(g["sound_speed_for_time_path"]) + elif g.get("sound_speed_for_time_scalar") is not None: + sound_speed = g["sound_speed_for_time_scalar"] + + if sound_speed is not None: + cfl = g.get("cfl") + if cfl is not None: + kgrid.makeTime(sound_speed, cfl=cfl) + else: + kgrid.makeTime(sound_speed) return kgrid def make_medium(self): @@ -137,10 +149,19 @@ def make_medium(self): m = self.state["medium"] if m is None: - raise SessionError("Medium not defined. Run 'kwave phantom generate' first.") + raise SessionError("Medium not defined. Run 'kwave phantom generate' or 'kwave phantom load' first.") kwargs = {} + # Handle fields that can be scalars or .npy paths for field in ("sound_speed", "density", "alpha_coeff", "alpha_power", "BonA"): - if field in m and m[field] is not None: + # Check for path-based storage (from phantom load) + path_key = f"{field}_path" + scalar_key = f"{field}_scalar" + if path_key in m and m[path_key] is not None: + kwargs[field] = np.load(m[path_key]) + elif scalar_key in m and m[scalar_key] is not None: + kwargs[field] = m[scalar_key] + # Fallback: direct scalar value (from phantom generate) + elif field in m and m[field] is not None: kwargs[field] = m[field] return kWaveMedium(**kwargs) diff --git a/tests/test_cli/test_1d_ivp.py b/tests/test_cli/test_1d_ivp.py new file mode 100644 index 00000000..6a5d9ccd --- /dev/null +++ b/tests/test_cli/test_1d_ivp.py @@ -0,0 +1,176 @@ +"""Test: replicate ivp_1D_simulation.py via CLI commands. + +Exercises: 1D grid, heterogeneous medium (array .npy files), +custom p0, sparse sensor mask, custom CFL. +""" + +import json + +import numpy as np +import pytest +from click.testing import CliRunner + +from kwave.cli.main import cli +from kwave.data import Vector +from kwave.kgrid import kWaveGrid +from kwave.kmedium import kWaveMedium +from kwave.ksensor import kSensor +from kwave.ksource import kSource +from kwave.kspaceFirstOrder import kspaceFirstOrder + + +def _invoke(runner, args, session_dir): + result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False) + assert result.exit_code == 0, f"Command failed: {args}\n{result.output}" + output = result.output.strip() + try: + return json.loads(output) + except json.JSONDecodeError: + pass + # For run command: find the last top-level JSON object + depth = 0 + last_start = None + for i, ch in enumerate(output): + if ch == "{" and depth == 0: + last_start = i + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if last_start is not None: + return json.loads(output[last_start:]) + raise ValueError(f"Could not parse JSON from output: {output[:200]}") + + +# --- Build the 1D IVP arrays (same as ivp_1D_simulation.py) --- + +Nx = 512 +dx = 0.05e-3 + + +def _make_sound_speed(): + c = 1500 * np.ones(Nx) + c[: Nx // 3] = 2000 + return c + + +def _make_density(): + rho = 1000 * np.ones(Nx) + rho[4 * Nx // 5 :] = 1500 + return rho + + +def _make_p0(): + p0 = np.zeros(Nx) + x0, width = 280, 100 + pulse = 0.5 * (np.sin(np.arange(width + 1) * np.pi / width - np.pi / 2) + 1) + p0[x0 : x0 + width + 1] = pulse + return p0 + + +def _make_sensor_mask(): + mask = np.zeros(Nx) + mask[Nx // 4] = 1 + mask[3 * Nx // 4] = 1 + return mask + + +@pytest.fixture +def session_dir(tmp_path): + return tmp_path / "kwave_test_session" + + +@pytest.fixture +def data_dir(tmp_path): + """Directory for pre-built .npy files (simulating what an agent would prepare).""" + d = tmp_path / "arrays" + d.mkdir() + np.save(d / "sound_speed.npy", _make_sound_speed()) + np.save(d / "density.npy", _make_density()) + np.save(d / "p0.npy", _make_p0()) + np.save(d / "sensor_mask.npy", _make_sensor_mask()) + return d + + +@pytest.fixture +def runner(): + return CliRunner() + + +class TestCLI1DIVP: + """Replicate ivp_1D_simulation.py end-to-end via CLI.""" + + def test_cli_matches_python_api(self, runner, session_dir, data_dir): + # -- CLI flow -- + _invoke(runner, ["session", "init"], session_dir) + + _invoke( + runner, + [ + "phantom", + "load", + "--grid-size", + "512", + "--spacing", + "0.05e-3", + "--sound-speed", + str(data_dir / "sound_speed.npy"), + "--density", + str(data_dir / "density.npy"), + "--cfl", + "0.3", + ], + session_dir, + ) + + _invoke( + runner, + [ + "source", + "define", + "--type", + "initial-pressure", + "--p0-file", + str(data_dir / "p0.npy"), + ], + session_dir, + ) + + _invoke( + runner, + [ + "sensor", + "define", + "--mask", + str(data_dir / "sensor_mask.npy"), + "--record", + "p", + ], + session_dir, + ) + + plan_resp = _invoke(runner, ["plan"], session_dir) + assert plan_resp["status"] == "ok" + assert plan_resp["result"]["grid"]["N"] == [512] + assert plan_resp["result"]["grid"]["Nt"] > 0 + + run_resp = _invoke(runner, ["run"], session_dir) + assert run_resp["status"] == "ok" + + # Load CLI results + cli_p = np.load(run_resp["result"]["outputs"]["p"]["path"]) + + # -- Direct Python API (the example) -- + sound_speed = _make_sound_speed() + density = _make_density() + kgrid = kWaveGrid(Vector([Nx]), Vector([dx])) + kgrid.makeTime(sound_speed, cfl=0.3) + medium = kWaveMedium(sound_speed=sound_speed, density=density) + source = kSource() + source.p0 = _make_p0() + sensor = kSensor(mask=_make_sensor_mask()) + result = kspaceFirstOrder(kgrid, medium, source, sensor, backend="python", quiet=True) + + # -- Compare -- + assert cli_p.shape == result["p"].shape, f"Shape mismatch: {cli_p.shape} vs {result['p'].shape}" + np.testing.assert_allclose(cli_p, result["p"], rtol=0, atol=0) From d26ab3080de8be05e2bbc2606dee3659c41a7091 Mon Sep 17 00:00:00 2001 From: Walter Simson Date: Fri, 27 Mar 2026 10:52:40 -0700 Subject: [PATCH 3/4] Simplify CLI: normalize schema, deduplicate, remove dead code - Normalize medium state to unified {field}_scalar/{field}_path schema (phantom generate now uses same format as phantom load) - Remove bogus PPW check from plan (Nyquist-based PPW was a false positive for all valid CFL values) - Remove dead code: unused min_wavelength, unused sys import, unused Optional - Add Session.assert_ready() and Session.update_many() - Rename _completeness() to completeness() (public API) - Extract _invoke helper to tests/test_cli/conftest.py - Shrink e2e test grid from 128x128 to 48x48 - Extract _parse_int_tuple and _resolve_scalar_or_path in phantom.py Co-Authored-By: Claude Opus 4.6 (1M context) --- kwave/cli/commands/phantom.py | 104 ++++++++++------------------------ kwave/cli/commands/plan.py | 45 +++------------ kwave/cli/commands/run.py | 10 +--- kwave/cli/main.py | 1 - kwave/cli/session.py | 82 +++++++++++++++------------ tests/test_cli/conftest.py | 43 ++++++++++++++ tests/test_cli/test_1d_ivp.py | 65 ++++----------------- tests/test_cli/test_e2e.py | 103 +++++++++------------------------ 8 files changed, 166 insertions(+), 287 deletions(-) create mode 100644 tests/test_cli/conftest.py diff --git a/kwave/cli/commands/phantom.py b/kwave/cli/commands/phantom.py index 7f8a97df..cafdc4c4 100644 --- a/kwave/cli/commands/phantom.py +++ b/kwave/cli/commands/phantom.py @@ -7,6 +7,19 @@ from kwave.cli.schema import CLIError, CLIResponse, ValidationError, json_command +def _parse_int_tuple(s: str) -> tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + +def _resolve_scalar_or_path(value: str, name: str, sess) -> dict: + """Parse a CLI value as scalar float or .npy path. Returns {name_scalar, name_path} dict.""" + if value.endswith(".npy"): + arr = np.load(value) + path = sess.save_array(name, arr) + return {f"{name}_scalar": None, f"{name}_path": path} + return {f"{name}_scalar": float(value), f"{name}_path": None} + + @click.group("phantom") def phantom(): """Define the simulation phantom (medium + initial pressure).""" @@ -25,60 +38,23 @@ def load(sess, grid_size, spacing, sound_speed, density, cfl): """Load medium properties from scalar values or .npy files.""" sess.load() - grid_n = tuple(int(x) for x in grid_size.split(",")) + grid_n = _parse_int_tuple(grid_size) ndim = len(grid_n) grid_spacing = (spacing,) * ndim - medium_state = {} - - # Sound speed: scalar or .npy path - if sound_speed.endswith(".npy"): - arr = np.load(sound_speed) - path = sess.save_array("sound_speed", arr) - medium_state["sound_speed_path"] = path - medium_state["sound_speed_scalar"] = None - # Store path so makeTime gets the full array (not just max) - sound_speed_for_time_path = path - sound_speed_for_time_scalar = None - else: - medium_state["sound_speed_scalar"] = float(sound_speed) - medium_state["sound_speed_path"] = None - sound_speed_for_time_path = None - sound_speed_for_time_scalar = float(sound_speed) - - # Density: scalar, .npy path, or None + medium_state = _resolve_scalar_or_path(sound_speed, "sound_speed", sess) if density is not None: - if density.endswith(".npy"): - arr = np.load(density) - path = sess.save_array("density", arr) - medium_state["density_path"] = path - medium_state["density_scalar"] = None - else: - medium_state["density_scalar"] = float(density) - medium_state["density_path"] = None - - grid_state = { - "N": list(grid_n), - "spacing": list(grid_spacing), - "sound_speed_for_time_path": sound_speed_for_time_path, - "sound_speed_for_time_scalar": sound_speed_for_time_scalar, - } + medium_state.update(_resolve_scalar_or_path(density, "density", sess)) + + grid_state = {"N": list(grid_n), "spacing": list(grid_spacing)} if cfl is not None: grid_state["cfl"] = cfl - sess.update("grid", grid_state) - sess.update("medium", medium_state) + sess.update_many({"grid": grid_state, "medium": medium_state}) return CLIResponse( - result={ - "grid_size": list(grid_n), - "spacing": list(grid_spacing), - "medium": medium_state, - }, - derived={ - "ndim": ndim, - "grid_points": int(np.prod(grid_n)), - }, + result={"grid_size": list(grid_n), "spacing": list(grid_spacing), "medium": medium_state}, + derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))}, ) @@ -96,7 +72,7 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_ """Generate an analytical phantom.""" sess.load() - grid_n = tuple(int(x) for x in grid_size.split(",")) + grid_n = _parse_int_tuple(grid_size) ndim = len(grid_n) grid_spacing = (spacing,) * ndim @@ -117,12 +93,11 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_ if disc_center is None: center = Vector([n // 2 for n in grid_n]) else: - center = Vector([int(x) for x in disc_center.split(",")]) + center = Vector(_parse_int_tuple(disc_center)) p0 = make_disc(Vector(list(grid_n)), center, disc_radius).astype(float) elif phantom_type == "spherical": - # Simple spherical inclusion centered in grid center = np.array([n // 2 for n in grid_n]) coords = np.mgrid[tuple(slice(0, n) for n in grid_n)] dist = np.sqrt(sum((c - cn) ** 2 for c, cn in zip(coords, center))) @@ -133,32 +108,14 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_ layer_pos = grid_n[0] // 4 p0[layer_pos, ...] = 1.0 - # Save arrays p0_path = sess.save_array("p0", p0) - # Update session - sess.update( - "grid", + sess.update_many( { - "N": list(grid_n), - "spacing": list(grid_spacing), - "sound_speed_for_time_scalar": sound_speed, - "sound_speed_for_time_path": None, - }, - ) - sess.update( - "medium", - { - "sound_speed": sound_speed, - "density": density, - }, - ) - sess.update( - "source", - { - "type": "initial-pressure", - "p0_path": p0_path, - }, + "grid": {"N": list(grid_n), "spacing": list(grid_spacing)}, + "medium": {"sound_speed_scalar": sound_speed, "sound_speed_path": None, "density_scalar": density, "density_path": None}, + "source": {"type": "initial-pressure", "p0_path": p0_path}, + } ) return CLIResponse( @@ -171,8 +128,5 @@ def generate(sess, phantom_type, grid_size, spacing, sound_speed, density, disc_ "sound_speed": sound_speed, "density": density, }, - derived={ - "ndim": ndim, - "grid_points": int(np.prod(grid_n)), - }, + derived={"ndim": ndim, "grid_points": int(np.prod(grid_n))}, ) diff --git a/kwave/cli/commands/plan.py b/kwave/cli/commands/plan.py index 59c2b630..8b6f7543 100644 --- a/kwave/cli/commands/plan.py +++ b/kwave/cli/commands/plan.py @@ -4,7 +4,7 @@ import numpy as np from kwave.cli.main import pass_session -from kwave.cli.schema import CLIResponse, SessionError, json_command +from kwave.cli.schema import CLIResponse, json_command @click.command("plan") @@ -13,59 +13,29 @@ def plan(sess): """Derive full simulation config and validate before running.""" sess.load() - - # Check completeness - comp = sess._completeness() - missing = [k for k, v in comp.items() if not v] - if missing: - raise SessionError(f"Cannot plan: missing {', '.join(missing)}. Complete setup first.") + sess.assert_ready("plan") kgrid = sess.make_grid() medium = sess.make_medium() - g = sess.state["grid"] - grid_n = tuple(g["N"]) - spacing = tuple(g["spacing"]) + grid_n = tuple(int(n) for n in kgrid.N) + spacing = tuple(float(d) for d in kgrid.spacing) ndim = len(grid_n) grid_points = int(np.prod(grid_n)) - - # Time stepping dt = float(kgrid.dt) Nt = int(kgrid.Nt) - # PPW check - c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) - max_spacing = max(spacing) - # For IVP problems, use grid-based wavelength estimate - min_wavelength = c_min * dt * Nt / 2 # rough estimate - ppw = c_min / (max_spacing * (1 / (2 * dt))) # Nyquist-based PPW - - # CFL c_max = float(np.max(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) + c_min = float(np.min(medium.sound_speed)) if hasattr(medium.sound_speed, "__len__") else float(medium.sound_speed) cfl = c_max * dt / min(spacing) - # Memory estimate: ~(3 + 2*ndim) fields of float64 n_fields = 3 + 2 * ndim - memory_bytes = grid_points * n_fields * 8 - memory_mb = memory_bytes / (1024 * 1024) + memory_mb = grid_points * n_fields * 8 / (1024 * 1024) + estimated_runtime_s = grid_points * Nt * 50e-9 # ~50ns per grid point per step on CPU - # Runtime estimate - cost_per_point_step_ns = 50 # ~50ns per grid point per step on CPU - estimated_runtime_s = grid_points * Nt * cost_per_point_step_ns / 1e9 - - # PML pml_size = 20 - # Warnings warnings = [] - if ppw < 4: - warnings.append( - { - "code": "LOW_PPW", - "detail": f"PPW={ppw:.1f} is below recommended minimum of 4", - "suggestion": "Increase grid resolution or reduce maximum frequency", - } - ) if cfl > 0.5: warnings.append( { @@ -94,7 +64,6 @@ def plan(sess): } derived = { - "ppw": round(ppw, 2), "cfl": round(cfl, 4), "grid_points": grid_points, "estimated_memory_mb": round(memory_mb, 1), diff --git a/kwave/cli/commands/run.py b/kwave/cli/commands/run.py index 3aa9b8f5..c19f2fc5 100644 --- a/kwave/cli/commands/run.py +++ b/kwave/cli/commands/run.py @@ -1,14 +1,13 @@ """Run command: execute simulation with structured JSON progress.""" import json -import sys import time import click import numpy as np from kwave.cli.main import pass_session -from kwave.cli.schema import CLIResponse, SessionError, json_command +from kwave.cli.schema import CLIResponse, json_command def _emit_event(event: dict): @@ -24,12 +23,7 @@ def _emit_event(event: dict): def run(sess, backend, device): """Execute the simulation.""" sess.load() - - # Check completeness - comp = sess._completeness() - missing = [k for k, v in comp.items() if not v] - if missing: - raise SessionError(f"Cannot run: missing {', '.join(missing)}. Complete setup first.") + sess.assert_ready("run") kgrid = sess.make_grid() medium = sess.make_medium() diff --git a/kwave/cli/main.py b/kwave/cli/main.py index 81266411..0a924d20 100644 --- a/kwave/cli/main.py +++ b/kwave/cli/main.py @@ -1,7 +1,6 @@ """Agent-first CLI for k-Wave simulations. All commands return structured JSON.""" from pathlib import Path -from typing import Optional import click diff --git a/kwave/cli/session.py b/kwave/cli/session.py index c6a1ba72..b13a6e20 100644 --- a/kwave/cli/session.py +++ b/kwave/cli/session.py @@ -83,7 +83,6 @@ def reset(self) -> dict: """Clear session state, keep the session ID.""" self.load() self._state = _default_state() - # Clean up data files if self.data_dir.exists(): for f in self.data_dir.iterdir(): f.unlink() @@ -96,14 +95,36 @@ def status(self) -> dict: "session_id": self.id, "created_at": self._created_at, "state": self.state, - "completeness": self._completeness(), + "completeness": self.completeness(), } + def completeness(self) -> dict: + """Which steps have been completed.""" + s = self.state + return { + "grid": s["grid"] is not None, + "medium": s["medium"] is not None, + "source": s["source"] is not None, + "sensor": s["sensor"] is not None, + } + + def assert_ready(self, verb: str = "proceed") -> None: + """Raise SessionError if any required step is incomplete.""" + missing = [k for k, v in self.completeness().items() if not v] + if missing: + raise SessionError(f"Cannot {verb}: missing {', '.join(missing)}. Complete setup first.") + def update(self, key: str, value) -> None: """Update a state field and persist.""" self.state[key] = value self._save() + def update_many(self, updates: dict) -> None: + """Update multiple state fields and persist once.""" + for key, value in updates.items(): + self.state[key] = value + self._save() + def save_array(self, name: str, arr: np.ndarray) -> str: """Save an array to the session data dir, return the path.""" self.data_dir.mkdir(parents=True, exist_ok=True) @@ -117,6 +138,16 @@ def load_array(self, path: str) -> np.ndarray: # --- Materializers: session state -> kWave objects --- + def _resolve_field(self, state: dict, field: str): + """Load a medium field from session state (path or scalar).""" + path = state.get(f"{field}_path") + if path is not None: + return np.load(path) + scalar = state.get(f"{field}_scalar") + if scalar is not None: + return scalar + return None + def make_grid(self): """Construct a kWaveGrid from session state.""" from kwave.kgrid import kWaveGrid @@ -128,19 +159,16 @@ def make_grid(self): grid_spacing = tuple(g["spacing"]) kgrid = kWaveGrid(grid_size, grid_spacing) - # Resolve sound speed for time stepping (array or scalar) - sound_speed = None - if g.get("sound_speed_for_time_path") is not None: - sound_speed = np.load(g["sound_speed_for_time_path"]) - elif g.get("sound_speed_for_time_scalar") is not None: - sound_speed = g["sound_speed_for_time_scalar"] - - if sound_speed is not None: - cfl = g.get("cfl") - if cfl is not None: - kgrid.makeTime(sound_speed, cfl=cfl) - else: - kgrid.makeTime(sound_speed) + # Resolve sound speed for time stepping from medium state + m = self.state.get("medium") + if m is not None: + sound_speed = self._resolve_field(m, "sound_speed") + if sound_speed is not None: + cfl = g.get("cfl") + if cfl is not None: + kgrid.makeTime(sound_speed, cfl=cfl) + else: + kgrid.makeTime(sound_speed) return kgrid def make_medium(self): @@ -151,18 +179,10 @@ def make_medium(self): if m is None: raise SessionError("Medium not defined. Run 'kwave phantom generate' or 'kwave phantom load' first.") kwargs = {} - # Handle fields that can be scalars or .npy paths for field in ("sound_speed", "density", "alpha_coeff", "alpha_power", "BonA"): - # Check for path-based storage (from phantom load) - path_key = f"{field}_path" - scalar_key = f"{field}_scalar" - if path_key in m and m[path_key] is not None: - kwargs[field] = np.load(m[path_key]) - elif scalar_key in m and m[scalar_key] is not None: - kwargs[field] = m[scalar_key] - # Fallback: direct scalar value (from phantom generate) - elif field in m and m[field] is not None: - kwargs[field] = m[field] + val = self._resolve_field(m, field) + if val is not None: + kwargs[field] = val return kWaveMedium(**kwargs) def make_source(self): @@ -197,16 +217,6 @@ def make_sensor(self): # --- Private --- - def _completeness(self) -> dict: - """Which steps have been completed.""" - s = self.state - return { - "grid": s["grid"] is not None, - "medium": s["medium"] is not None, - "source": s["source"] is not None, - "sensor": s["sensor"] is not None, - } - def _save(self): raw = { "id": self._id, diff --git a/tests/test_cli/conftest.py b/tests/test_cli/conftest.py new file mode 100644 index 00000000..d159e678 --- /dev/null +++ b/tests/test_cli/conftest.py @@ -0,0 +1,43 @@ +"""Shared fixtures and helpers for CLI tests.""" + +import json + +import pytest +from click.testing import CliRunner + +from kwave.cli.main import cli + + +def invoke(runner, args, session_dir): + """Invoke a CLI command and parse the final JSON response.""" + result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False) + assert result.exit_code == 0, f"Command failed: {args}\n{result.output}" + output = result.output.strip() + try: + return json.loads(output) + except json.JSONDecodeError: + pass + # For run command: progress events precede the final JSON response. + # Find the last top-level JSON object. + depth = 0 + last_start = None + for i, ch in enumerate(output): + if ch == "{" and depth == 0: + last_start = i + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if last_start is not None: + return json.loads(output[last_start:]) + raise ValueError(f"Could not parse JSON from output: {output[:200]}") + + +@pytest.fixture +def session_dir(tmp_path): + return tmp_path / "kwave_test_session" + + +@pytest.fixture +def runner(): + return CliRunner() diff --git a/tests/test_cli/test_1d_ivp.py b/tests/test_cli/test_1d_ivp.py index 6a5d9ccd..eb6415f6 100644 --- a/tests/test_cli/test_1d_ivp.py +++ b/tests/test_cli/test_1d_ivp.py @@ -4,45 +4,16 @@ custom p0, sparse sensor mask, custom CFL. """ -import json - import numpy as np import pytest -from click.testing import CliRunner -from kwave.cli.main import cli from kwave.data import Vector from kwave.kgrid import kWaveGrid from kwave.kmedium import kWaveMedium from kwave.ksensor import kSensor from kwave.ksource import kSource from kwave.kspaceFirstOrder import kspaceFirstOrder - - -def _invoke(runner, args, session_dir): - result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False) - assert result.exit_code == 0, f"Command failed: {args}\n{result.output}" - output = result.output.strip() - try: - return json.loads(output) - except json.JSONDecodeError: - pass - # For run command: find the last top-level JSON object - depth = 0 - last_start = None - for i, ch in enumerate(output): - if ch == "{" and depth == 0: - last_start = i - if ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if last_start is not None: - return json.loads(output[last_start:]) - raise ValueError(f"Could not parse JSON from output: {output[:200]}") - - -# --- Build the 1D IVP arrays (same as ivp_1D_simulation.py) --- +from tests.test_cli.conftest import invoke Nx = 512 dx = 0.05e-3 @@ -75,14 +46,9 @@ def _make_sensor_mask(): return mask -@pytest.fixture -def session_dir(tmp_path): - return tmp_path / "kwave_test_session" - - @pytest.fixture def data_dir(tmp_path): - """Directory for pre-built .npy files (simulating what an agent would prepare).""" + """Directory with pre-built .npy files (simulating agent-prepared arrays).""" d = tmp_path / "arrays" d.mkdir() np.save(d / "sound_speed.npy", _make_sound_speed()) @@ -92,19 +58,13 @@ def data_dir(tmp_path): return d -@pytest.fixture -def runner(): - return CliRunner() - - class TestCLI1DIVP: """Replicate ivp_1D_simulation.py end-to-end via CLI.""" def test_cli_matches_python_api(self, runner, session_dir, data_dir): - # -- CLI flow -- - _invoke(runner, ["session", "init"], session_dir) + invoke(runner, ["session", "init"], session_dir) - _invoke( + invoke( runner, [ "phantom", @@ -123,7 +83,7 @@ def test_cli_matches_python_api(self, runner, session_dir, data_dir): session_dir, ) - _invoke( + invoke( runner, [ "source", @@ -136,7 +96,7 @@ def test_cli_matches_python_api(self, runner, session_dir, data_dir): session_dir, ) - _invoke( + invoke( runner, [ "sensor", @@ -149,28 +109,25 @@ def test_cli_matches_python_api(self, runner, session_dir, data_dir): session_dir, ) - plan_resp = _invoke(runner, ["plan"], session_dir) + plan_resp = invoke(runner, ["plan"], session_dir) assert plan_resp["status"] == "ok" assert plan_resp["result"]["grid"]["N"] == [512] assert plan_resp["result"]["grid"]["Nt"] > 0 - run_resp = _invoke(runner, ["run"], session_dir) + run_resp = invoke(runner, ["run"], session_dir) assert run_resp["status"] == "ok" - # Load CLI results cli_p = np.load(run_resp["result"]["outputs"]["p"]["path"]) - # -- Direct Python API (the example) -- + # Direct Python API sound_speed = _make_sound_speed() - density = _make_density() kgrid = kWaveGrid(Vector([Nx]), Vector([dx])) kgrid.makeTime(sound_speed, cfl=0.3) - medium = kWaveMedium(sound_speed=sound_speed, density=density) + medium = kWaveMedium(sound_speed=sound_speed, density=_make_density()) source = kSource() source.p0 = _make_p0() sensor = kSensor(mask=_make_sensor_mask()) result = kspaceFirstOrder(kgrid, medium, source, sensor, backend="python", quiet=True) - # -- Compare -- - assert cli_p.shape == result["p"].shape, f"Shape mismatch: {cli_p.shape} vs {result['p'].shape}" + assert cli_p.shape == result["p"].shape np.testing.assert_allclose(cli_p, result["p"], rtol=0, atol=0) diff --git a/tests/test_cli/test_e2e.py b/tests/test_cli/test_e2e.py index 7fefa901..99402c72 100644 --- a/tests/test_cli/test_e2e.py +++ b/tests/test_cli/test_e2e.py @@ -14,47 +14,12 @@ from kwave.ksource import kSource from kwave.kspaceFirstOrder import kspaceFirstOrder from kwave.utils.mapgen import make_disc - - -def _invoke(runner, args, session_dir): - result = runner.invoke(cli, ["--session-dir", str(session_dir)] + args, catch_exceptions=False) - assert result.exit_code == 0, f"Command failed: {args}\n{result.output}" - # The run command emits progress events (single-line JSON) before the final - # multi-line JSON response. Find the last complete JSON object. - output = result.output.strip() - # Try parsing the full output first (non-run commands produce a single JSON object) - try: - return json.loads(output) - except json.JSONDecodeError: - pass - # For run command: find the last '{' that starts a top-level JSON object - depth = 0 - last_start = None - for i, ch in enumerate(output): - if ch == "{" and depth == 0: - last_start = i - if ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if last_start is not None: - return json.loads(output[last_start:]) - raise ValueError(f"Could not parse JSON from output: {output[:200]}") - - -@pytest.fixture -def session_dir(tmp_path): - return tmp_path / "kwave_test_session" - - -@pytest.fixture -def runner(): - return CliRunner() +from tests.test_cli.conftest import invoke class TestSessionLifecycle: def test_init(self, runner, session_dir): - resp = _invoke(runner, ["session", "init"], session_dir) + resp = invoke(runner, ["session", "init"], session_dir) assert resp["status"] == "ok" assert "session_id" in resp["result"] @@ -63,15 +28,15 @@ def test_status_without_init_fails(self, runner, session_dir): assert result.exit_code != 0 def test_reset(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) - resp = _invoke(runner, ["session", "reset"], session_dir) + invoke(runner, ["session", "init"], session_dir) + resp = invoke(runner, ["session", "reset"], session_dir) assert resp["result"]["reset"] is True class TestPhantomGenerate: def test_disc_phantom(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) - resp = _invoke( + invoke(runner, ["session", "init"], session_dir) + resp = invoke( runner, [ "phantom", @@ -96,21 +61,10 @@ def test_disc_phantom(self, runner, session_dir): assert resp["result"]["p0_max"] == 1.0 def test_disc_requires_2d(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) + invoke(runner, ["session", "init"], session_dir) result = runner.invoke( cli, - [ - "--session-dir", - str(session_dir), - "phantom", - "generate", - "--type", - "disc", - "--grid-size", - "64,64,64", - "--spacing", - "0.1e-3", - ], + ["--session-dir", str(session_dir), "phantom", "generate", "--type", "disc", "--grid-size", "64,64,64", "--spacing", "0.1e-3"], ) assert result.exit_code != 0 resp = json.loads(result.output) @@ -119,21 +73,21 @@ def test_disc_requires_2d(self, runner, session_dir): class TestSensorDefine: def test_full_grid(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) - resp = _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + invoke(runner, ["session", "init"], session_dir) + resp = invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) assert resp["result"]["mask_type"] == "full-grid" assert resp["result"]["record"] == ["p", "p_final"] class TestPlan: def test_plan_incomplete_session(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) + invoke(runner, ["session", "init"], session_dir) result = runner.invoke(cli, ["--session-dir", str(session_dir), "plan"]) assert result.exit_code != 0 def test_plan_complete_session(self, runner, session_dir): - _invoke(runner, ["session", "init"], session_dir) - _invoke( + invoke(runner, ["session", "init"], session_dir) + invoke( runner, [ "phantom", @@ -151,8 +105,8 @@ def test_plan_complete_session(self, runner, session_dir): ], session_dir, ) - _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) - resp = _invoke(runner, ["plan"], session_dir) + invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + resp = invoke(runner, ["plan"], session_dir) assert resp["status"] == "ok" assert resp["result"]["grid"]["Nt"] > 0 assert resp["derived"]["cfl"] > 0 @@ -161,10 +115,12 @@ def test_plan_complete_session(self, runner, session_dir): class TestEndToEnd: """Replicate new_api_ivp_2D.py via CLI and compare results.""" + N = 48 # small grid for fast CI (must be > 2*pml_size=40) + def test_cli_matches_python_api(self, runner, session_dir): - # Run via CLI - _invoke(runner, ["session", "init"], session_dir) - _invoke( + N = self.N + invoke(runner, ["session", "init"], session_dir) + invoke( runner, [ "phantom", @@ -172,7 +128,7 @@ def test_cli_matches_python_api(self, runner, session_dir): "--type", "disc", "--grid-size", - "128,128", + f"{N},{N}", "--spacing", "0.1e-3", "--sound-speed", @@ -180,30 +136,27 @@ def test_cli_matches_python_api(self, runner, session_dir): "--density", "1000", "--disc-center", - "64,64", + f"{N // 2},{N // 2}", "--disc-radius", - "5", + "3", ], session_dir, ) - _invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) - resp = _invoke(runner, ["run"], session_dir) + invoke(runner, ["sensor", "define", "--mask", "full-grid", "--record", "p,p_final"], session_dir) + resp = invoke(runner, ["run"], session_dir) assert resp["status"] == "ok" - # Load CLI results cli_p = np.load(resp["result"]["outputs"]["p"]["path"]) cli_p_final = np.load(resp["result"]["outputs"]["p_final"]["path"]) - # Run directly via Python API (the example) - kgrid = kWaveGrid([128, 128], [0.1e-3, 0.1e-3]) + kgrid = kWaveGrid([N, N], [0.1e-3, 0.1e-3]) kgrid.makeTime(1500) medium = kWaveMedium(sound_speed=1500, density=1000) source = kSource() - source.p0 = make_disc(Vector([128, 128]), Vector([64, 64]), 5).astype(float) - sensor = kSensor(mask=np.ones((128, 128), dtype=bool)) + source.p0 = make_disc(Vector([N, N]), Vector([N // 2, N // 2]), 3).astype(float) + sensor = kSensor(mask=np.ones((N, N), dtype=bool)) result = kspaceFirstOrder(kgrid, medium, source, sensor, quiet=True) - # Compare assert cli_p.shape == result["p"].shape assert cli_p_final.shape == result["p_final"].shape np.testing.assert_allclose(cli_p, result["p"], rtol=0, atol=0) From 9ec566e47ef7ef249ad66588d1e94de4c597187a Mon Sep 17 00:00:00 2001 From: Walter Simson Date: Fri, 27 Mar 2026 10:57:24 -0700 Subject: [PATCH 4/4] Skip CLI tests when click is not installed Uses pytest.importorskip so test collection succeeds in CI environments that don't install the cli extra. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_cli/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_cli/conftest.py b/tests/test_cli/conftest.py index d159e678..3b4d01b4 100644 --- a/tests/test_cli/conftest.py +++ b/tests/test_cli/conftest.py @@ -3,9 +3,11 @@ import json import pytest -from click.testing import CliRunner -from kwave.cli.main import cli +click = pytest.importorskip("click", reason="CLI tests require click: pip install k-wave-python[cli]") +from click.testing import CliRunner # noqa: E402 + +from kwave.cli.main import cli # noqa: E402 def invoke(runner, args, session_dir):