Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ docs = [
"pillow>=10.0.0",
"cairosvg>=2.7.1"
]
hf = ["huggingface-hub>=0.20.0"]
zarr_conversion = [
"fire>=0.5.0",
"numcodecs>=0.16.3",
Expand Down
5 changes: 5 additions & 0 deletions src/electrai/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from electrai.callbacks.hf_upload import HuggingFaceCallback

__all__ = ["HuggingFaceCallback"]
182 changes: 182 additions & 0 deletions src/electrai/callbacks/hf_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING

from lightning.pytorch.callbacks import Callback

if TYPE_CHECKING:
from types import SimpleNamespace

logger = logging.getLogger(__name__)

MANIFEST_FILENAME = "hf_upload_manifest.json"


class HuggingFaceCallback(Callback):
"""Tracks saved checkpoints for deferred upload to HuggingFace Hub.

On clusters without internet (e.g. Princeton Della), checkpoints are
queued in a JSON manifest and uploaded later via ``electrai hf-push``.
When ``hf.upload_immediate`` is True, uploads are attempted inline
(failures are logged but never crash training).
"""

def __init__(self, cfg: SimpleNamespace) -> None:
super().__init__()
hf = cfg.hf
self.repo_id: str = hf["repo_id"]
self.every_n_epochs: int = hf.get("upload_every_n_epochs", 5)
self.upload_immediate: bool = hf.get("upload_immediate", False)
self.ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))
self.manifest_path = self.ckpt_path / MANIFEST_FILENAME
self._manifest: list[dict] = []
self._load_existing_manifest()

def _load_existing_manifest(self) -> None:
if self.manifest_path.exists():
with self.manifest_path.open(encoding="utf-8") as f:
self._manifest = json.load(f)

def _save_manifest(self) -> None:
self.ckpt_path.mkdir(parents=True, exist_ok=True)
with self.manifest_path.open("w", encoding="utf-8") as f:
json.dump(self._manifest, f, indent=2)

def _queue_checkpoint(
self, ckpt_file: Path, epoch: int | None, *, path_in_repo: str | None = None
) -> None:
entry = {
"path": str(ckpt_file),
"path_in_repo": path_in_repo or ckpt_file.name,
"epoch": epoch,
"repo_id": self.repo_id,
"uploaded": False,
}
self._manifest.append(entry)
self._save_manifest()
logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name)
Comment on lines +48 to +60

def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002
if trainer.sanity_checking:
return
epoch = trainer.current_epoch
if (epoch + 1) % self.every_n_epochs != 0:
return
if trainer.global_rank != 0:
return

# Save the current state to a stable epoch-specific file. This is
# independent of ModelCheckpoint's last.ckpt (which Lightning reorders
# to run after us in this hook, so last.ckpt would still be stale).
stable_name = f"last_epoch{epoch + 1:03d}.ckpt"
stable_path = self.ckpt_path / stable_name
trainer.save_checkpoint(stable_path)

self._queue_checkpoint(stable_path, epoch, path_in_repo=stable_name)

if self.upload_immediate:
_upload_single(self._manifest[-1])
if self._manifest[-1]["uploaded"]:
stable_path.unlink(missing_ok=True)
self._save_manifest()

def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002
if trainer.global_rank != 0:
return
# Queue best checkpoints that haven't been queued yet
queued_paths = {e["path"] for e in self._manifest}
had_immediate = False
for ckpt_file in self.ckpt_path.glob("ckpt_*.ckpt"):
if str(ckpt_file) not in queued_paths:
self._queue_checkpoint(ckpt_file, epoch=None)
if self.upload_immediate:
_upload_single(self._manifest[-1])
had_immediate = True
if had_immediate:
self._save_manifest()

pending = sum(1 for e in self._manifest if not e["uploaded"])
if pending:
logger.info(
"%d checkpoint(s) pending upload. "
"Run 'electrai hf-push --ckpt-path %s' from a node with "
"internet access.",
pending,
self.ckpt_path,
)


def _upload_single(entry: dict) -> None:
"""Attempt to upload a single checkpoint. Logs errors, never raises."""
path = Path(entry["path"])
try:
from huggingface_hub import upload_file
except ImportError:
logger.warning(
"huggingface-hub is not installed. "
"Run 'pip install huggingface-hub' to enable uploads."
)
return
try:
if not path.exists():
logger.warning("Checkpoint file not found, skipping: %s", path)
return
path_in_repo = entry.get("path_in_repo", path.name)
upload_file(
path_or_fileobj=str(path),
path_in_repo=path_in_repo,
repo_id=entry["repo_id"],
)
entry["uploaded"] = True
logger.info("Uploaded %s to %s", path.name, entry["repo_id"])
except Exception:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bare Exception catch makes it hard to distinguish between errors (e.g. network issue, wrong repo_id or missing authentication token). Would it be worth catching HfHubHTTPError so those surface more clearly?

