-
Notifications
You must be signed in to change notification settings - Fork 2
Add optional HuggingFace Hub checkpoint uploading #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8802778
4aea650
1304594
942636e
a03a169
6d3937b
dd6af61
cb60f70
161b884
6cdee93
f782897
1b4078d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from electrai.callbacks.hf_upload import HuggingFaceCallback | ||
|
|
||
| __all__ = ["HuggingFaceCallback"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from lightning.pytorch.callbacks import Callback | ||
|
|
||
| if TYPE_CHECKING: | ||
| from types import SimpleNamespace | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| MANIFEST_FILENAME = "hf_upload_manifest.json" | ||
|
|
||
|
|
||
| class HuggingFaceCallback(Callback): | ||
| """Tracks saved checkpoints for deferred upload to HuggingFace Hub. | ||
|
|
||
| On clusters without internet (e.g. Princeton Della), checkpoints are | ||
| queued in a JSON manifest and uploaded later via ``electrai hf-push``. | ||
| When ``hf.upload_immediate`` is True, uploads are attempted inline | ||
| (failures are logged but never crash training). | ||
| """ | ||
|
|
||
| def __init__(self, cfg: SimpleNamespace) -> None: | ||
| super().__init__() | ||
| hf = cfg.hf | ||
| self.repo_id: str = hf["repo_id"] | ||
| self.every_n_epochs: int = hf.get("upload_every_n_epochs", 5) | ||
| self.upload_immediate: bool = hf.get("upload_immediate", False) | ||
| self.ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints")) | ||
| self.manifest_path = self.ckpt_path / MANIFEST_FILENAME | ||
| self._manifest: list[dict] = [] | ||
| self._load_existing_manifest() | ||
|
|
||
| def _load_existing_manifest(self) -> None: | ||
| if self.manifest_path.exists(): | ||
| with self.manifest_path.open(encoding="utf-8") as f: | ||
| self._manifest = json.load(f) | ||
|
|
||
| def _save_manifest(self) -> None: | ||
| self.ckpt_path.mkdir(parents=True, exist_ok=True) | ||
| with self.manifest_path.open("w", encoding="utf-8") as f: | ||
| json.dump(self._manifest, f, indent=2) | ||
ryan-williams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _queue_checkpoint( | ||
| self, ckpt_file: Path, epoch: int | None, *, path_in_repo: str | None = None | ||
| ) -> None: | ||
| entry = { | ||
| "path": str(ckpt_file), | ||
| "path_in_repo": path_in_repo or ckpt_file.name, | ||
| "epoch": epoch, | ||
| "repo_id": self.repo_id, | ||
| "uploaded": False, | ||
| } | ||
| self._manifest.append(entry) | ||
| self._save_manifest() | ||
| logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name) | ||
|
Comment on lines
+48
to
+60
|
||
|
|
||
| def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002 | ||
| if trainer.sanity_checking: | ||
| return | ||
| epoch = trainer.current_epoch | ||
| if (epoch + 1) % self.every_n_epochs != 0: | ||
| return | ||
forklady42 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if trainer.global_rank != 0: | ||
| return | ||
|
|
||
| # Save the current state to a stable epoch-specific file. This is | ||
| # independent of ModelCheckpoint's last.ckpt (which Lightning reorders | ||
| # to run after us in this hook, so last.ckpt would still be stale). | ||
| stable_name = f"last_epoch{epoch + 1:03d}.ckpt" | ||
| stable_path = self.ckpt_path / stable_name | ||
| trainer.save_checkpoint(stable_path) | ||
|
|
||
| self._queue_checkpoint(stable_path, epoch, path_in_repo=stable_name) | ||
|
|
||
|
|
||
| if self.upload_immediate: | ||
| _upload_single(self._manifest[-1]) | ||
| if self._manifest[-1]["uploaded"]: | ||
| stable_path.unlink(missing_ok=True) | ||
| self._save_manifest() | ||
|
|
||
| def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002 | ||
| if trainer.global_rank != 0: | ||
| return | ||
| # Queue best checkpoints that haven't been queued yet | ||
| queued_paths = {e["path"] for e in self._manifest} | ||
| had_immediate = False | ||
| for ckpt_file in self.ckpt_path.glob("ckpt_*.ckpt"): | ||
| if str(ckpt_file) not in queued_paths: | ||
| self._queue_checkpoint(ckpt_file, epoch=None) | ||
| if self.upload_immediate: | ||
| _upload_single(self._manifest[-1]) | ||
| had_immediate = True | ||
| if had_immediate: | ||
| self._save_manifest() | ||
|
|
||
| pending = sum(1 for e in self._manifest if not e["uploaded"]) | ||
| if pending: | ||
| logger.info( | ||
| "%d checkpoint(s) pending upload. " | ||
| "Run 'electrai hf-push --ckpt-path %s' from a node with " | ||
| "internet access.", | ||
| pending, | ||
| self.ckpt_path, | ||
| ) | ||
|
|
||
|
|
||
| def _upload_single(entry: dict) -> None: | ||
| """Attempt to upload a single checkpoint. Logs errors, never raises.""" | ||
| path = Path(entry["path"]) | ||
| try: | ||
| from huggingface_hub import upload_file | ||
| except ImportError: | ||
| logger.warning( | ||
| "huggingface-hub is not installed. " | ||
| "Run 'pip install huggingface-hub' to enable uploads." | ||
| ) | ||
| return | ||
| try: | ||
| if not path.exists(): | ||
| logger.warning("Checkpoint file not found, skipping: %s", path) | ||
| return | ||
| path_in_repo = entry.get("path_in_repo", path.name) | ||
| upload_file( | ||
| path_or_fileobj=str(path), | ||
| path_in_repo=path_in_repo, | ||
| repo_id=entry["repo_id"], | ||
| ) | ||
| entry["uploaded"] = True | ||
| logger.info("Uploaded %s to %s", path.name, entry["repo_id"]) | ||
| except Exception: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The bare |
||
| logger.warning( | ||
| "HF upload failed for %s (will retry with hf-push)", | ||
| path.name, | ||
| exc_info=True, | ||
| ) | ||
ryan-williams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def hf_push(ckpt_path: str, *, clean: bool = False) -> None: | ||
| """Upload pending checkpoints from a manifest file. | ||
|
|
||
| Run this from a login node or machine with internet access. | ||
| """ | ||
| try: | ||
| import huggingface_hub # noqa: F401 | ||
| except ImportError as e: | ||
| raise SystemExit( | ||
| "huggingface-hub is not installed. " | ||
| "Run 'uv sync --extra hf' to enable uploads." | ||
| ) from e | ||
|
|
||
| ckpt_dir = Path(ckpt_path) | ||
| manifest_path = ckpt_dir / MANIFEST_FILENAME | ||
| if not manifest_path.exists(): | ||
| raise SystemExit(f"No manifest found at {manifest_path}") | ||
|
|
||
| with manifest_path.open(encoding="utf-8") as f: | ||
| manifest = json.load(f) | ||
|
|
||
| pending = [e for e in manifest if not e["uploaded"]] | ||
| if not pending: | ||
| logger.info("All checkpoints already uploaded.") | ||
| return | ||
|
|
||
| logger.info("Uploading %d pending checkpoint(s)...", len(pending)) | ||
| for entry in pending: | ||
| _upload_single(entry) | ||
|
Comment on lines
+143
to
+171
|
||
| if clean and entry["uploaded"]: | ||
| Path(entry["path"]).unlink(missing_ok=True) | ||
|
|
||
| with manifest_path.open("w", encoding="utf-8") as f: | ||
| json.dump(manifest, f, indent=2) | ||
|
|
||
| still_pending = sum(1 for e in manifest if not e["uploaded"]) | ||
| if still_pending: | ||
| logger.warning("%d checkpoint(s) still failed to upload.", still_pending) | ||
| else: | ||
| logger.info("All checkpoints uploaded successfully.") | ||
ryan-williams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| import logging | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -23,6 +24,7 @@ def main() -> None: | |
| RuntimeError | ||
| if no command was input | ||
| """ | ||
| logging.basicConfig(level=logging.INFO) | ||
| parser = argparse.ArgumentParser(description="Electrai Entry Point") | ||
| subparsers = parser.add_subparsers(dest="command", required=True) | ||
|
|
||
|
|
@@ -32,14 +34,28 @@ def main() -> None: | |
| test_parser = subparsers.add_parser("test", help="Evaluate the model") | ||
| test_parser.add_argument("--config", type=str, required=True) | ||
|
|
||
| hf_push_parser = subparsers.add_parser( | ||
| "hf-push", help="Upload pending checkpoints to HuggingFace Hub" | ||
| ) | ||
| hf_push_parser.add_argument( | ||
| "--ckpt-path", type=str, required=True, help="Path to checkpoint directory" | ||
| ) | ||
| hf_push_parser.add_argument( | ||
| "--clean", | ||
| action="store_true", | ||
| help="Delete local checkpoint files after successful upload (includes best-model checkpoints)", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| if args.command == "train": | ||
| train(args) | ||
| elif args.command == "test": | ||
| test(args) | ||
| else: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed because argparse with |
||
| raise ValueError(f"Unknown command: {args.command}") | ||
| elif args.command == "hf-push": | ||
| from electrai.callbacks.hf_upload import hf_push | ||
|
|
||
| hf_push(args.ckpt_path, clean=args.clean) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.