From 88027785cbe03e16aaeb751c1ae28d9bf99cd1c2 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 10:25:55 -0500 Subject: [PATCH 01/11] Add optional HuggingFace Hub checkpoint uploading Supports deferred uploads for clusters without internet (e.g. Della) via a JSON manifest and `electrai hf-push` CLI command. Immediate upload mode available for nodes with connectivity. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 1 + src/electrai/callbacks/__init__.py | 1 + src/electrai/callbacks/hf_upload.py | 152 ++++++++++++++ src/electrai/configs/MP/config_resnet.yaml | 7 + src/electrai/configs/MP/config_resunet.yaml | 7 + src/electrai/entrypoints/main.py | 11 + src/electrai/entrypoints/train.py | 10 +- uv.lock | 210 +++++++++++++++++++- 8 files changed, 397 insertions(+), 2 deletions(-) create mode 100644 src/electrai/callbacks/__init__.py create mode 100644 src/electrai/callbacks/hf_upload.py diff --git a/pyproject.toml b/pyproject.toml index d8920e88..c854d6d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ docs = [ "pillow>=10.0.0", "cairosvg>=2.7.1" ] +hf = ["huggingface-hub>=0.20.0"] zarr_conversion = [ "fire>=0.5.0", "numcodecs>=0.16.3", diff --git a/src/electrai/callbacks/__init__.py b/src/electrai/callbacks/__init__.py new file mode 100644 index 00000000..9d48db4f --- /dev/null +++ b/src/electrai/callbacks/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py new file mode 100644 index 00000000..0934dcb9 --- /dev/null +++ b/src/electrai/callbacks/hf_upload.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +from lightning.pytorch.callbacks import Callback + +if TYPE_CHECKING: + from types import SimpleNamespace + +logger = logging.getLogger(__name__) + +MANIFEST_FILENAME = "hf_upload_manifest.json" + + +class HuggingFaceCallback(Callback): + """Tracks saved checkpoints for deferred upload to HuggingFace Hub. + + On clusters without internet (e.g. Princeton Della), checkpoints are + queued in a JSON manifest and uploaded later via ``electrai hf-push``. + When ``hf_upload_immediate`` is True, uploads are attempted inline + (failures are logged but never crash training). + """ + + def __init__(self, cfg: SimpleNamespace) -> None: + super().__init__() + hf = cfg.hf + self.repo_id: str = hf["repo_id"] + self.every_n_epochs: int = hf.get("upload_every_n_epochs", 5) + self.upload_immediate: bool = hf.get("upload_immediate", False) + self.private: bool = hf.get("private", True) + self.ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints")) + self.manifest_path = self.ckpt_path / MANIFEST_FILENAME + self._manifest: list[dict] = [] + self._load_existing_manifest() + + def _load_existing_manifest(self) -> None: + if self.manifest_path.exists(): + with Path.open(self.manifest_path) as f: + self._manifest = json.load(f) + + def _save_manifest(self) -> None: + self.ckpt_path.mkdir(parents=True, exist_ok=True) + with Path.open(self.manifest_path, "w") as f: + json.dump(self._manifest, f, indent=2) + + def _queue_checkpoint(self, ckpt_file: Path, epoch: int) -> None: + entry = { + "path": str(ckpt_file), + "epoch": epoch, + "repo_id": self.repo_id, + "uploaded": False, + } + self._manifest.append(entry) + self._save_manifest() + logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name) + + def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 + epoch = trainer.current_epoch + if (epoch + 1) % self.every_n_epochs != 0: + return + if trainer.global_rank != 0: + return + + last_ckpt = self.ckpt_path / "last.ckpt" + if not last_ckpt.exists(): + return + + self._queue_checkpoint(last_ckpt, epoch) + + if self.upload_immediate: + _upload_single(self._manifest[-1]) + self._save_manifest() + + def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002 + if trainer.global_rank != 0: + return + # Queue best checkpoints that haven't been queued yet + queued_paths = {e["path"] for e in self._manifest} + for ckpt_file in self.ckpt_path.glob("ckpt_*.ckpt"): + if str(ckpt_file) not in queued_paths: + self._queue_checkpoint(ckpt_file, epoch=-1) + if self.upload_immediate: + _upload_single(self._manifest[-1]) + self._save_manifest() + + pending = sum(1 for e in self._manifest if not e["uploaded"]) + if pending: + logger.info( + "%d checkpoint(s) pending upload. " + "Run 'electrai hf-push --ckpt-path %s' from a node with " + "internet access.", + pending, + self.ckpt_path, + ) + + +def _upload_single(entry: dict) -> None: + """Attempt to upload a single checkpoint. Logs errors, never raises.""" + try: + from huggingface_hub import upload_file + + path = Path(entry["path"]) + if not path.exists(): + logger.warning("Checkpoint file not found, skipping: %s", path) + return + upload_file( + path_or_fileobj=str(path), path_in_repo=path.name, repo_id=entry["repo_id"] + ) + entry["uploaded"] = True + logger.info("Uploaded %s to %s", path.name, entry["repo_id"]) + except Exception: + logger.warning( + "HF upload failed for %s (will retry with hf-push)", + path.name, + exc_info=True, + ) + + +def hf_push(ckpt_path: str) -> None: + """Upload pending checkpoints from a manifest file. + + Run this from a login node or machine with internet access. + """ + ckpt_dir = Path(ckpt_path) + manifest_path = ckpt_dir / MANIFEST_FILENAME + if not manifest_path.exists(): + logger.error("No manifest found at %s", manifest_path) + return + + with Path.open(manifest_path) as f: + manifest = json.load(f) + + pending = [e for e in manifest if not e["uploaded"]] + if not pending: + logger.info("All checkpoints already uploaded.") + return + + logger.info("Uploading %d pending checkpoint(s)...", len(pending)) + for entry in pending: + _upload_single(entry) + + with Path.open(manifest_path, "w") as f: + json.dump(manifest, f, indent=2) + + still_pending = sum(1 for e in manifest if not e["uploaded"]) + if still_pending: + logger.warning("%d checkpoint(s) still failed to upload.", still_pending) + else: + logger.info("All checkpoints uploaded successfully.") diff --git a/src/electrai/configs/MP/config_resnet.yaml b/src/electrai/configs/MP/config_resnet.yaml index 58938fa3..343d2c44 100644 --- a/src/electrai/configs/MP/config_resnet.yaml +++ b/src/electrai/configs/MP/config_resnet.yaml @@ -43,6 +43,13 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +# HuggingFace Hub (optional — install with `uv sync --extra hf`) +# hf: +# repo_id: your-username/your-repo # must already exist on HF +# upload_every_n_epochs: 5 +# upload_immediate: false # set true on nodes with internet access +# private: true + # test the model # save_pred: true # log_dir: ./logs diff --git a/src/electrai/configs/MP/config_resunet.yaml b/src/electrai/configs/MP/config_resunet.yaml index a16049cc..73b08bf9 100644 --- a/src/electrai/configs/MP/config_resunet.yaml +++ b/src/electrai/configs/MP/config_resunet.yaml @@ -43,6 +43,13 @@ wb_pname: mp-experiment # checkpoints ckpt_path: ./checkpoints +# HuggingFace Hub (optional — install with `uv sync --extra hf`) +# hf: +# repo_id: your-username/your-repo # must already exist on HF +# upload_every_n_epochs: 5 +# upload_immediate: false # set true on nodes with internet access +# private: true + # test the model # save_pred: true # log_dir: ./logs diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index 6ccc4725..8c61b7cb 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -32,12 +32,23 @@ def main() -> None: test_parser = subparsers.add_parser("test", help="Evaluate the model") test_parser.add_argument("--config", type=str, required=True) + hf_push_parser = subparsers.add_parser( + "hf-push", help="Upload pending checkpoints to HuggingFace Hub" + ) + hf_push_parser.add_argument( + "--ckpt-path", type=str, required=True, help="Path to checkpoint directory" + ) + args = parser.parse_args() if args.command == "train": train(args) elif args.command == "test": test(args) + elif args.command == "hf-push": + from electrai.callbacks.hf_upload import hf_push + + hf_push(args.ckpt_path) else: raise ValueError(f"Unknown command: {args.command}") diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index d51d6777..d841519e 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -58,6 +58,14 @@ def train(args): lr_monitor = LearningRateMonitor(logging_interval="epoch") + callbacks = [checkpoint_cb, lr_monitor] + + hf_cfg = getattr(cfg, "hf", None) + if hf_cfg and hf_cfg.get("repo_id"): + from electrai.callbacks.hf_upload import HuggingFaceCallback + + callbacks.append(HuggingFaceCallback(cfg)) + # ----------------------------- # Trainer # ----------------------------- @@ -69,7 +77,7 @@ def train(args): trainer = Trainer( max_epochs=int(cfg.epochs), logger=wandb_logger, - callbacks=[checkpoint_cb, lr_monitor], + callbacks=callbacks, accelerator="gpu" if torch.cuda.is_available() else "cpu", precision=cfg.precision, devices="auto", diff --git a/uv.lock b/uv.lock index fbe3c234..038c4f70 100644 --- a/uv.lock +++ b/uv.lock @@ -161,6 +161,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -170,6 +179,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } + +[[package]] +name = "anyio" +version = "4.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -715,6 +743,7 @@ name = "electrai" version = "0.0.1" source = { editable = "." } dependencies = [ + { name = "hydra-core" }, { name = "lightning" }, { name = "numpy" }, { name = "pymatgen" }, @@ -740,6 +769,9 @@ docs = [ { name = "mkdocstrings", extra = ["python"] }, { name = "pillow" }, ] +hf = [ + { name = "huggingface-hub" }, +] zarr-conversion = [ { name = "fire" }, { name = "numcodecs" }, @@ -759,6 +791,8 @@ dev = [ requires-dist = [ { name = "cairosvg", marker = "extra == 'docs'", specifier = ">=2.7.1" }, { name = "fire", marker = "extra == 'zarr-conversion'", specifier = ">=0.5.0" }, + { name = "huggingface-hub", marker = "extra == 'hf'", specifier = ">=0.20.0" }, + { name = "hydra-core", specifier = ">=1.3.2" }, { name = "lightning", specifier = "~=2.5.6" }, { name = "mkdocs-gen-files", marker = "extra == 'docs'", specifier = ">=0.5.0" }, { name = "mkdocs-literate-nav", marker = "extra == 'docs'", specifier = ">=0.6.0" }, @@ -781,7 +815,7 @@ requires-dist = [ { name = "zarr", specifier = ">=3.1.3" }, { name = "zarr", marker = "extra == 'zarr-conversion'", specifier = ">=3.1.3" }, ] -provides-extras = ["dev", "docs", "zarr-conversion"] +provides-extras = ["dev", "docs", "hf", "zarr-conversion"] [package.metadata.requires-dev] dev = [ @@ -1027,6 +1061,109 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/b1/9ff6578d789a89812ff21e4e0f80ffae20a65d5dd84e7a17873fe3b365be/griffe-1.14.0-py3-none-any.whl", hash = "sha256:0e9d52832cccf0f7188cfe585ba962d2674b241c01916d780925df34873bceb0", size = 144439, upload-time = "2025-09-05T15:02:27.511Z" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/cb/9bb543bd987ffa1ee48202cc96a756951b734b79a542335c566148ade36c/hf_xet-1.3.2.tar.gz", hash = "sha256:e130ee08984783d12717444e538587fa2119385e5bd8fc2bb9f930419b73a7af", size = 643646, upload-time = "2026-02-27T17:26:08.051Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/75/462285971954269432aad2e7938c5c7ff9ec7d60129cec542ab37121e3d6/hf_xet-1.3.2-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:335a8f36c55fd35a92d0062f4e9201b4015057e62747b7e7001ffb203c0ee1d2", size = 3761019, upload-time = "2026-02-27T17:25:49.441Z" }, + { url = "https://files.pythonhosted.org/packages/35/56/987b0537ddaf88e17192ea09afa8eca853e55f39a4721578be436f8409df/hf_xet-1.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c1ae4d3a716afc774e66922f3cac8206bfa707db13f6a7e62dfff74bfc95c9a8", size = 3521565, upload-time = "2026-02-27T17:25:47.469Z" }, + { url = "https://files.pythonhosted.org/packages/a8/5c/7e4a33a3d689f77761156cc34558047569e54af92e4d15a8f493229f6767/hf_xet-1.3.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d6dbdf231efac0b9b39adcf12a07f0c030498f9212a18e8c50224d0e84ab803d", size = 4176494, upload-time = "2026-02-27T17:25:40.247Z" }, + { url = "https://files.pythonhosted.org/packages/6b/b3/71e856bf9d9a69b3931837e8bf22e095775f268c8edcd4a9e8c355f92484/hf_xet-1.3.2-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c1980abfb68ecf6c1c7983379ed7b1e2b49a1aaf1a5aca9acc7d48e5e2e0a961", size = 3955601, upload-time = "2026-02-27T17:25:38.376Z" }, + { url = "https://files.pythonhosted.org/packages/63/d7/aecf97b3f0a981600a67ff4db15e2d433389d698a284bb0ea5d8fcdd6f7f/hf_xet-1.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1c88fbd90ad0d27c46b77a445f0a436ebaa94e14965c581123b68b1c52f5fd30", size = 4154770, upload-time = "2026-02-27T17:25:56.756Z" }, + { url = "https://files.pythonhosted.org/packages/e2/e1/3af961f71a40e09bf5ee909842127b6b00f5ab4ee3817599dc0771b79893/hf_xet-1.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:35b855024ca37f2dd113ac1c08993e997fbe167b9d61f9ef66d3d4f84015e508", size = 4394161, upload-time = "2026-02-27T17:25:58.111Z" }, + { url = "https://files.pythonhosted.org/packages/a1/c3/859509bade9178e21b8b1db867b8e10e9f817ab9ac1de77cb9f461ced765/hf_xet-1.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:31612ba0629046e425ba50375685a2586e11fb9144270ebabd75878c3eaf6378", size = 3637377, upload-time = "2026-02-27T17:26:10.611Z" }, + { url = "https://files.pythonhosted.org/packages/05/7f/724cfbef4da92d577b71f68bf832961c8919f36c60d28d289a9fc9d024d4/hf_xet-1.3.2-cp313-cp313t-win_arm64.whl", hash = "sha256:433c77c9f4e132b562f37d66c9b22c05b5479f243a1f06a120c1c06ce8b1502a", size = 3497875, upload-time = "2026-02-27T17:26:09.034Z" }, + { url = "https://files.pythonhosted.org/packages/ba/75/9d54c1ae1d05fb704f977eca1671747babf1957f19f38ae75c5933bc2dc1/hf_xet-1.3.2-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:c34e2c7aefad15792d57067c1c89b2b02c1bbaeabd7f8456ae3d07b4bbaf4094", size = 3761076, upload-time = "2026-02-27T17:25:55.42Z" }, + { url = "https://files.pythonhosted.org/packages/f2/8a/08a24b6c6f52b5d26848c16e4b6d790bb810d1bf62c3505bed179f7032d3/hf_xet-1.3.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4bc995d6c41992831f762096020dc14a65fdf3963f86ffed580b596d04de32e3", size = 3521745, upload-time = "2026-02-27T17:25:54.217Z" }, + { url = "https://files.pythonhosted.org/packages/b5/db/a75cf400dd8a1a8acf226a12955ff6ee999f272dfc0505bafd8079a61267/hf_xet-1.3.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:959083c89dee30f7d6f890b36cdadda823386c4de63b1a30384a75bfd2ae995d", size = 4176301, upload-time = "2026-02-27T17:25:46.044Z" }, + { url = "https://files.pythonhosted.org/packages/01/40/6c4c798ffdd83e740dd3925c4e47793b07442a9efa3bc3866ba141a82365/hf_xet-1.3.2-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:cfa760888633b08c01b398d212ce7e8c0d7adac6c86e4b20dfb2397d8acd78ee", size = 3955437, upload-time = "2026-02-27T17:25:44.703Z" }, + { url = "https://files.pythonhosted.org/packages/0c/09/9a3aa7c5f07d3e5cc57bb750d12a124ffa72c273a87164bd848f9ac5cc14/hf_xet-1.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3155a02e083aa21fd733a7485c7c36025e49d5975c8d6bda0453d224dd0b0ac4", size = 4154535, upload-time = "2026-02-27T17:26:05.207Z" }, + { url = "https://files.pythonhosted.org/packages/ae/e0/831f7fa6d90cb47a230bc23284b502c700e1483bbe459437b3844cdc0776/hf_xet-1.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:91b1dc03c31cbf733d35dc03df7c5353686233d86af045e716f1e0ea4a2673cf", size = 4393891, upload-time = "2026-02-27T17:26:06.607Z" }, + { url = "https://files.pythonhosted.org/packages/ab/96/6ed472fdce7f8b70f5da6e3f05be76816a610063003bfd6d9cea0bbb58a3/hf_xet-1.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:211f30098512d95e85ad03ae63bd7dd2c4df476558a5095d09f9e38e78cbf674", size = 3637583, upload-time = "2026-02-27T17:26:17.349Z" }, + { url = "https://files.pythonhosted.org/packages/8b/e8/a069edc4570b3f8e123c0b80fadc94530f3d7b01394e1fc1bb223339366c/hf_xet-1.3.2-cp314-cp314t-win_arm64.whl", hash = "sha256:4a6817c41de7c48ed9270da0b02849347e089c5ece9a0e72ae4f4b3a57617f82", size = 3497977, upload-time = "2026-02-27T17:26:14.966Z" }, + { url = "https://files.pythonhosted.org/packages/d8/28/dbb024e2e3907f6f3052847ca7d1a2f7a3972fafcd53ff79018977fcb3e4/hf_xet-1.3.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f93b7595f1d8fefddfede775c18b5c9256757824f7f6832930b49858483cd56f", size = 3763961, upload-time = "2026-02-27T17:25:52.537Z" }, + { url = "https://files.pythonhosted.org/packages/e4/71/b99aed3823c9d1795e4865cf437d651097356a3f38c7d5877e4ac544b8e4/hf_xet-1.3.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a85d3d43743174393afe27835bde0cd146e652b5fcfdbcd624602daef2ef3259", size = 3526171, upload-time = "2026-02-27T17:25:50.968Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ca/907890ce6ef5598b5920514f255ed0a65f558f820515b18db75a51b2f878/hf_xet-1.3.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7c2a054a97c44e136b1f7f5a78f12b3efffdf2eed3abc6746fc5ea4b39511633", size = 4180750, upload-time = "2026-02-27T17:25:43.125Z" }, + { url = "https://files.pythonhosted.org/packages/8c/ad/bc7f41f87173d51d0bce497b171c4ee0cbde1eed2d7b4216db5d0ada9f50/hf_xet-1.3.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:06b724a361f670ae557836e57801b82c75b534812e351a87a2c739f77d1e0635", size = 3961035, upload-time = "2026-02-27T17:25:41.837Z" }, + { url = "https://files.pythonhosted.org/packages/73/38/600f4dda40c4a33133404d9fe644f1d35ff2d9babb4d0435c646c63dd107/hf_xet-1.3.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:305f5489d7241a47e0458ef49334be02411d1d0f480846363c1c8084ed9916f7", size = 4161378, upload-time = "2026-02-27T17:26:00.365Z" }, + { url = "https://files.pythonhosted.org/packages/00/b3/7bc1ff91d1ac18420b7ad1e169b618b27c00001b96310a89f8a9294fe509/hf_xet-1.3.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:06cdbde243c85f39a63b28e9034321399c507bcd5e7befdd17ed2ccc06dfe14e", size = 4398020, upload-time = "2026-02-27T17:26:03.977Z" }, + { url = "https://files.pythonhosted.org/packages/2b/0b/99bfd948a3ed3620ab709276df3ad3710dcea61976918cce8706502927af/hf_xet-1.3.2-cp37-abi3-win_amd64.whl", hash = "sha256:9298b47cce6037b7045ae41482e703c471ce36b52e73e49f71226d2e8e5685a1", size = 3641624, upload-time = "2026-02-27T17:26:13.542Z" }, + { url = "https://files.pythonhosted.org/packages/cc/02/9a6e4ca1f3f73a164c0cd48e41b3cc56585dcc37e809250de443d673266f/hf_xet-1.3.2-cp37-abi3-win_arm64.whl", hash = "sha256:83d8ec273136171431833a6957e8f3af496bee227a0fe47c7b8b39c106d1749a", size = 3503976, upload-time = "2026-02-27T17:26:12.123Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tqdm" }, + { name = "typer" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/76/b5efb3033d8499b17f9386beaf60f64c461798e1ee16d10bc9c0077beba5/huggingface_hub-1.5.0.tar.gz", hash = "sha256:f281838db29265880fb543de7a23b0f81d3504675de82044307ea3c6c62f799d", size = 695872, upload-time = "2026-02-26T15:35:32.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/74/2bc951622e2dbba1af9a460d93c51d15e458becd486e62c29cc0ccb08178/huggingface_hub-1.5.0-py3-none-any.whl", hash = "sha256:c9c0b3ab95a777fc91666111f3b3ede71c0cdced3614c553a64e98920585c4ee", size = 596261, upload-time = "2026-02-26T15:35:31.1Z" }, +] + +[[package]] +name = "hydra-core" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "omegaconf" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/8e/07e42bc434a847154083b315779b0a81d567154504624e181caf2c71cd98/hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", size = 3263494, upload-time = "2023-02-23T18:33:43.03Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, +] + [[package]] name = "idna" version = "3.11" @@ -1208,6 +1345,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/ae/44c4a6a4cbb496d93c6257954260fe3a6e91b7bed2240e5dad2a717f5111/markdown-3.9-py3-none-any.whl", hash = "sha256:9f4d91ed810864ea88a6f32c07ba8bee1346c0cc1f6b1f9f6c822f2a9667d280", size = 107441, upload-time = "2025-09-04T20:25:21.784Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -1346,6 +1495,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/cc/3fe688ff1355010937713164caacf9ed443675ac48a997bab6ed23b3f7c0/matplotlib-3.10.7-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91", size = 8693919, upload-time = "2025-10-09T00:27:58.41Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mergedeep" version = "1.3.4" @@ -1902,6 +2060,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, ] +[[package]] +name = "omegaconf" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, +] + [[package]] name = "orjson" version = "3.11.4" @@ -2642,6 +2813,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "14.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, +] + [[package]] name = "ruamel-yaml" version = "0.18.16" @@ -2885,6 +3069,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -3162,6 +3355,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/b5/b0d3d8b901b6a04ca38df5e24c27e53afb15b93624d7fd7d658c7cd9352a/triton-3.5.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bac7f7d959ad0f48c0e97d6643a1cc0fd5786fe61cb1f83b537c6b2d54776478", size = 170582192, upload-time = "2025-11-11T17:41:23.963Z" }, ] +[[package]] +name = "typer" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 4aea6509f401f91174d727036ce85e51c7efee52 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 13:52:28 -0500 Subject: [PATCH 02/11] Name HF uploads by epoch and fix callback hook ordering Upload last.ckpt as last_epoch{N}.ckpt so previous versions are preserved across epochs. Switch from on_train_epoch_end to on_validation_end so checkpoints exist when the upload runs. Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 0934dcb9..5579c47e 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -46,9 +46,12 @@ def _save_manifest(self) -> None: with Path.open(self.manifest_path, "w") as f: json.dump(self._manifest, f, indent=2) - def _queue_checkpoint(self, ckpt_file: Path, epoch: int) -> None: + def _queue_checkpoint( + self, ckpt_file: Path, epoch: int, *, path_in_repo: str | None = None + ) -> None: entry = { "path": str(ckpt_file), + "path_in_repo": path_in_repo or ckpt_file.name, "epoch": epoch, "repo_id": self.repo_id, "uploaded": False, @@ -57,7 +60,9 @@ def _queue_checkpoint(self, ckpt_file: Path, epoch: int) -> None: self._save_manifest() logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name) - def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 + def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 + if trainer.sanity_checking: + return epoch = trainer.current_epoch if (epoch + 1) % self.every_n_epochs != 0: return @@ -68,7 +73,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 if not last_ckpt.exists(): return - self._queue_checkpoint(last_ckpt, epoch) + self._queue_checkpoint( + last_ckpt, epoch, path_in_repo=f"last_epoch{epoch:03d}.ckpt" + ) if self.upload_immediate: _upload_single(self._manifest[-1]) @@ -106,8 +113,11 @@ def _upload_single(entry: dict) -> None: if not path.exists(): logger.warning("Checkpoint file not found, skipping: %s", path) return + path_in_repo = entry.get("path_in_repo", path.name) upload_file( - path_or_fileobj=str(path), path_in_repo=path.name, repo_id=entry["repo_id"] + path_or_fileobj=str(path), + path_in_repo=path_in_repo, + repo_id=entry["repo_id"], ) entry["uploaded"] = True logger.info("Uploaded %s to %s", path.name, entry["repo_id"]) From 13045941fcf219f58443d07e08c7fbae7009c9d5 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 15:30:24 -0500 Subject: [PATCH 03/11] Address comments - Fix potential NameError in _upload_single by moving path assignment above try - Remove unused `private` config key (repo must already exist) - Replace epoch=-1 magic value with None - Only call _save_manifest() in on_train_end when immediate uploads occurred - Raise SystemExit instead of silent logger.error in hf_push for missing manifest Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 15 ++++++++------- src/electrai/configs/MP/config_resnet.yaml | 1 - src/electrai/configs/MP/config_resunet.yaml | 1 - 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 5579c47e..2543e7d9 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -30,7 +30,6 @@ def __init__(self, cfg: SimpleNamespace) -> None: self.repo_id: str = hf["repo_id"] self.every_n_epochs: int = hf.get("upload_every_n_epochs", 5) self.upload_immediate: bool = hf.get("upload_immediate", False) - self.private: bool = hf.get("private", True) self.ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints")) self.manifest_path = self.ckpt_path / MANIFEST_FILENAME self._manifest: list[dict] = [] @@ -47,7 +46,7 @@ def _save_manifest(self) -> None: json.dump(self._manifest, f, indent=2) def _queue_checkpoint( - self, ckpt_file: Path, epoch: int, *, path_in_repo: str | None = None + self, ckpt_file: Path, epoch: int | None, *, path_in_repo: str | None = None ) -> None: entry = { "path": str(ckpt_file), @@ -86,12 +85,15 @@ def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002 return # Queue best checkpoints that haven't been queued yet queued_paths = {e["path"] for e in self._manifest} + had_immediate = False for ckpt_file in self.ckpt_path.glob("ckpt_*.ckpt"): if str(ckpt_file) not in queued_paths: - self._queue_checkpoint(ckpt_file, epoch=-1) + self._queue_checkpoint(ckpt_file, epoch=None) if self.upload_immediate: _upload_single(self._manifest[-1]) - self._save_manifest() + had_immediate = True + if had_immediate: + self._save_manifest() pending = sum(1 for e in self._manifest if not e["uploaded"]) if pending: @@ -106,10 +108,10 @@ def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002 def _upload_single(entry: dict) -> None: """Attempt to upload a single checkpoint. Logs errors, never raises.""" + path = Path(entry["path"]) try: from huggingface_hub import upload_file - path = Path(entry["path"]) if not path.exists(): logger.warning("Checkpoint file not found, skipping: %s", path) return @@ -137,8 +139,7 @@ def hf_push(ckpt_path: str) -> None: ckpt_dir = Path(ckpt_path) manifest_path = ckpt_dir / MANIFEST_FILENAME if not manifest_path.exists(): - logger.error("No manifest found at %s", manifest_path) - return + raise SystemExit(f"No manifest found at {manifest_path}") with Path.open(manifest_path) as f: manifest = json.load(f) diff --git a/src/electrai/configs/MP/config_resnet.yaml b/src/electrai/configs/MP/config_resnet.yaml index 343d2c44..834ba84d 100644 --- a/src/electrai/configs/MP/config_resnet.yaml +++ b/src/electrai/configs/MP/config_resnet.yaml @@ -48,7 +48,6 @@ ckpt_path: ./checkpoints # repo_id: your-username/your-repo # must already exist on HF # upload_every_n_epochs: 5 # upload_immediate: false # set true on nodes with internet access -# private: true # test the model # save_pred: true diff --git a/src/electrai/configs/MP/config_resunet.yaml b/src/electrai/configs/MP/config_resunet.yaml index 73b08bf9..7fef6ff4 100644 --- a/src/electrai/configs/MP/config_resunet.yaml +++ b/src/electrai/configs/MP/config_resunet.yaml @@ -48,7 +48,6 @@ ckpt_path: ./checkpoints # repo_id: your-username/your-repo # must already exist on HF # upload_every_n_epochs: 5 # upload_immediate: false # set true on nodes with internet access -# private: true # test the model # save_pred: true From 942636e677140b79dfd0624a59578d2b86f117d0 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 15:44:34 -0500 Subject: [PATCH 04/11] Update src/electrai/callbacks/hf_upload.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/electrai/callbacks/hf_upload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 2543e7d9..f4d95aff 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -20,7 +20,7 @@ class HuggingFaceCallback(Callback): On clusters without internet (e.g. Princeton Della), checkpoints are queued in a JSON manifest and uploaded later via ``electrai hf-push``. - When ``hf_upload_immediate`` is True, uploads are attempted inline + When ``hf.upload_immediate`` is True, uploads are attempted inline (failures are logged but never crash training). """ From a03a169a454c83dc06258900d1fbbec8e3e4245b Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 15:46:36 -0500 Subject: [PATCH 05/11] Update src/electrai/callbacks/hf_upload.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/electrai/callbacks/hf_upload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index f4d95aff..f70f3431 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -37,12 +37,12 @@ def __init__(self, cfg: SimpleNamespace) -> None: def _load_existing_manifest(self) -> None: if self.manifest_path.exists(): - with Path.open(self.manifest_path) as f: + with self.manifest_path.open(encoding="utf-8") as f: self._manifest = json.load(f) def _save_manifest(self) -> None: self.ckpt_path.mkdir(parents=True, exist_ok=True) - with Path.open(self.manifest_path, "w") as f: + with self.manifest_path.open("w", encoding="utf-8") as f: json.dump(self._manifest, f, indent=2) def _queue_checkpoint( From 6d3937b2492f63fbad00b6122c5a6249950a79a6 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Thu, 5 Mar 2026 15:56:15 -0500 Subject: [PATCH 06/11] Address comments - Use 1-indexed epoch in uploaded filename to match upload cadence - Copy last.ckpt to stable epoch-specific file at queue time so deferred hf-push uploads the correct snapshot after last.ckpt is overwritten Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index f70f3431..010272b3 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -2,6 +2,7 @@ import json import logging +import shutil from pathlib import Path from typing import TYPE_CHECKING @@ -72,9 +73,13 @@ def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 if not last_ckpt.exists(): return - self._queue_checkpoint( - last_ckpt, epoch, path_in_repo=f"last_epoch{epoch:03d}.ckpt" - ) + # Copy to a stable filename so later hf-push uploads the correct + # snapshot even after last.ckpt is overwritten by subsequent epochs. + stable_name = f"last_epoch{epoch + 1:03d}.ckpt" + stable_path = self.ckpt_path / stable_name + shutil.copy2(last_ckpt, stable_path) + + self._queue_checkpoint(stable_path, epoch, path_in_repo=stable_name) if self.upload_immediate: _upload_single(self._manifest[-1]) From dd6af61c5b65f5b1938b9fb9ead1101a5aac3235 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Mon, 9 Mar 2026 10:26:23 -0400 Subject: [PATCH 07/11] Address comments - Remove unreachable else branch in main.py (argparse handles unknown commands) - Re-export HuggingFaceCallback from callbacks __init__ for ergonomic imports Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/__init__.py | 4 ++++ src/electrai/entrypoints/main.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/electrai/callbacks/__init__.py b/src/electrai/callbacks/__init__.py index 9d48db4f..9d84fafb 100644 --- a/src/electrai/callbacks/__init__.py +++ b/src/electrai/callbacks/__init__.py @@ -1 +1,5 @@ from __future__ import annotations + +from electrai.callbacks.hf_upload import HuggingFaceCallback + +__all__ = ["HuggingFaceCallback"] diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index 8c61b7cb..2acad67f 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -49,8 +49,6 @@ def main() -> None: from electrai.callbacks.hf_upload import hf_push hf_push(args.ckpt_path) - else: - raise ValueError(f"Unknown command: {args.command}") if __name__ == "__main__": From cb60f703c2e0f8fde3b750841cef433c7a4dd9e4 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Mon, 9 Mar 2026 11:08:25 -0400 Subject: [PATCH 08/11] Address comments - Delete stable copies after successful immediate upload - Add --clean flag to hf-push to delete local copies after deferred upload - Add logging.basicConfig to main() for CLI log visibility - Fix Path.open() style and add encoding="utf-8" in hf_push Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 10 +++++++--- src/electrai/entrypoints/main.py | 9 ++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 010272b3..13729dea 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -83,6 +83,8 @@ def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 if self.upload_immediate: _upload_single(self._manifest[-1]) + if self._manifest[-1]["uploaded"]: + stable_path.unlink(missing_ok=True) self._save_manifest() def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002 @@ -136,7 +138,7 @@ def _upload_single(entry: dict) -> None: ) -def hf_push(ckpt_path: str) -> None: +def hf_push(ckpt_path: str, *, clean: bool = False) -> None: """Upload pending checkpoints from a manifest file. Run this from a login node or machine with internet access. @@ -146,7 +148,7 @@ def hf_push(ckpt_path: str) -> None: if not manifest_path.exists(): raise SystemExit(f"No manifest found at {manifest_path}") - with Path.open(manifest_path) as f: + with manifest_path.open(encoding="utf-8") as f: manifest = json.load(f) pending = [e for e in manifest if not e["uploaded"]] @@ -157,8 +159,10 @@ def hf_push(ckpt_path: str) -> None: logger.info("Uploading %d pending checkpoint(s)...", len(pending)) for entry in pending: _upload_single(entry) + if clean and entry["uploaded"]: + Path(entry["path"]).unlink(missing_ok=True) - with Path.open(manifest_path, "w") as f: + with manifest_path.open("w", encoding="utf-8") as f: json.dump(manifest, f, indent=2) still_pending = sum(1 for e in manifest if not e["uploaded"]) diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index 2acad67f..dd8aafe0 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import logging import torch @@ -23,6 +24,7 @@ def main() -> None: RuntimeError if no command was input """ + logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Electrai Entry Point") subparsers = parser.add_subparsers(dest="command", required=True) @@ -38,6 +40,11 @@ def main() -> None: hf_push_parser.add_argument( "--ckpt-path", type=str, required=True, help="Path to checkpoint directory" ) + hf_push_parser.add_argument( + "--clean", + action="store_true", + help="Delete local copies after successful upload", + ) args = parser.parse_args() @@ -48,7 +55,7 @@ def main() -> None: elif args.command == "hf-push": from electrai.callbacks.hf_upload import hf_push - hf_push(args.ckpt_path) + hf_push(args.ckpt_path, clean=args.clean) if __name__ == "__main__": From 161b8847031860e42a948d8290ff098d1483ed2a Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Mon, 9 Mar 2026 11:16:42 -0400 Subject: [PATCH 09/11] Clarify --clean help text to mention best-model checkpoints Co-Authored-By: Claude Opus 4.6 --- src/electrai/entrypoints/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/electrai/entrypoints/main.py b/src/electrai/entrypoints/main.py index dd8aafe0..f8f9e791 100644 --- a/src/electrai/entrypoints/main.py +++ b/src/electrai/entrypoints/main.py @@ -43,7 +43,7 @@ def main() -> None: hf_push_parser.add_argument( "--clean", action="store_true", - help="Delete local copies after successful upload", + help="Delete local checkpoint files after successful upload (includes best-model checkpoints)", ) args = parser.parse_args() From f782897b5234b9e77e43af81b93eca9b0cab3bf5 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 18 Mar 2026 10:51:01 -0400 Subject: [PATCH 10/11] Fix checkpoint hook ordering and `ImportError` handling in HF callback - Switch from `on_validation_end` to `on_train_epoch_end` so `last.ckpt` is current (Lightning reorders `ModelCheckpoint` to run last in `on_validation_end`, making the copy stale). - Separate `ImportError` from upload failures in `_upload_single` with a clear "install huggingface-hub" message. Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 13729dea..1aeb23da 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -60,7 +60,7 @@ def _queue_checkpoint( self._save_manifest() logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name) - def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 + def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 if trainer.sanity_checking: return epoch = trainer.current_epoch @@ -118,7 +118,13 @@ def _upload_single(entry: dict) -> None: path = Path(entry["path"]) try: from huggingface_hub import upload_file - + except ImportError: + logger.warning( + "huggingface-hub is not installed. " + "Run 'pip install huggingface-hub' to enable uploads." + ) + return + try: if not path.exists(): logger.warning("Checkpoint file not found, skipping: %s", path) return From 1b4078d2d9de78cafbf516222745e0a9e71e27b3 Mon Sep 17 00:00:00 2001 From: Betsy Cannon Date: Tue, 7 Apr 2026 17:22:01 -0400 Subject: [PATCH 11/11] Fix stale checkpoint and hf-push exit code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - on_validation_end + trainer.save_checkpoint: previously on_train_epoch_end copied last.ckpt, which is saved by ModelCheckpoint during on_validation_end (after our hook) — so last.ckpt was always from the previous epoch. Save the current state explicitly to avoid any dependency on callback ordering. - hf_push: preflight the huggingface_hub import so the CLI fails fast with a clear install message instead of silently looping and exiting 0. Co-Authored-By: Claude Opus 4.6 --- src/electrai/callbacks/hf_upload.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/electrai/callbacks/hf_upload.py b/src/electrai/callbacks/hf_upload.py index 1aeb23da..b8ec2e20 100644 --- a/src/electrai/callbacks/hf_upload.py +++ b/src/electrai/callbacks/hf_upload.py @@ -2,7 +2,6 @@ import json import logging -import shutil from pathlib import Path from typing import TYPE_CHECKING @@ -60,7 +59,7 @@ def _queue_checkpoint( self._save_manifest() logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name) - def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 + def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 if trainer.sanity_checking: return epoch = trainer.current_epoch @@ -69,15 +68,12 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: # noqa: ARG002 if trainer.global_rank != 0: return - last_ckpt = self.ckpt_path / "last.ckpt" - if not last_ckpt.exists(): - return - - # Copy to a stable filename so later hf-push uploads the correct - # snapshot even after last.ckpt is overwritten by subsequent epochs. + # Save the current state to a stable epoch-specific file. This is + # independent of ModelCheckpoint's last.ckpt (which Lightning reorders + # to run after us in this hook, so last.ckpt would still be stale). stable_name = f"last_epoch{epoch + 1:03d}.ckpt" stable_path = self.ckpt_path / stable_name - shutil.copy2(last_ckpt, stable_path) + trainer.save_checkpoint(stable_path) self._queue_checkpoint(stable_path, epoch, path_in_repo=stable_name) @@ -149,6 +145,14 @@ def hf_push(ckpt_path: str, *, clean: bool = False) -> None: Run this from a login node or machine with internet access. """ + try: + import huggingface_hub # noqa: F401 + except ImportError as e: + raise SystemExit( + "huggingface-hub is not installed. " + "Run 'uv sync --extra hf' to enable uploads." + ) from e + ckpt_dir = Path(ckpt_path) manifest_path = ckpt_dir / MANIFEST_FILENAME if not manifest_path.exists():