Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/source/core_concepts/disk_format.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ folder
│ │ └── sample_yyyyyyyyy
│ └── infos.yaml
└── problem_definition
├── problem_infos.yaml
└── split.json (or split.csv for <=0.1.7)
└── problem_infos.yaml
```

- `dataset/samples/`: one directory per {py:class}`~plaid.containers.sample.Sample`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/core_concepts/feature_identifiers.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Legacy name-based methods (e.g., `add_input_scalars_names`) are deprecated; pref
- Always include enough context to disambiguate a feature. For fields/nodes on multiple bases/zones/times, set all relevant keys.
- Use {py:meth}`~plaid.containers.sample.Sample.get_all_features_identifiers()` to introspect what identifiers exist in a sample.
- Use sets to deduplicate identifiers safely: `set(list_of_identifiers)`.
- When authoring problem definitions on disk, {py:meth}`~plaid.problem_definition.ProblemDefinition._save_to_dir_` persists identifiers under `problem_definition/problem_infos.yaml` (keys `input_features` and `output_features`).
- When authoring problem definitions on disk, {py:meth}`~plaid.problem_definition.ProblemDefinition.save_to_dir` persists identifiers under `problem_definition/problem_infos.yaml` (keys `input_features` and `output_features`).

## See also

Expand Down
5 changes: 4 additions & 1 deletion docs/source/core_concepts/problem_definition.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ pb.add_out_feature_identifier(FeatureIdentifier({
splits = {"train": [0, 1, 2], "test": [3, 4]}
pb.set_split(splits)

pb._save_to_dir_("problem_definition")
pb.save_to_dir("problem_definition")

# later
pb2 = ProblemDefinition.load("problem_definition")
```

{py:class}`~plaid.problem_definition.ProblemDefinition` supports filtering helpers to intersect existing inputs/outputs with a candidate list of identifiers.
8 changes: 4 additions & 4 deletions examples/post/bisect_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# Load PLAID datasets and problem metadata objects
ref_ds = Dataset(dataset_directory / "dataset_ref")
pred_ds = Dataset(dataset_directory / "dataset_near_pred")
problem = ProblemDefinition(dataset_directory / "problem_definition")
problem = ProblemDefinition.load(dataset_directory / "problem_definition")

