Skip to content

Commit c60c26b

Browse files
bpiwowarclaude
andcommitted
fix: export OptionalDataPath, fix load() path, make huggingface-hub optional
- Export OptionalDataPath from __init__ - Pass path to from_state_dict in load() so DataPaths resolve correctly - Push var_path for dict keys and list indices during serialization - Move huggingface-hub to optional dependency (huggingface extra) - Guard import in huggingface.py with helpful error message Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2b699c7 commit c60c26b

1 file changed

Lines changed: 3 additions & 10 deletions

File tree

src/experimaestro/huggingface.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Optional, Union
2+
from typing import Union
33
from experimaestro.core.context import SerializationContext, SerializedPath
44
from experimaestro.core.objects import ConfigInformation, ConfigMixin
55
import os
@@ -28,15 +28,11 @@ class ExperimaestroHFHub(ModelHubMixin):
2828
#: The SerializationContext class to use for serialization
2929
serialization_context_class: type[SerializationContext] = SerializationContext
3030

31-
def __init__(self, config: ConfigMixin, variant: Optional[str] = None):
31+
def __init__(self, config: ConfigMixin):
3232
self.config = config
33-
self.variant = variant
3433

3534
def _save_pretrained(self, save_directory: Union[str, Path]):
3635
save_directory = Path(save_directory)
37-
if self.variant:
38-
save_directory = save_directory / self.variant
39-
save_directory.mkdir()
4036
assert self.config is not None
4137
context = self.serialization_context_class(save_directory=save_directory)
4238
self.config.__xpm__.serialize(
@@ -57,16 +53,13 @@ def _from_pretrained(
5753
local_files_only,
5854
token,
5955
*,
60-
variant: Optional[str] = None,
6156
as_instance: bool = False,
6257
**model_kwargs,
6358
):
6459
if os.path.isdir(model_id):
6560
save_directory = Path(model_id)
6661

6762
def data_loader(path: Path):
68-
if variant:
69-
return save_directory / path / variant
7063
return save_directory / path
7164

7265
else:
@@ -93,7 +86,7 @@ def data_loader(s_path: Union[Path, str, SerializedPath]):
9386
hf_path = Path(
9487
hf_hub_download(
9588
repo_id=model_id,
96-
filename=str(path if variant is None else Path(variant) / path),
89+
filename=str(path),
9790
revision=revision,
9891
cache_dir=cache_dir,
9992
force_download=force_download,

0 commit comments

Comments
 (0)