Skip to content
117 changes: 109 additions & 8 deletions src/electrai/entrypoints/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from types import SimpleNamespace

Expand All @@ -10,6 +12,47 @@

from electrai.lightning import LightningGenerator

logger = logging.getLogger(__name__)


def _resolve_checkpoint(cfg) -> Path:
"""Find the best available checkpoint from config.

Resolution order:
1. cfg.ckpt_file — explicit path to a specific .ckpt file
2. cfg.ckpt_path / "last.ckpt"
3. cfg.ckpt_path / "best.ckpt"
4. Latest ckpt_*.ckpt in cfg.ckpt_path (highest epoch by lexicographic sort)
"""
ckpt_file = getattr(cfg, "ckpt_file", None)
if ckpt_file is not None:
ckpt = Path(ckpt_file)
if ckpt.exists():
return ckpt
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")

ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))

# If ckpt_path is itself a file, use it directly
if ckpt_path.is_file():
return ckpt_path

for name in ("last.ckpt", "best.ckpt"):
candidate = ckpt_path / name
if candidate.exists():
return candidate

# Glob for ckpt_*.ckpt and pick the latest epoch by lexicographic sort
candidates = sorted(ckpt_path.glob("ckpt_*.ckpt"))
if candidates:
return candidates[-1]

raise FileNotFoundError(
f"No checkpoint found in {ckpt_path}. "
"Set ckpt_file to an explicit path, or ensure ckpt_path contains "
"last.ckpt, best.ckpt, or ckpt_*.ckpt files."
)


def test(args):
# -----------------------------
Expand All @@ -29,12 +72,22 @@ def test(args):
# Model (LightningModule handles architecture + loss + optimizer)
# -----------------------------
lit_model = LightningGenerator(cfg)
lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir)

# -----------------------------
# Callback
# W&B (optional)
# -----------------------------
ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))
wandb_mode = getattr(cfg, "wandb_mode", "disabled").lower()
os.environ["WANDB_MODE"] = wandb_mode
if wandb_mode != "disabled":
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(
project=getattr(cfg, "wb_pname", "electrai"),
entity=getattr(cfg, "entity", None),
config=vars(cfg),
)
else:
wandb_logger = None

# -----------------------------
# Trainer
Expand All @@ -49,7 +102,7 @@ def test(args):
for directory in [log_dir, tmp_dir]:
directory.mkdir(exist_ok=True, parents=True)
trainer = Trainer(
logger=None,
logger=wandb_logger,
callbacks=None,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
Expand All @@ -61,10 +114,58 @@ def test(args):
)

# -----------------------------
# Train
# Resolve checkpoint and run test
# -----------------------------
ckpt = ckpt_path / "last.ckpt"
if not ckpt.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
ckpt = _resolve_checkpoint(cfg)
logger.info("Using checkpoint: %s", ckpt)

trainer.test(model=lit_model, datamodule=datamodule, ckpt_path=ckpt)

# -----------------------------
# Post-test analysis
# -----------------------------
metrics_csv = log_dir / "metrics.csv"
if metrics_csv.exists():
from electrai.scripts.analyze.summarize import plot_distribution, summarize

summary_text = summarize(metrics_csv, output_dir=log_dir)
logger.info("\n%s", summary_text)
plot_distribution(metrics_csv, output_dir=log_dir)

if wandb_logger is not None:
from electrai.scripts.analyze.summarize import log_to_wandb

log_to_wandb(metrics_csv, output_dir=log_dir)

# Optional: saturation analysis (always possible with enriched CSV)
analyze_cfg = getattr(cfg, "analyze", None)
run_analysis = analyze_cfg is None or getattr(analyze_cfg, "enabled", True)

if run_analysis:
from electrai.scripts.analyze.analyze_saturation import analyze_metrics

saturation_dir = log_dir / "saturation"
saturation_dir.mkdir(exist_ok=True, parents=True)
try:
analyze_metrics(metrics_csv, saturation_dir)
except (KeyError, ValueError) as e:
logger.warning("Saturation analysis skipped: %s", e)

# Tail analysis requires metadata CSV
metadata_path = (
getattr(analyze_cfg, "metadata", None) if analyze_cfg else None
)
if metadata_path is not None:
from electrai.scripts.analyze.analyze_tail import main as tail_main

tail_dir = log_dir / "tail"
tail_main(
[
"--metrics",
str(metrics_csv),
"--metadata",
str(metadata_path),
"--output-dir",
str(tail_dir),
]
)
46 changes: 39 additions & 7 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,26 @@ def test_step(self, batch):

self.log("test_loss", loss, prog_bar=True, sync_dist=True)

# Per-sample statistics over spatial dims (keep batch dim)
spatial_dims = tuple(range(1, preds.ndim)) # all dims except batch
out = {
"target": y.detach().cpu(),
"index": indices,
"nmae": loss.detach().cpu(),
"duration": elapsed,
"loss": loss.detach().cpu(),
"max_pred": preds.amax(dim=spatial_dims).detach().cpu(),
"max_target": y.amax(dim=spatial_dims).detach().cpu(),
"mean_pred": preds.mean(dim=spatial_dims).detach().cpu(),
"mean_target": y.mean(dim=spatial_dims).detach().cpu(),
"num_electrons": y.sum(dim=spatial_dims).detach().cpu(),
"duration_ms": elapsed,
}
if self.save_pred:
out["pred"] = preds.detach().cpu()
return out

def on_test_batch_end(self, outputs, _batch, batch_idx):
indices = outputs["index"]
nmae = outputs["nmae"]

if self.save_pred:
preds = outputs["pred"]
Expand All @@ -126,14 +133,36 @@ def on_test_batch_end(self, outputs, _batch, batch_idx):
preds[i].squeeze(0).cpu().numpy(),
)

if isinstance(nmae, torch.Tensor) and nmae.ndim == 0:
nmae = nmae.unsqueeze(0)
# Ensure scalar tensors are iterable (batch_size=1 produces 0-d tensors)
per_sample_keys = (
"nmae",
"loss",
"max_pred",
"max_target",
"mean_pred",
"mean_target",
"num_electrons",
)
for key in per_sample_keys:
val = outputs[key]
if isinstance(val, torch.Tensor) and val.ndim == 0:
outputs[key] = val.unsqueeze(0)

n_samples = len(indices)
duration_per_sample = outputs["duration_ms"] / n_samples

tmp_csv = (
self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv"
)
with tmp_csv.open("w") as f:
for idx, n in zip(indices, nmae, strict=True):
f.write(f"rank_{self.global_rank},{idx},{n.item()}\n")
for i, idx in enumerate(indices):
f.write(
f"rank_{self.global_rank},{idx},"
f"{outputs['nmae'][i].item()},{outputs['loss'][i].item()},"
f"{outputs['max_pred'][i].item()},{outputs['max_target'][i].item()},"
f"{outputs['mean_pred'][i].item()},{outputs['mean_target'][i].item()},"
f"{outputs['num_electrons'][i].item()},{duration_per_sample}\n"
)

def on_test_epoch_end(self):
is_dist = dist.is_available() and dist.is_initialized()
Expand Down Expand Up @@ -168,7 +197,10 @@ def on_test_epoch_end(self):
)

with final_csv.open("w") as f_out:
f_out.write("rank,index,nmae\n")
f_out.write(
"rank,index,nmae,loss,max_pred,max_target,"
"mean_pred,mean_target,num_electrons,duration_ms\n"
)
for tmp_csv in all_tmp_csvs:
with tmp_csv.open() as f_in:
for line in f_in:
Expand Down
Loading
Loading