logger.warning(
"HF upload failed for %s (will retry with hf-push)",
path.name,
exc_info=True,
)


def hf_push(ckpt_path: str, *, clean: bool = False) -> None:
"""Upload pending checkpoints from a manifest file.

Run this from a login node or machine with internet access.
"""
try:
import huggingface_hub # noqa: F401
except ImportError as e:
raise SystemExit(
"huggingface-hub is not installed. "
"Run 'uv sync --extra hf' to enable uploads."
) from e

ckpt_dir = Path(ckpt_path)
manifest_path = ckpt_dir / MANIFEST_FILENAME
if not manifest_path.exists():
raise SystemExit(f"No manifest found at {manifest_path}")

with manifest_path.open(encoding="utf-8") as f:
manifest = json.load(f)

pending = [e for e in manifest if not e["uploaded"]]
if not pending:
logger.info("All checkpoints already uploaded.")
return

logger.info("Uploading %d pending checkpoint(s)...", len(pending))
for entry in pending:
_upload_single(entry)
Comment on lines +143 to +171
if clean and entry["uploaded"]:
Path(entry["path"]).unlink(missing_ok=True)

with manifest_path.open("w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)

still_pending = sum(1 for e in manifest if not e["uploaded"])
if still_pending:
logger.warning("%d checkpoint(s) still failed to upload.", still_pending)
else:
logger.info("All checkpoints uploaded successfully.")
6 changes: 6 additions & 0 deletions src/electrai/configs/MP/config_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ wb_pname: mp-experiment
# checkpoints
ckpt_path: ./checkpoints

# HuggingFace Hub (optional — install with `uv sync --extra hf`)
# hf:
# repo_id: your-username/your-repo # must already exist on HF
# upload_every_n_epochs: 5
# upload_immediate: false # set true on nodes with internet access

# test the model
# save_pred: true
# log_dir: ./logs
Expand Down
6 changes: 6 additions & 0 deletions src/electrai/configs/MP/config_resunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ wb_pname: mp-experiment
# checkpoints
ckpt_path: ./checkpoints

# HuggingFace Hub (optional — install with `uv sync --extra hf`)
# hf:
# repo_id: your-username/your-repo # must already exist on HF
# upload_every_n_epochs: 5
# upload_immediate: false # set true on nodes with internet access

# test the model
# save_pred: true
# log_dir: ./logs
Expand Down
20 changes: 18 additions & 2 deletions src/electrai/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import logging

import torch

Expand All @@ -23,6 +24,7 @@ def main() -> None:
RuntimeError
if no command was input
"""
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Electrai Entry Point")
subparsers = parser.add_subparsers(dest="command", required=True)

Expand All @@ -32,14 +34,28 @@ def main() -> None:
test_parser = subparsers.add_parser("test", help="Evaluate the model")
test_parser.add_argument("--config", type=str, required=True)

hf_push_parser = subparsers.add_parser(
"hf-push", help="Upload pending checkpoints to HuggingFace Hub"
)
hf_push_parser.add_argument(
"--ckpt-path", type=str, required=True, help="Path to checkpoint directory"
)
hf_push_parser.add_argument(
"--clean",
action="store_true",
help="Delete local checkpoint files after successful upload (includes best-model checkpoints)",
)

args = parser.parse_args()

if args.command == "train":
train(args)
elif args.command == "test":
test(args)
else:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because argparse with required=True subparsers already handles unknown commands

raise ValueError(f"Unknown command: {args.command}")
elif args.command == "hf-push":
from electrai.callbacks.hf_upload import hf_push

hf_push(args.ckpt_path, clean=args.clean)


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def train(args):

lr_monitor = LearningRateMonitor(logging_interval="epoch")

callbacks = [checkpoint_cb, lr_monitor]

hf_cfg = getattr(cfg, "hf", None)
if hf_cfg and hf_cfg.get("repo_id"):
from electrai.callbacks.hf_upload import HuggingFaceCallback

callbacks.append(HuggingFaceCallback(cfg))

# -----------------------------
# Trainer
# -----------------------------
Expand All @@ -69,7 +77,7 @@ def train(args):
trainer = Trainer(
max_epochs=int(cfg.epochs),
logger=wandb_logger,
callbacks=[checkpoint_cb, lr_monitor],
callbacks=callbacks,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
precision=cfg.precision,
devices="auto",
Expand Down
Loading
Loading