# Get output scalars from reference and prediction dataset
ref_out_scalars, pred_out_scalars, out_scalars_names = prepare_datasets(
Expand Down Expand Up @@ -98,7 +98,7 @@
# Load PLAID datasets and problem metadata objects
ref_path = Dataset(dataset_directory / "dataset_ref")
pred_path = Dataset(dataset_directory / "dataset_pred")
problem_path = ProblemDefinition(dataset_directory / "problem_definition")
problem_path = ProblemDefinition.load(dataset_directory / "problem_definition")

# Using PLAID objects to generate bisect plot on feature_2
plot_bisect(ref_path, pred_path, problem_path, "feature_2", "equal_bisect_plot")
Expand All @@ -114,7 +114,7 @@
# Mix
ref_path = dataset_directory / "dataset_ref"
pred_path = dataset_directory / "dataset_near_pred"
problem_path = ProblemDefinition(dataset_directory / "problem_definition")
problem_path = ProblemDefinition.load(dataset_directory / "problem_definition")

# Using scalar index and verbose option to generate bisect plot
scalar_index = 0
Expand All @@ -129,4 +129,4 @@

os.remove("converge_bisect_plot.png")
os.remove("differ_bisect_plot.png")
os.remove("equal_bisect_plot.png")
os.remove("equal_bisect_plot.png")
6 changes: 3 additions & 3 deletions examples/post/metrics_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# Load PLAID datasets and problem metadata objects
ref_ds = Dataset(dataset_directory / "dataset_ref")
pred_ds = Dataset(dataset_directory / "dataset_near_pred")
problem = ProblemDefinition(dataset_directory / "problem_definition")
problem = ProblemDefinition.load(dataset_directory / "problem_definition")

# Get output scalars from reference and prediction dataset
ref_out_scalars, pred_out_scalars, out_scalars_names = prepare_datasets(
Expand Down Expand Up @@ -102,7 +102,7 @@
# Load PLAID datasets and problem metadata objects
ref_ds = Dataset(dataset_directory / "dataset_ref")
pred_ds = Dataset(dataset_directory / "dataset_pred")
problem = ProblemDefinition(dataset_directory / "problem_definition")
problem = ProblemDefinition.load(dataset_directory / "problem_definition")

# Pretty print activated with verbose mode
metrics = compute_metrics(ref_ds, pred_ds, problem, "second_metrics", verbose=True)
Expand All @@ -123,4 +123,4 @@
pretty_metrics(dictionary)

os.remove("first_metrics.yaml")
os.remove("second_metrics.yaml")
os.remove("second_metrics.yaml")
5 changes: 2 additions & 3 deletions examples/problem_definition_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
# ### Load a ProblemDefinition from a directory via initialization

# %%
problem = ProblemDefinition(pb_def_save_fname)
problem = ProblemDefinition.load(pb_def_save_fname)
print(problem)

# %% [markdown]
Expand All @@ -168,6 +168,5 @@
# ### Load from a directory via a Dataset instance

# %%
problem = ProblemDefinition()
problem.load(pb_def_save_fname)
problem = ProblemDefinition.load(pb_def_save_fname)
print(problem)
31 changes: 21 additions & 10 deletions src/plaid/bridges/huggingface_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,7 @@ def load_problem_definition_from_hub(
with open(yaml_path, "r", encoding="utf-8") as f:
yaml_data = yaml.safe_load(f)

prob_def = ProblemDefinition()
prob_def._initialize_from_problem_infos_dict(yaml_data)
prob_def = ProblemDefinition.model_validate(yaml_data)

return prob_def

Expand Down Expand Up @@ -484,9 +483,7 @@ def load_problem_definition_from_disk(
Returns:
ProblemDefinition: The loaded problem definition.
"""
pb_def = ProblemDefinition()
pb_def._load_from_file_(Path(path) / Path("problem_definitions") / Path(name))
return pb_def
return ProblemDefinition.load(Path(path) / Path("problem_definitions") / Path(name))


def load_tree_struct_from_disk(
Expand Down Expand Up @@ -698,19 +695,33 @@ def huggingface_description_to_problem_definition(
problem_definition = ProblemDefinition()
for func, key in [
(problem_definition.set_task, "task"),
(problem_definition.set_score_function, "score_function"),
(problem_definition.set_split, "split"),
]:
if key in description:
func(description[key])

if "input_features" in description:
problem_definition.add_in_features_identifiers(description["input_features"])
if "output_features" in description:
problem_definition.add_out_features_identifiers(description["output_features"])
if "constant_features" in description:
problem_definition.add_constant_features_identifiers(
description["constant_features"]
)
legacy_keys = [
(problem_definition.add_input_scalars_names, "in_scalars_names"),
(problem_definition.add_output_scalars_names, "out_scalars_names"),
(problem_definition.add_input_fields_names, "in_fields_names"),
(problem_definition.add_output_fields_names, "out_fields_names"),
(problem_definition.add_input_timeseries_names, "in_timeseries_names"),
(problem_definition.add_output_timeseries_names, "out_timeseries_names"),
(problem_definition.add_input_meshes_names, "in_meshes_names"),
(problem_definition.add_output_meshes_names, "out_meshes_names"),
]:
try:
]
for func, key in legacy_keys:
if key in description:
func(description[key])
except KeyError:
logger.error(f"Could not retrieve key:'{key}' from description")
pass

return problem_definition

Expand Down
2 changes: 1 addition & 1 deletion src/plaid/post/bisect.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def plot_bisect(
if isinstance(pred_dataset, (str, Path)):
pred_dataset: Dataset = Dataset(pred_dataset)
if isinstance(problem_def, (str, Path)):
problem_def: ProblemDefinition = ProblemDefinition(problem_def)
problem_def: ProblemDefinition = ProblemDefinition.load(problem_def)

# Load the testing_set
# testing_set = problem_def.get_split("test")
Expand Down
2 changes: 1 addition & 1 deletion src/plaid/post/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def compute_metrics(
if isinstance(pred_dataset, (str, Path)):
pred_dataset: Dataset = Dataset(pred_dataset)
if isinstance(problem, (str, Path)):
problem: ProblemDefinition = ProblemDefinition(problem)
problem: ProblemDefinition = ProblemDefinition.load(problem)

### Get important formated values ###
problem_split = problem.get_split()
Expand Down
Loading
Loading