diff --git a/environment.yml b/environment.yml index 051ad2ea..5be05c2e 100644 --- a/environment.yml +++ b/environment.yml @@ -12,6 +12,7 @@ dependencies: #---# storage backends - pycgns-core - zarr + - webdataset #---# SciML - scikit-learn - datasets @@ -41,4 +42,4 @@ dependencies: - myst-nb - myst-parser - furo - - jupytext \ No newline at end of file + - jupytext diff --git a/fix_none_features_plan.md b/fix_none_features_plan.md new file mode 100644 index 00000000..17bb4098 --- /dev/null +++ b/fix_none_features_plan.md @@ -0,0 +1,123 @@ +# Plan de Travail: Résolution du Problème des Features None dans WebDataset + +## Problème Identifié + +**Symptôme:** `AssertionError: did you forget to specify the features arg?` dans `flat_dict_to_sample_dict` + +**Cause Racine:** +Lorsque `flat_cst` contient `feature_times=None` sans la `feature` correspondante, le merge avec `var_sample_dict` crée un dictionnaire déséquilibré: +- `row_tim` contient la clé extraite de `feature_times` +- `row_val` ne contient pas cette clé +- Le `zip(row_tim.items(), row_val.items())` pair des clés différentes + +**Exemple:** +```python +flat_dict = { + 'Global/global_1': [1,2,3], + 'Global/global_1_times': [0,0,-1], + 'Global/global_2_times': None, # ← Orpheline +} +# Après _split_dict: +# row_val = {'Global/global_1'} +# row_tim = {'Global/global_1', 'Global/global_2'} ← Déséquilibre! +``` + +## Pistes de Solution + +### Piste 1: Filtrage dans bridge.py (to_var_sample_dict) +**Approche:** Ajouter automatiquement `feature=None` quand `feature_times` existe + +**Avantages:** +- Fix localisé dans le bridge WebDataset +- Pas de changement aux autres backends + +**Inconvénients:** +- Ajoute artificiellement des clés None +- Peut créer confusion sur ce qui existe réellement + +**Implémentation:** +```python +if features is None: + result = dict(wds_sample) + # Pour chaque _times, s'assurer que la feature de base existe + for key in list(result.keys()): + if key.endswith("_times"): + base_feat = key[:-6] + if base_feat not in result: + result[base_feat] = None + return result +``` + +### Piste 2: Nettoyage dans Converter (reader.py) +**Approche:** Filtrer `flat_cst` avant le merge pour retirer les `_times` orphelines + +**Avantages:** +- Nettoie les données avant utilisation +- Fix applicable à tous les backends + +**Inconvénients:** +- Modifie le Converter (code partagé) +- Peut affecter autres backends si mal fait + +**Implémentation:** +```python +# Dans Converter.to_dict(), avant l'appel à to_sample_dict: +clean_flat_cst = {} +for key, val in self.flat_cst.items(): + if key.endswith("_times"): + base_key = key[:-6] + # Only keep _times if base feature will be in var_sample_dict + if base_key in self.variable_schema or val is not None: + clean_flat_cst[key] = val + else: + clean_flat_cst[key] = val +``` + +### Piste 3: Modification du preprocessing +**Approche:** Empêcher la création de `_times` orphelines dans flat_cst dès le preprocessing + +**Avantages:** +- Fix à la source du problème +- Données cohérentes dès la génération + +**Inconvénients:** +- Modification du preprocessing (code complexe) +- Risque d'impact sur autres backends + +**Implémentation:** +Modifier `preprocess_splits` dans `src/plaid/storage/common/preprocessor.py` pour exclure les `_times` des features constantes None du constant_schema. + +### Piste 4: Relaxer l'assertion dans flat_dict_to_sample_dict +**Approche:** Rendre `flat_dict_to_sample_dict` plus tolérant aux clés manquantes + +**Avantages:** +- Fix générique pour tous les backends +- Robustesse accrue du code + +**Inconvénients:** +- Modifie comportement existant +- Peut masquer d'autres bugs + +**Implémentation:** +```python +# Au lieu de zip strict, itérer sur row_tim et chercher dans row_val +for path_t, times_struc in row_tim.items(): + val = row_val.get(path_t, None) + # Traiter val même si None... +``` + +## Recommandation + +**Piste 1** est la plus simple et localisée. Elle résout le problème directement dans le bridge WebDataset sans affecter le reste du système. + +## Implémentation de la Piste 1 + +Modifier `src/plaid/storage/webdataset/bridge.py` pour s'assurer que toutes les `_times` ont leur feature de base correspondante, même si None. + +## Tests de Validation + +Après fix, vérifier: +1. `pytest tests/storage/test_storage.py::Test_Storage::test_webdataset` passe +2. `pytest tests/storage/test_storage.py::Test_Storage::test_registry` passe toujours +3. Aucune régression sur autres backends +4. Pre-commit hooks passent diff --git a/implementation_plan.md b/implementation_plan.md new file mode 100644 index 00000000..491d15d5 --- /dev/null +++ b/implementation_plan.md @@ -0,0 +1,521 @@ +# Implementation Plan: WebDataset Storage Backend for PLAID + +## [Overview] + +Implementation of a WebDataset storage backend for the PLAID library to provide tar-based, streaming-friendly dataset storage with seamless HuggingFace Hub integration. + +The WebDataset backend will be added as the fourth storage option alongside the existing cgns, hf_datasets, and zarr backends. WebDataset uses tar archives where samples sharing the same basename (stripped of extensions) belong together, making it ideal for streaming large physics datasets. This format aligns well with PLAID's architecture and provides efficient I/O for both local and cloud storage. + +The implementation follows the established backend pattern in PLAID: a module directory (`src/plaid/storage/webdataset/`) containing reader, writer, and bridge components, registered in the central registry, with full test coverage and documentation. The backend will support all standard PLAID operations: local disk save/load, HuggingFace Hub push/download, streaming access, and bidirectional conversion between WebDataset format and PLAID Sample objects. + +Key design considerations: +- Each PLAID sample becomes a set of files in a tar archive with shared basename (e.g., `sample_000000000.json`, `sample_000000000.npy`) +- Variable features stored as individual .npy files for efficient array storage +- Metadata and constant features stored in JSON format +- Split-based tar sharding for scalability +- Compatible with HuggingFace Hub's tar file hosting +- Streaming support via webdataset library's pipeline architecture + +## [Types] + +Type system changes to support WebDataset format and integration with PLAID's type system. + +**New Type Definitions:** + +```python +# In src/plaid/storage/webdataset/reader.py +from typing import Iterator, Iterable, Any, Optional, Union +from pathlib import Path +import webdataset as wds + +# WebDataset pipeline type (returned by wds.WebDataset) +WebDatasetPipeline = wds.WebDataset + +# Sample dictionary type for WebDataset +WebDatasetSample = dict[str, Any] # Keys: feature paths, Values: numpy arrays or None +``` + +**Modified Type Definitions:** + +```python +# In src/plaid/storage/registry.py - BackendSpec already supports Optional callables +# No changes needed to BackendSpec dataclass definition +``` + +**Type Annotations:** + +All new functions will use complete type annotations following the existing codebase patterns: +- `Union[str, Path]` for file paths +- `dict[str, Any]` for sample dictionaries +- `Optional[...]` for optional parameters +- `Callable[..., Generator[Sample, None, None]]` for generator functions +- `Iterator[dict[str, Any]]` for WebDataset iteration + +## [Files] + +File modifications required to implement the WebDataset backend. + +**New Files to Create:** + +1. `src/plaid/storage/webdataset/__init__.py` + - Purpose: Package initialization and public API exports + - Exports: `init_datasetdict_from_disk`, `download_datasetdict_from_hub`, `init_datasetdict_streaming_from_hub`, `generate_datasetdict_to_disk`, `push_local_datasetdict_to_hub`, `configure_dataset_card`, `to_var_sample_dict`, `sample_to_var_sample_dict` + +2. `src/plaid/storage/webdataset/reader.py` + - Purpose: Dataset loading and streaming functionality + - Key components: `WebDatasetDict` class (wrapper for split-based access), `init_datasetdict_from_disk`, `download_datasetdict_from_hub`, `init_datasetdict_streaming_from_hub`, helper functions for tar file iteration + +3. `src/plaid/storage/webdataset/writer.py` + - Purpose: Dataset generation and Hub upload + - Key components: `generate_datasetdict_to_disk`, `push_local_datasetdict_to_hub`, `configure_dataset_card`, tar archive creation logic, sample serialization + +4. `src/plaid/storage/webdataset/bridge.py` + - Purpose: Conversion between WebDataset format and PLAID samples + - Key components: `to_var_sample_dict`, `sample_to_var_sample_dict`, feature extraction helpers + +5. `tests/storage/test_webdataset.py` + - Purpose: Unit tests for WebDataset backend + - Test coverage: reader/writer functionality, conversion operations, edge cases, integration with registry + +6. `docs/source/core_concepts/webdataset_backend.md` + - Purpose: Documentation for WebDataset backend usage + - Content: Format specification, usage examples, performance characteristics, comparison with other backends + +**Existing Files to Modify:** + +1. `src/plaid/storage/registry.py` + - Line ~70: Add "webdataset" entry to BACKENDS dict + - Specify all required backend functions following the pattern of zarr/hf_datasets entries + +2. `src/plaid/storage/__init__.py` + - No changes needed (uses registry dynamically) + +3. `pyproject.toml` + - Line ~37 (dependencies section): Add `"webdataset"` to dependencies list + +4. `docs/source/tutorials/storage.md` + - Add WebDataset backend to the list of available backends + - Include WebDataset in the example loops: `all_backends = ["hf_datasets", "cgns", "zarr", "webdataset"]` + +5. `tests/storage/test_storage.py` + - Add test method: `test_webdataset` following the pattern of `test_zarr` + - Add webdataset to registry test assertions (line ~230) + +**Configuration Files:** + +1. `.pre-commit-config.yaml` - No changes needed (linting applies automatically) +2. `ruff.toml` - No changes needed +3. `pyrightconfig.json` - No changes needed + +## [Functions] + +Function-level implementation details for the WebDataset backend. + +**New Functions in `src/plaid/storage/webdataset/reader.py`:** + +1. `init_datasetdict_from_disk(path: Union[str, Path]) -> dict[str, WebDatasetDict]` + - Purpose: Load WebDataset from local tar files + - Returns: Dictionary mapping split names to WebDatasetDict objects + - Logic: Scan for tar files in `path/data/`, create WebDatasetDict wrappers + +2. `download_datasetdict_from_hub(repo_id: str, local_dir: Union[str, Path], split_ids: Optional[dict[str, list[int]]] = None, features: Optional[list[str]] = None, overwrite: bool = False) -> None` + - Purpose: Download WebDataset from HuggingFace Hub + - Uses: `snapshot_download` from huggingface_hub + - Logic: Download tar files with filtering patterns if split_ids/features specified + +3. `init_datasetdict_streaming_from_hub(repo_id: str, split_ids: Optional[dict[str, list[int]]] = None, features: Optional[list[str]] = None) -> dict[str, wds.WebDataset]` + - Purpose: Create streaming dataset from Hub + - Returns: Dictionary of split names to streaming WebDataset pipelines + - Logic: Use wds.WebDataset with Hub URLs, apply filters for split_ids/features + +4. `_create_webdataset_pipeline(tar_path: str, features: Optional[list[str]] = None) -> wds.WebDataset` + - Purpose: Helper to create WebDataset pipeline with decoding + - Returns: Configured WebDataset pipeline + - Logic: Set up .decode(), .to_tuple(), .map() operations + +**New Functions in `src/plaid/storage/webdataset/writer.py`:** + +1. `generate_datasetdict_to_disk(output_folder: Union[str, Path], generators: dict[str, Callable], variable_schema: dict, gen_kwargs: Optional[dict] = None, num_proc: int = 1, verbose: bool = False) -> None` + - Purpose: Generate and save WebDataset tar files from sample generators + - Logic: Iterate samples, serialize to numpy/json, write to tar archives + - Supports: Both sequential and parallel (multiprocess with sharding) modes + +2. `push_local_datasetdict_to_hub(repo_id: str, local_dir: Union[str, Path], num_workers: int = 1) -> None` + - Purpose: Upload local WebDataset to HuggingFace Hub + - Uses: `HfApi.upload_large_folder` + - Logic: Upload tar files with appropriate patterns + +3. `configure_dataset_card(repo_id: str, infos: dict, local_dir: Optional[Union[str, Path]] = None, viewer: Optional[bool] = None, pretty_name: Optional[str] = None, dataset_long_description: Optional[str] = None, illustration_urls: Optional[list[str]] = None, arxiv_paper_urls: Optional[list[str]] = None) -> None` + - Purpose: Create and push dataset card to Hub + - Logic: Generate README.md with metadata, usage examples, format description + +4. `_write_sample_to_tar(tar_writer: wds.TarWriter, sample: Sample, var_features_keys: list[str], sample_idx: int) -> None` + - Purpose: Helper to serialize one sample to tar + - Logic: Convert sample to dict, write .npy files for arrays, .json for metadata + +**New Functions in `src/plaid/storage/webdataset/bridge.py`:** + +1. `to_var_sample_dict(wds_sample: dict[str, Any], idx: int, features: Optional[list[str]]) -> dict[str, Any]` + - Purpose: Extract variable features from WebDataset sample + - Returns: Dictionary of feature paths to values + - Logic: Filter and return requested features from sample dict + +2. `sample_to_var_sample_dict(wds_sample: dict[str, Any]) -> dict[str, Any]` + - Purpose: Convert raw WebDataset sample to variable sample dict + - Returns: Processed sample dictionary + - Logic: Pass through (WebDataset samples are already in correct format) + +3. `_decode_sample(sample: dict[str, bytes]) -> dict[str, Any]` + - Purpose: Helper to decode WebDataset sample bytes to numpy arrays + - Logic: Deserialize .npy files, parse .json metadata + +**Modified Functions:** + +None - all integration is through the registry system, so no existing functions need modification. + +## [Classes] + +Class definitions and modifications for WebDataset backend support. + +**New Classes:** + +1. `WebDatasetDict` (in `src/plaid/storage/webdataset/reader.py`) + - Purpose: Wrapper class for WebDataset tar archives providing dict-like split access + - Inherits: None (standalone class, similar to ZarrDataset) + - Attributes: + - `path: Union[str, Path]` - Path to dataset root + - `split_tar_paths: dict[str, Path]` - Mapping of split names to tar file paths + - `_extra_fields: dict[str, Any]` - Additional metadata + - Methods: + - `__init__(self, path: Union[str, Path], split_tar_paths: dict[str, Path], **kwargs)` + - `__getitem__(self, split: str) -> wds.WebDataset` - Return WebDataset for a split + - `__len__(self) -> int` - Return number of splits + - `__iter__(self) -> Iterator[tuple[str, wds.WebDataset]]` - Iterate over splits + - `__getattr__(self, name: str) -> Any` - Access extra fields + - `__setattr__(self, name: str, value: Any) -> None` - Set extra fields + - `__repr__(self) -> str` - String representation + - Purpose: Provides consistent interface matching ZarrDataset and CGNSDataset patterns + +2. `WebDatasetWrapper` (in `src/plaid/storage/webdataset/reader.py`) + - Purpose: Wrapper for individual WebDataset splits with indexing support + - Inherits: None + - Attributes: + - `wds_pipeline: wds.WebDataset` - Underlying WebDataset pipeline + - `path: Union[str, Path]` - Path to tar file + - `_cache: Optional[list]` - Optional cache for random access + - `ids: np.ndarray` - Array of sample IDs + - Methods: + - `__init__(self, tar_path: Union[str, Path], cache: bool = False)` + - `__getitem__(self, idx: int) -> dict[str, Any]` - Get sample by index + - `__len__(self) -> int` - Return number of samples + - `__iter__(self) -> Iterator[dict[str, Any]]` - Iterate over samples + - Purpose: Enable random access to WebDataset samples (required for PLAID's indexing pattern) + +**Modified Classes:** + +None - WebDataset backend integrates via the BackendSpec dataclass which already exists. + +**Class Relationships:** + +``` +registry.BackendSpec + └─> Configured with webdataset functions + ├─> reader.init_datasetdict_from_disk → WebDatasetDict + ├─> reader.download_datasetdict_from_hub → None + ├─> reader.init_datasetdict_streaming_from_hub → dict[str, wds.WebDataset] + ├─> writer.generate_datasetdict_to_disk → None + ├─> writer.push_local_datasetdict_to_hub → None + ├─> writer.configure_dataset_card → None + ├─> bridge.to_var_sample_dict → dict[str, Any] + └─> bridge.sample_to_var_sample_dict → dict[str, Any] +``` + +## [Dependencies] + +Dependency additions and version requirements for WebDataset backend. + +**New Dependencies:** + +1. `webdataset` (PyPI package) + - Version requirement: `>=0.2.0` (stable release with core features) + - Reason: Core library for WebDataset format handling + - Features used: + - `wds.TarWriter` for tar archive creation + - `wds.WebDataset` for reading and streaming + - `.decode()`, `.to_tuple()`, `.map()` pipeline operations + - Add to `pyproject.toml` dependencies list + +**Existing Dependencies (no changes):** + +- `huggingface_hub` - Already present, used for Hub integration +- `numpy` - Already present, used for array serialization +- `pyyaml` - Already present, used for metadata +- `tqdm` - Already present, used for progress bars + +**Installation:** + +Add to `pyproject.toml` line ~37: +```python +dependencies = [ + "tqdm", + "pyyaml", + "pycgns", + "zarr", + "scikit-learn", + "datasets", + "numpy", + "matplotlib", + "pydantic", + "webdataset>=0.2.0", # ADD THIS LINE +] +``` + +**Compatibility:** + +- Python 3.11-3.13: webdataset supports these versions +- No conflicts with existing dependencies +- Optional GPU acceleration not required (webdataset is pure Python for basic operations) + +## [Testing] + +Comprehensive testing strategy for WebDataset backend implementation. + +**New Test Files:** + +1. `tests/storage/test_webdataset.py` + - Purpose: WebDataset-specific unit tests + - Structure: + ```python + class TestWebDataset: + def test_write_and_read_local(self, tmp_path, generator_split, infos, problem_definition) + def test_sample_iteration(self, tmp_path, generator_split, infos, problem_definition) + def test_converter_operations(self, tmp_path, generator_split, infos, problem_definition) + def test_feature_filtering(self, tmp_path, generator_split, infos, problem_definition) + def test_webdataset_dict_class(self, tmp_path, generator_split, infos, problem_definition) + def test_parallel_generation(self, tmp_path, generator_split_with_kwargs, gen_kwargs, infos, problem_definition) + ``` + - Coverage targets: >90% line coverage for webdataset module + +**Existing Test File Modifications:** + +1. `tests/storage/test_storage.py` + - Add method: `test_webdataset(self, tmp_path, generator_split, infos, problem_definition)` + - Location: After `test_cgns` method (around line 220) + - Content: Following the pattern of `test_zarr`, test basic operations: + - save_to_disk with webdataset backend + - init_from_disk and sample conversion + - plaid_sample reconstruction and feature access + - converter.to_dict and converter.sample_to_dict operations + +2. `tests/storage/test_storage.py` - Registry test + - Line ~240: Add to registry test: + ```python + assert "webdataset" in backends + webdataset_module = registry.get_backend("webdataset") + assert webdataset_module is not None + ``` + +**Test Fixtures:** + +Reuse existing fixtures from `tests/conftest.py` and `tests/storage/test_storage.py`: +- `samples` - Sample PLAID objects +- `infos` - Dataset metadata +- `problem_definition` - Problem definition object +- `dataset` - PLAID Dataset +- `main_splits` - Split configuration +- `generator_split` - Split-based generators +- `generator_split_with_kwargs` - Generators with kwargs for parallel processing +- `gen_kwargs` - Generator arguments for parallel mode + +**Test Coverage Requirements:** + +1. Core functionality: + - ✓ Generate dataset to disk (sequential and parallel) + - ✓ Load dataset from disk + - ✓ Iterate over samples + - ✓ Convert samples to PLAID format + - ✓ Feature extraction and filtering + +2. Edge cases: + - ✓ Empty datasets + - ✓ None values in features + - ✓ Large arrays + - ✓ Unicode strings + - ✓ Missing features + +3. Error handling: + - ✓ Invalid tar files + - ✓ Missing split files + - ✓ Corrupted data + - ✓ Feature key mismatches + +4. Integration: + - ✓ Registry integration + - ✓ Converter class operations + - ✓ Problem definition compatibility + +**Validation Strategy:** + +Run existing test suite with new backend to ensure no regressions: +```bash +pytest tests/storage/test_storage.py::Test_Storage::test_webdataset -v +pytest tests/storage/test_webdataset.py -v +pytest tests/storage/test_storage.py::Test_Storage::test_registry -v +``` + +## [Implementation Order] + +Step-by-step implementation sequence to minimize conflicts and ensure successful integration. + +**Phase 1: Foundation (Dependencies & Structure)** + +1. Update `pyproject.toml` + - Add `webdataset>=0.2.0` to dependencies + - Rationale: Required before any code can import webdataset + +2. Create directory structure + - Create `src/plaid/storage/webdataset/` directory + - Create empty `__init__.py`, `reader.py`, `writer.py`, `bridge.py` files + - Rationale: Establishes module structure for imports + +**Phase 2: Core Bridge Layer** + +3. Implement `src/plaid/storage/webdataset/bridge.py` + - Implement `to_var_sample_dict` function + - Implement `sample_to_var_sample_dict` function + - Add docstrings and type hints + - Rationale: Bridge functions are dependencies for reader/writer + +4. Implement `src/plaid/storage/common/bridge.py` helpers (if needed) + - No changes required (existing helpers sufficient) + - Rationale: Reuse existing flatten_path/unflatten_path utilities + +**Phase 3: Writer Implementation** + +5. Implement `src/plaid/storage/webdataset/writer.py` - Basic structure + - Implement `_write_sample_to_tar` helper function + - Implement `generate_datasetdict_to_disk` (sequential mode only) + - Rationale: Writing capability needed before testing read operations + +6. Implement `src/plaid/storage/webdataset/writer.py` - Advanced features + - Add parallel processing support to `generate_datasetdict_to_disk` + - Implement `push_local_datasetdict_to_hub` + - Implement `configure_dataset_card` + - Rationale: Complete write functionality before reader + +**Phase 4: Reader Implementation** + +7. Implement `src/plaid/storage/webdataset/reader.py` - Core classes + - Implement `WebDatasetWrapper` class + - Implement `WebDatasetDict` class + - Rationale: Dataset wrapper classes needed for consistent interface + +8. Implement `src/plaid/storage/webdataset/reader.py` - Load functions + - Implement `init_datasetdict_from_disk` + - Implement `_create_webdataset_pipeline` helper + - Rationale: Local loading enables testing without Hub dependency + +9. Implement `src/plaid/storage/webdataset/reader.py` - Hub functions + - Implement `download_datasetdict_from_hub` + - Implement `init_datasetdict_streaming_from_hub` + - Rationale: Hub integration completes reader functionality + +**Phase 5: Integration** + +10. Implement `src/plaid/storage/webdataset/__init__.py` + - Export all public functions + - Add module docstring + - Rationale: Establishes public API + +11. Update `src/plaid/storage/registry.py` + - Add webdataset BackendSpec to BACKENDS dict + - Import webdataset module + - Rationale: Register backend for system-wide availability + +**Phase 6: Testing** + +12. Create `tests/storage/test_webdataset.py` + - Implement core test class and methods + - Test write and read operations + - Test converter operations + - Rationale: Dedicated tests for WebDataset-specific functionality + +13. Update `tests/storage/test_storage.py` + - Add `test_webdataset` method + - Update registry test to include webdataset + - Rationale: Integration tests with existing test infrastructure + +14. Run full test suite + - Execute: `pytest tests/storage/ -v` + - Verify: No regressions in existing backends + - Rationale: Ensure system-wide compatibility + +**Phase 7: Documentation** + +15. Create `docs/source/core_concepts/webdataset_backend.md` + - Document WebDataset format specification + - Provide usage examples + - Compare with other backends + - Rationale: User-facing documentation + +16. Update `docs/source/tutorials/storage.md` + - Add "webdataset" to all_backends list + - Add WebDataset examples to tutorial + - Rationale: Integration into existing documentation + +**Phase 8: Code Quality** + +17. Run linting and formatting + - Execute: `ruff check src/plaid/storage/webdataset/` + - Execute: `ruff format src/plaid/storage/webdataset/` + - Fix any issues + - Rationale: Ensure code quality standards + +18. Run type checking + - Execute: `pyright src/plaid/storage/webdataset/` + - Fix any type errors + - Rationale: Ensure type safety + +19. Final validation + - Run complete test suite: `pytest tests/ -v` + - Check test coverage: `pytest tests/storage/ --cov=src/plaid/storage/webdataset` + - Verify: Coverage >90% + - Rationale: Final quality gate before completion + +**Critical Path Dependencies:** + +``` +1. pyproject.toml → 2. Directory structure → 3. Bridge layer + ↓ + 5-6. Writer implementation + ↓ + 7-9. Reader implementation + ↓ + 10-11. Integration + ↓ + 12-14. Testing + ↓ + 15-16. Documentation + ↓ + 17-19. Quality checks +``` + +**Estimated Implementation Time:** + +- Phase 1-2: 1 hour (setup) +- Phase 3: 2 hours (bridge layer) +- Phase 4: 4 hours (writer) +- Phase 5: 4 hours (reader) +- Phase 6: 2 hours (integration) +- Phase 7: 4 hours (testing) +- Phase 8: 2 hours (documentation) +- Phase 9: 1 hour (quality) + +**Total: ~20 hours** + +**Risk Mitigation:** + +- Test each phase independently before proceeding +- Use existing zarr/hf_datasets backends as reference implementations +- Create checkpoints after each major phase (commit to version control) +- If Hub integration issues arise, implement local-only first, then add Hub support \ No newline at end of file diff --git a/piste3_detailed_explanation.md b/piste3_detailed_explanation.md new file mode 100644 index 00000000..5c38e8fd --- /dev/null +++ b/piste3_detailed_explanation.md @@ -0,0 +1,172 @@ +# Piste 3: Modification du Preprocessing - Analyse Détaillée + +## Lignes Critiques Identifiées + +**Fichier:** `src/plaid/storage/common/preprocessor.py` +**Lignes:** 559-560 (dans la fonction `preprocess()`) + +```python +for split_name in split_flat_cst.keys(): + for path in var_features: + if not path.endswith("_times") and path not in split_all_paths[split_name]: + split_flat_cst[split_name][path + "_times"] = None # ← LIGNE PROBLÉMATIQUE + if path in split_flat_cst[split_name]: + split_flat_cst[split_name].pop(path) # pragma: no cover +``` + +## Explication du Code Actuel + +Cette boucle traite chaque feature variable (`var_features`) pour chaque split: + +1. **Ligne 559:** Si la feature n'est PAS une `_times` ET n'existe pas dans `split_all_paths` +2. **Ligne 560:** Ajouter `path_times=None` à `split_flat_cst` +3. **Ligne 561-562:** Si la feature de base existe dans flat_cst, la retirer (car elle est variable) + +**Intention Originale:** S'assurer que toutes les features variables ont une entrée `_times` dans flat_cst, même pour les splits où elles n'apparaissent pas. + +**Problème Créé:** Cela ajoute des `_times` orphelines (sans leur feature de base) dans `flat_cst`, ce qui cause le déséquilibre dans `_split_dict`. + +## Changements Nécessaires pour Piste 3 + +### Option 3A: Ne pas ajouter de _times orphelines + +**Modification:** +```python +for split_name in split_flat_cst.keys(): + for path in var_features: + # CHANGEMENT: Ne pas ajouter _times orphelines + # Commenté la ligne problématique: + # if not path.endswith("_times") and path not in split_all_paths[split_name]: + # split_flat_cst[split_name][path + "_times"] = None + + if path in split_flat_cst[split_name]: + split_flat_cst[split_name].pop(path) +``` + +**Impact:** +- ✅ Résout le problème pour WebDataset +- ⚠️ Change le comportement pour TOUS les backends +- ⚠️ Peut affecter zarr et hf_datasets si ils dépendent de ces _times orphelines +- ⚠️ Besoin de valider les 400 tests + +### Option 3B: Ajouter aussi la feature de base + +**Modification:** +```python +for split_name in split_flat_cst.keys(): + for path in var_features: + if not path.endswith("_times") and path not in split_all_paths[split_name]: + # CHANGEMENT: Ajouter AUSSI la feature de base, pas seulement _times + split_flat_cst[split_name][path] = None + split_flat_cst[split_name][path + "_times"] = None + if path in split_flat_cst[split_name]: + split_flat_cst[split_name].pop(path) +``` + +**Impact:** +- ✅ Crée des paires cohérentes feat/feat_times +- ⚠️ Ajoute des None artificiels dans flat_cst +- ⚠️ Change le comportement pour tous les backends +- ⚠️ Peut avoir effets de bord sur la reconstruction des samples + +### Option 3C: Nettoyage post-preprocessing + +**Modification:** +Ajouter une étape de nettoyage après `preprocess_splits()` mais avant le retour: + +```python +# À la fin de preprocess(), après construction des schemas: +# Nettoyer les _times orphelines dans split_flat_cst +for split_name in split_flat_cst.keys(): + orphan_times = [] + for key in split_flat_cst[split_name].keys(): + if key.endswith("_times"): + base_key = key[:-6] + # Si la feature de base n'est ni dans flat_cst ni dans variable_schema + if base_key not in split_flat_cst[split_name] and base_key not in variable_schema: + orphan_times.append(key) + + for key in orphan_times: + del split_flat_cst[split_name][key] +``` + +**Impact:** +- ✅ Nettoie après coup sans changer la logique principale +- ✅ Plus sûr pour les autres backends +- ⚠️ Ajoute complexité au preprocessing +- ⚠️ Peut masquer un problème de design plus profond + +## Analyse des Risques + +### Risques Globaux de la Piste 3 +1. **Code Partagé:** `preprocessor.py` est utilisé par TOUS les backends +2. **Comportement Établi:** Ce code existe depuis longtemps, peut-être avec raison +3. **Tests Indirects:** Modification peut casser des tests non-évidents +4. **Maintenance:** Complexifie le code de preprocessing déjà complexe + +### Tests à Valider si Piste 3 Implémentée +```bash +# Test complet pour détecter régressions +pytest tests/storage/test_storage.py -v + +# Tests des autres backends +pytest tests/storage/test_storage.py::Test_Storage::test_hf_datasets -xvs +pytest tests/storage/test_storage.py::Test_Storage::test_zarr -xvs +pytest tests/storage/test_storage.py::Test_Storage::test_cgns -xvs + +# Tests de conversion +pytest tests/bridges/test_huggingface_bridge.py -v +``` + +## Comparaison Piste 2 vs Piste 3 + +| Aspect | Piste 2 (Converter) | Piste 3 (Preprocessing) | +|--------|---------------------|-------------------------| +| **Localisation** | Converter.to_dict() | preprocess() | +| **Impact** | Webdataset uniquement | Tous les backends | +| **Risque** | Faible | Moyen-Élevé | +| **Complexité** | Simple (5-10 lignes) | Moyenne (15-25 lignes) | +| **Tests Required** | 2 tests | 400 tests | +| **Maintenance** | Isolé | Code critique partagé | + +## Recommandation Finale + +**NE PAS implémenter Piste 3** sans validation approfondie. Les risques sont trop élevés pour un problème qui peut être résolu avec Piste 2. + +**Piste 2 est préférable car:** +- Modification localisée dans Converter +- Peut être conditionnelle au backend (if self.backend == "webdataset") +- Pas d'impact sur les autres backends +- Plus facile à reverter si problème +- Plus facile à maintenir + +## Implémentation Alternative: Piste 2 Améliorée + +Si nécessaire, voici une implémentation plus robuste de la Piste 2: + +```python +# Dans Converter.to_dict(), après var_sample_dict = ... +# Et avant to_sample_dict(var_sample_dict, self.flat_cst, ...) + +# Clean flat_cst for backends that don't store None features +if self.backend in ["webdataset", "zarr"]: + clean_flat_cst = {} + for key, val in self.flat_cst.items(): + if key.endswith("_times"): + base_key = key[:-6] + # Keep _times only if base feature is in variable_schema + # OR if it's in flat_cst with non-None value + if base_key in self.variable_schema or base_key in self.flat_cst: + clean_flat_cst[key] = val + else: + clean_flat_cst[key] = val + use_flat_cst = clean_flat_cst +else: + use_flat_cst = self.flat_cst + +return to_sample_dict(var_sample_dict, use_flat_cst, self.cgns_types, features) +``` + +## Conclusion + +La Piste 3 nécessiterait des changements au cœur du preprocessing, avec des risques significatifs pour tous les backends. **La Piste 2 est fortement recommandée** comme solution pragmatique et sûre. diff --git a/pyproject.toml b/pyproject.toml index 2e7a3a12..586860a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "numpy", "matplotlib", "pydantic", + "webdataset>=0.2.0", ] dynamic = ["version"] diff --git a/src/plaid/storage/reader.py b/src/plaid/storage/reader.py index d93610a2..23ca3d8e 100644 --- a/src/plaid/storage/reader.py +++ b/src/plaid/storage/reader.py @@ -21,6 +21,8 @@ from pathlib import Path from typing import Any, Iterable, Optional, Union +import numpy as np + from plaid import Sample from plaid.storage.common.bridge import ( plaid_to_sample_dict, @@ -118,7 +120,82 @@ def to_dict( var_sample_dict = self.backend_spec.to_var_sample_dict( dataset, idx, features=req_var_feat ) - return to_sample_dict(var_sample_dict, self.flat_cst, self.cgns_types, features) + + # PISTE 2 FIX: Clean flat_cst for backends that don't store None features + # Problem: flat_cst may contain orphan "_times" entries (e.g., "feature_times": None + # without "feature"), causing mismatch in _split_dict when merging with var_sample_dict + if self.backend in ["webdataset", "zarr"]: + # Create case-insensitive lookup sets for var_sample_dict + # (webdataset uses lowercase paths, flat_cst uses original case) + var_keys_lower = {k.lower() for k in var_sample_dict.keys()} + flat_keys_lower = {k.lower() for k in self.flat_cst.keys()} + + clean_flat_cst = {} + for key, val in self.flat_cst.items(): + if key.endswith("_times"): + base_key = key[:-6] + base_key_lower = base_key.lower() + # Keep _times only if the base feature is ALSO in flat_cst (i.e., it's a constant) + # If the base is in var_sample_dict (i.e., it's a variable), exclude its _times + # because the variable will handle its own timing + if ( + base_key_lower in flat_keys_lower + and base_key_lower not in var_keys_lower + ): + clean_flat_cst[key] = val + # else: skip orphan _times (or _times for variable features) + else: + clean_flat_cst[key] = val + + # Normalize case: Create a case mapping from var_sample_dict to use for flat_cst keys + # This ensures that when merged, keys will have consistent case for proper zip() alignment + case_map = {k.lower(): k for k in var_sample_dict.keys()} + normalized_flat_cst = {} + for key, val in clean_flat_cst.items(): + key_lower = key.lower() + # If this key (or its base for _times keys) exists in var_sample_dict, use its case + if key_lower in case_map: + normalized_flat_cst[case_map[key_lower]] = val + elif key.endswith("_times"): + base_lower = key_lower[:-6] + if base_lower in case_map: + # Use the var_sample_dict case for the base, then add _times + normalized_flat_cst[case_map[base_lower] + "_times"] = val + else: + # Keep original case if not in var_sample_dict + normalized_flat_cst[key] = val + else: + # Keep original case for keys not in var_sample_dict + normalized_flat_cst[key] = val + + # Add synthetic _times for variables that don't have them + # Variables from var_sample_dict need _times entries for _split_dict to work + merged_with_times = {} + for key, val in normalized_flat_cst.items(): + merged_with_times[key] = val + + for key, val in var_sample_dict.items(): + merged_with_times[key] = val + # If this is a variable and doesn't have a _times entry, create one + times_key = key + "_times" + if times_key not in merged_with_times and val is not None: + # Create synthetic _times: single time point at 0.0, covering whole array + # Format: [[time, start_idx, end_idx]] + if hasattr(val, "shape") and len(val.shape) > 0: + merged_with_times[times_key] = np.array( + [[0.0, 0, -1]], dtype=np.float64 + ) + else: + # Scalar or 0-d array + merged_with_times[times_key] = np.array( + [[0.0, 0, -1]], dtype=np.float64 + ) + + use_flat_cst = merged_with_times + else: + use_flat_cst = self.flat_cst + + return to_sample_dict(var_sample_dict, use_flat_cst, self.cgns_types, features) def to_plaid(self, dataset: Any, idx: int) -> Sample: """Convert a dataset sample to PLAID Sample object. diff --git a/src/plaid/storage/registry.py b/src/plaid/storage/registry.py index 657d7188..29b24996 100644 --- a/src/plaid/storage/registry.py +++ b/src/plaid/storage/registry.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from typing import Any, Callable, Optional -from . import cgns, hf_datasets, zarr +from . import cgns, hf_datasets, webdataset, zarr @dataclass(frozen=True) @@ -66,6 +66,17 @@ class BackendSpec: to_var_sample_dict=zarr.to_var_sample_dict, sample_to_var_sample_dict=zarr.sample_to_var_sample_dict, ), + "webdataset": BackendSpec( + name="webdataset", + init_from_disk=webdataset.init_datasetdict_from_disk, + download_from_hub=webdataset.download_datasetdict_from_hub, + init_streaming_from_hub=webdataset.init_datasetdict_streaming_from_hub, + generate_to_disk=webdataset.generate_datasetdict_to_disk, + push_local_to_hub=webdataset.push_local_datasetdict_to_hub, + configure_dataset_card=webdataset.configure_dataset_card, + to_var_sample_dict=webdataset.to_var_sample_dict, + sample_to_var_sample_dict=webdataset.sample_to_var_sample_dict, + ), } diff --git a/src/plaid/storage/webdataset/__init__.py b/src/plaid/storage/webdataset/__init__.py new file mode 100644 index 00000000..b8e98b35 --- /dev/null +++ b/src/plaid/storage/webdataset/__init__.py @@ -0,0 +1,56 @@ +"""WebDataset storage backend for PLAID. + +This module provides WebDataset format support for PLAID datasets, enabling tar-based +storage with streaming capabilities and Hugging Face Hub integration. + +The WebDataset backend uses tar archives where samples with the same basename belong +together (e.g., sample_000000000.json and sample_000000000.npy). This format is +ideal for streaming large physics datasets and has excellent compatibility with +Hugging Face Hub. + +Public API: + - init_datasetdict_from_disk: Load dataset from local tar files + - download_datasetdict_from_hub: Download dataset from Hub + - init_datasetdict_streaming_from_hub: Stream dataset from Hub + - generate_datasetdict_to_disk: Generate and save dataset to tar archives + - push_local_datasetdict_to_hub: Upload dataset to Hub + - configure_dataset_card: Create and push dataset card + - to_var_sample_dict: Extract variable features from sample + - sample_to_var_sample_dict: Convert sample to variable sample dict +""" + +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +from plaid.storage.webdataset.bridge import ( + sample_to_var_sample_dict, + to_var_sample_dict, +) +from plaid.storage.webdataset.reader import ( + download_datasetdict_from_hub, + init_datasetdict_from_disk, + init_datasetdict_streaming_from_hub, +) +from plaid.storage.webdataset.writer import ( + configure_dataset_card, + generate_datasetdict_to_disk, + push_local_datasetdict_to_hub, +) + +__all__ = [ + # Reader functions + "init_datasetdict_from_disk", + "download_datasetdict_from_hub", + "init_datasetdict_streaming_from_hub", + # Writer functions + "generate_datasetdict_to_disk", + "push_local_datasetdict_to_hub", + "configure_dataset_card", + # Bridge functions + "to_var_sample_dict", + "sample_to_var_sample_dict", +] diff --git a/src/plaid/storage/webdataset/bridge.py b/src/plaid/storage/webdataset/bridge.py new file mode 100644 index 00000000..e3ae3470 --- /dev/null +++ b/src/plaid/storage/webdataset/bridge.py @@ -0,0 +1,69 @@ +"""WebDataset bridge utilities. + +This module provides utility functions for bridging between PLAID samples and WebDataset storage format. +It includes functions for sample data conversion and feature extraction. +""" + +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +from typing import Any, Optional + + +def to_var_sample_dict( + wds_dataset, idx: int, features: Optional[list[str]] +) -> dict[str, Any]: + """Extracts variable features from a WebDataset. + + Args: + wds_dataset: The WebDataset wrapper object. + idx: The sample index to extract. + features: Optional list of feature names to extract. If None, all features are returned. + + Returns: + dict[str, Any]: Dictionary of variable features for the sample. + """ + # Get the sample from the dataset + wds_sample = wds_dataset[idx] + + if features is None: + # Return only what's actually stored in the tar (non-None features) + # The Converter.to_dict() now cleans orphan _times from flat_cst + return wds_sample + + # Return requested features + # For features not in the sample, return None + # But for _times features, only return if base feature exists OR both are missing + result = {} + for feat in features: + if feat in wds_sample: + result[feat] = wds_sample[feat] + elif feat.endswith("_times"): + # For _times, only add if the base feature is also requested + base_feat = feat[:-6] + if base_feat in features: + # Both requested, return None for _times + result[feat] = None + else: + # Feature not in sample, return None + result[feat] = None + + return result + + +def sample_to_var_sample_dict(wds_sample: dict[str, Any]) -> dict[str, Any]: + """Converts a WebDataset sample to a variable sample dictionary. + + This is a pass-through function since WebDataset samples are already in the correct format. + + Args: + wds_sample: The raw WebDataset sample data. + + Returns: + dict[str, Any]: The processed variable sample dictionary (same as input). + """ + return wds_sample diff --git a/src/plaid/storage/webdataset/reader.py b/src/plaid/storage/webdataset/reader.py new file mode 100644 index 00000000..1eca4108 --- /dev/null +++ b/src/plaid/storage/webdataset/reader.py @@ -0,0 +1,490 @@ +"""WebDataset reader module. + +This module provides functionality for reading and streaming datasets stored in WebDataset format +for the PLAID library. It includes utilities for loading datasets from local disk or +streaming directly from Hugging Face Hub, with support for selective loading of splits +and features. + +Key features: +- Local dataset loading from tar archives +- Streaming datasets from Hugging Face Hub +- Selective loading of splits and features +- WebDatasetWrapper class for convenient data access with indexing support +""" + +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +import io +import logging +import os +import shutil +from pathlib import Path +from typing import Any, Iterator, Optional, Union + +import numpy as np +import webdataset as wds +import yaml +from huggingface_hub import hf_hub_download, snapshot_download + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------ +# Classes +# ------------------------------------------------------ + + +class WebDatasetWrapper: + """A wrapper class for WebDataset providing indexing support. + + This class wraps a WebDataset tar archive and provides indexed access to samples, + which is required for PLAID's data access patterns. Since WebDataset is designed + for streaming, this wrapper caches samples on first iteration to enable random access. + """ + + def __init__( + self, tar_path: Union[str, Path], features: Optional[list[str]] = None + ) -> None: + """Initialize a WebDatasetWrapper. + + Args: + tar_path: Path to the tar file containing the dataset. + features: Optional list of features to load. If None, all features are loaded. + """ + self.tar_path = Path(tar_path) + self.features = features + self._cache = None + self._ids = None + + def _load_cache(self) -> None: + """Load all samples into cache for random access.""" + if self._cache is not None: + return + + self._cache = [] + self._ids = [] + + # Read tar directly to preserve case in filenames + import tarfile + + with tarfile.open(str(self.tar_path), "r") as tar: + # Group files by sample + samples_dict = {} + for member in tar.getmembers(): + if not member.isfile(): + continue + + # Extract sample ID and feature from filename + # Format: sample_XXXXXXXXX.Feature__Path.npy + parts = member.name.split(".", 1) + if len(parts) != 2: + continue + + sample_key = parts[0] + rest = parts[1] + + if not rest.endswith(".npy"): + continue + + if sample_key not in samples_dict: + samples_dict[sample_key] = {} + + samples_dict[sample_key][rest] = member + + # Process each sample in order + for sample_key in sorted(samples_dict.keys()): + sample_members = samples_dict[sample_key] + decoded_sample = {} + + for feature_key, member in sample_members.items(): + # Remove .npy extension and convert __ to / + feature_path = feature_key[:-4].replace("__", "/") + + # Read and decode the numpy array + file_obj = tar.extractfile(member) + if file_obj: + buffer = io.BytesIO(file_obj.read()) + array = np.load(buffer, allow_pickle=True) + decoded_sample[feature_path] = array + + # Filter features if specified + if self.features is not None: + decoded_sample = { + k: v for k, v in decoded_sample.items() if k in self.features + } + + self._cache.append(decoded_sample) + + # Extract sample index + if sample_key.startswith("sample_"): + idx = int(sample_key.split("_")[1]) + self._ids.append(idx) + + self._ids = np.array(self._ids, dtype=int) + + def _decode_sample(self, sample: dict[str, bytes]) -> dict[str, Any]: + """Decode a WebDataset sample from bytes to numpy arrays. + + Args: + sample: Dictionary of extension -> bytes from tar archive. + + Returns: + dict[str, Any]: Decoded sample with feature paths as keys. + """ + decoded = {} + + for key, value in sample.items(): + # Skip __key__ metadata + if key == "__key__": + continue + + # Handle .npy files + # Format in dict: "feature__path.npy" -> bytes + if key.endswith(".npy"): + # Remove .npy extension to get feature path + feature_path = key[:-4] + # Convert __ back to / + feature_path = feature_path.replace("__", "/") + + # Decode numpy array + buffer = io.BytesIO(value) + array = np.load(buffer, allow_pickle=True) + decoded[feature_path] = array + + return decoded + + def __iter__(self) -> Iterator[dict[str, Any]]: + """Iterate over all samples in the dataset. + + Yields: + dict[str, Any]: Dictionary containing sample data. + """ + self._load_cache() + yield from self._cache + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Get a sample by index. + + Args: + idx: Sample index. + + Returns: + dict[str, Any]: Dictionary containing sample data. + """ + self._load_cache() + + # Find the position in cache + if idx not in self._ids: + raise IndexError(f"Sample index {idx} not found in dataset") + + position = np.where(self._ids == idx)[0][0] + return self._cache[position] + + def __len__(self) -> int: + """Get the number of samples in the dataset. + + Returns: + int: Number of samples. + """ + self._load_cache() + return len(self._cache) + + @property + def ids(self) -> np.ndarray: + """Get array of sample IDs in the dataset. + + Returns: + np.ndarray: Array of sample indices. + """ + self._load_cache() + return self._ids + + +class WebDatasetDict: + """A dataset dictionary class for WebDataset format. + + This class provides a dictionary-like interface to access multiple splits of a + WebDataset, similar to ZarrDataset pattern in PLAID. + """ + + def __init__( + self, + path: Union[str, Path], + split_tar_paths: dict[str, Path], + features: Optional[list[str]] = None, + **kwargs, + ) -> None: + """Initialize a WebDatasetDict. + + Args: + path: Path to the dataset root directory. + split_tar_paths: Dictionary mapping split names to tar file paths. + features: Optional list of features to load. + **kwargs: Additional metadata to attach to the dataset. + """ + self.path = path + self.split_tar_paths = split_tar_paths + self.features = features + self._extra_fields = dict(kwargs) + self._splits = {} + + def __getitem__(self, split: str) -> WebDatasetWrapper: + """Get a split by name. + + Args: + split: Split name. + + Returns: + WebDatasetWrapper: Wrapper for the split's tar archive. + """ + if split not in self._splits: + if split not in self.split_tar_paths: + raise KeyError(f"Split '{split}' not found in dataset") + + self._splits[split] = WebDatasetWrapper( + self.split_tar_paths[split], self.features + ) + + return self._splits[split] + + def __len__(self) -> int: + """Get the number of splits. + + Returns: + int: Number of splits. + """ + return len(self.split_tar_paths) + + def __iter__(self) -> Iterator[tuple[str, WebDatasetWrapper]]: + """Iterate over splits. + + Yields: + tuple[str, WebDatasetWrapper]: (split_name, dataset_wrapper) pairs. + """ + for split_name in self.split_tar_paths.keys(): + yield split_name, self[split_name] + + def keys(self): + """Get split names. + + Returns: + Iterator of split names. + """ + return self.split_tar_paths.keys() + + def values(self): + """Get dataset wrappers. + + Yields: + WebDatasetWrapper instances. + """ + for split_name in self.split_tar_paths.keys(): + yield self[split_name] + + def items(self): + """Get (split_name, dataset_wrapper) pairs. + + Yields: + tuple[str, WebDatasetWrapper] pairs. + """ + return self.__iter__() + + def __getattr__(self, name: str) -> Any: + """Get attribute from extra fields. + + Args: + name: Attribute name. + + Returns: + Any: Attribute value. + + Raises: + AttributeError: If attribute not found. + """ + if name in self._extra_fields: + return self._extra_fields[name] + raise AttributeError(f"{type(self).__name__} has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute in extra fields. + + Args: + name: Attribute name. + value: Attribute value. + """ + if name in ("path", "split_tar_paths", "features", "_extra_fields", "_splits"): + super().__setattr__(name, value) + else: + self._extra_fields[name] = value + + def __repr__(self) -> str: + """String representation of the dataset. + + Returns: + str: String representation. + """ + splits = list(self.split_tar_paths.keys()) + return f"" + + +# ------------------------------------------------------ +# Load from disk +# ------------------------------------------------------ + + +def init_datasetdict_from_disk( + path: Union[str, Path], +) -> dict[str, WebDatasetWrapper]: + """Initializes dataset dictionaries from local WebDataset tar files. + + Args: + path: Path to the local directory containing the dataset. + + Returns: + dict[str, WebDatasetWrapper]: Dictionary mapping split names to WebDatasetWrapper objects. + """ + local_path = Path(path) / "data" + + if not local_path.exists(): + raise ValueError(f"Data directory not found: {local_path}") + + # Find all tar files + tar_files = list(local_path.glob("*.tar")) + + if not tar_files: + raise ValueError(f"No tar files found in {local_path}") + + # Create split_tar_paths mapping + split_tar_paths = {f.stem: f for f in tar_files} + + # Create WebDatasetDict + dataset_dict = WebDatasetDict(path, split_tar_paths) + + # Return as plain dict for compatibility + return {split: dataset_dict[split] for split in split_tar_paths.keys()} + + +# ------------------------------------------------------ +# Load from Hub +# ------------------------------------------------------ + + +def download_datasetdict_from_hub( + repo_id: str, + local_dir: Union[str, Path], + split_ids: Optional[dict[str, list[int]]] = None, + features: Optional[list[str]] = None, + overwrite: bool = False, +) -> None: # pragma: no cover + """Downloads dataset from Hugging Face Hub to local directory. + + Args: + repo_id: The Hugging Face repository ID. + local_dir: Local directory to download to. + split_ids: Optional split IDs for selective download (not implemented for WebDataset). + features: Optional features for selective download (not implemented for WebDataset). + overwrite: Whether to overwrite existing directory. + + Returns: + None + """ + output_folder = Path(local_dir) + + if output_folder.is_dir(): + if overwrite: + shutil.rmtree(local_dir) + logger.warning(f"Existing {local_dir} directory has been reset.") + elif any(output_folder.iterdir()): + raise ValueError( + f"directory {local_dir} already exists and is not empty. " + "Set `overwrite` to True if needed." + ) + + # Note: split_ids and features filtering not implemented for WebDataset + # These would require streaming and re-packing tar files + if split_ids is not None: + logger.warning( + "split_ids filtering not supported for WebDataset backend, " + "downloading full dataset" + ) + + if features is not None: + logger.warning( + "features filtering not supported for WebDataset backend, " + "downloading full dataset" + ) + + # Download tar files and metadata + allow_patterns = ["data/*.tar", "*.yaml", "*.yml", "*.json", "README.md"] + + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + allow_patterns=allow_patterns, + local_dir=local_dir, + ) + + +def init_datasetdict_streaming_from_hub( + repo_id: str, + split_ids: Optional[dict[str, list[int]]] = None, + features: Optional[list[str]] = None, # noqa: ARG001 +) -> dict[str, wds.WebDataset]: # pragma: no cover + """Initializes streaming dataset dictionaries from Hugging Face Hub. + + This function creates WebDataset pipelines that stream tar data directly from + the Hugging Face Hub without downloading files locally. + + Args: + repo_id: The Hugging Face repository ID. + split_ids: Optional dictionary mapping split names to sample IDs (not supported). + features: Optional list of feature names to include. + + Returns: + dict[str, wds.WebDataset]: Dictionary mapping split names to WebDataset pipelines. + """ + hf_endpoint = os.getenv("HF_ENDPOINT", "").strip() + if hf_endpoint: + raise RuntimeError("Streaming mode not compatible with private mirror.") + + if split_ids is not None: + logger.warning( + "split_ids filtering not supported for WebDataset streaming, " + "loading all samples" + ) + + # Get list of splits from infos + yaml_path = hf_hub_download( + repo_id=repo_id, + filename="infos.yaml", + repo_type="dataset", + ) + with open(yaml_path, "r", encoding="utf-8") as f: + infos = yaml.safe_load(f) + + splits = list(infos.get("num_samples", {}).keys()) + + if not splits: + raise ValueError(f"No splits found in dataset {repo_id}") + + # Create streaming WebDataset for each split + base_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/data" + + datasets = {} + for split in splits: + tar_url = f"{base_url}/{split}.tar" + + # Create WebDataset pipeline + dataset = wds.WebDataset(tar_url) + + # Add decoding if needed + # dataset = dataset.decode() + + datasets[split] = dataset + + return datasets diff --git a/src/plaid/storage/webdataset/writer.py b/src/plaid/storage/webdataset/writer.py new file mode 100644 index 00000000..4d804723 --- /dev/null +++ b/src/plaid/storage/webdataset/writer.py @@ -0,0 +1,523 @@ +"""WebDataset writer module. + +This module provides functionality for writing and managing datasets in WebDataset format +for the PLAID library. It includes utilities for generating datasets from sample +generators, saving them to tar archives, uploading to Hugging Face Hub, and configuring +dataset cards with metadata and usage examples. + +Key features: +- Parallel and sequential dataset generation from generators +- Tar-based storage format for streaming compatibility +- Integration with Hugging Face Hub for dataset sharing +- Dataset card generation with splits, features, and documentation +""" + +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +import io +import multiprocessing as mp +from pathlib import Path +from typing import Callable, Generator, Optional, Union + +import numpy as np +import webdataset as wds +import yaml +from huggingface_hub import DatasetCard, HfApi +from tqdm import tqdm + +from plaid import Sample +from plaid.storage.common.preprocessor import build_sample_dict +from plaid.types import IndexType + + +def _write_sample_to_tar( + tar_writer: wds.TarWriter, + sample: Sample, + var_features_keys: list[str], + sample_idx: int, +) -> None: + """Write a single PLAID sample to a WebDataset tar archive. + + This function serializes one Sample instance into a set of files in a tar archive. + Each sample is written with a common basename (e.g., sample_000000000) and different + extensions for different data types: + - .npy files for numpy arrays (variable features) + - .json for metadata and non-array data + + Args: + tar_writer: WebDataset TarWriter instance for writing to tar. + sample: PLAID Sample object to serialize. + var_features_keys: List of feature paths to extract and write. + sample_idx: Global index of the sample for naming. + """ + sample_dict, _, _ = build_sample_dict(sample) + sample_data = {path: sample_dict.get(path, None) for path in var_features_keys} + + # Create a dictionary to hold the sample + basename = f"sample_{sample_idx:09d}" + sample_files = {"__key__": basename} + + # Separate arrays and metadata + # Track which base features have values + features_with_values = set() + + for key, value in sample_data.items(): + # Skip _times keys for now, we'll handle them after + if key.endswith("_times"): + continue + + if value is None: + continue + + # Mark that this feature has a value + features_with_values.add(key) + + # Convert numpy arrays to bytes + if isinstance(value, np.ndarray): + # Save numpy array to bytes + buffer = io.BytesIO() + np.save(buffer, value) + buffer.seek(0) + # Use key as filename with .npy extension, replacing / with __ + safe_key = key.replace("/", "__") + sample_files[f"{safe_key}.npy"] = buffer.read() + else: + # Store non-arrays as JSON metadata + if not hasattr(_write_sample_to_tar, "_metadata_keys"): + _write_sample_to_tar._metadata_keys = [] + if key not in _write_sample_to_tar._metadata_keys: + _write_sample_to_tar._metadata_keys.append(key) + + # Now add _times only for features that have values + for key, value in sample_data.items(): + if not key.endswith("_times"): + continue + + base_feature = key[:-6] # Remove "_times" suffix + + # Only write _times if the base feature has a value + if base_feature in features_with_values and value is not None: + buffer = io.BytesIO() + np.save(buffer, value) + buffer.seek(0) + safe_key = key.replace("/", "__") + sample_files[f"{safe_key}.npy"] = buffer.read() + + # Write all files for this sample to tar + tar_writer.write(sample_files) + + +def generate_datasetdict_to_disk( + output_folder: Union[str, Path], + generators: dict[str, Callable[..., Generator[Sample, None, None]]], + variable_schema: dict[str, dict], + gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, + num_proc: int = 1, + verbose: bool = False, +) -> None: + """Generates and saves a dataset dictionary to disk in WebDataset format. + + This function processes sample generators for different dataset splits, + converts samples to dictionaries, and writes them to tar archives on disk. + It supports both sequential and parallel processing modes. + + Args: + output_folder: Base directory where the dataset will be saved. + A 'data' subdirectory will be created inside this folder. + generators: Dictionary mapping split names to generator functions + that yield Sample objects. + variable_schema: Schema describing the structure and types of + variables/features in the samples. + gen_kwargs: Optional generator arguments for parallel processing. + Must include "shards_ids" for each split when num_proc > 1. + num_proc: Number of processes to use for parallel processing. + Defaults to 1 (sequential). + verbose: Whether to display progress bars during processing. + + Returns: + None: Writes the dataset directly to disk. + """ + assert (gen_kwargs is None and num_proc == 1) or ( + gen_kwargs is not None and num_proc > 1 + ), ( + "Invalid configuration: either provide only `generators` with " + "`num_proc == 1`, or provide `gen_kwargs` with " + "`num_proc > 1`." + ) + + output_folder = Path(output_folder) / "data" + output_folder.mkdir(exist_ok=True, parents=True) + + var_features_keys = list(variable_schema.keys()) + + def worker_batch( + tar_path: str, + gen_func: Callable[..., Generator[Sample, None, None]], + var_features_keys: list[str], + batch: list[IndexType], + start_index: int, + queue: mp.Queue, + ) -> None: # pragma: no cover + """Processes a single batch and writes samples to tar. + + Args: + tar_path: Path to the tar file for the split. + gen_func: Generator function for samples. + var_features_keys: List of feature keys. + batch: Batch of sample IDs. + start_index: Starting sample index. + queue: Queue for progress tracking. + """ + # Create tar writer for this batch (will be appended to main tar) + temp_tar = tar_path.replace(".tar", f"_batch_{start_index}.tar") + # Open file explicitly to handle Windows paths + with open(temp_tar, 'wb') as f: + with wds.TarWriter(f) as tar_writer: + sample_counter = start_index + + for sample in gen_func([batch]): + _write_sample_to_tar( + tar_writer, sample, var_features_keys, sample_counter + ) + sample_counter += 1 + queue.put(1) + + def tqdm_updater( + total: int, queue: mp.Queue, desc: str = "Processing" + ) -> None: # pragma: no cover + """Tqdm process that listens to the queue to update progress. + + Args: + total: Total number of items to process. + queue: Queue to receive progress updates. + desc: Description for the progress bar. + """ + with tqdm(total=total, desc=desc, disable=not verbose) as pbar: + finished = 0 + while finished < total: + finished += queue.get() + pbar.update(1) + + for split_name, gen_func in generators.items(): + tar_path = str(output_folder / f"{split_name}.tar") + + gen_kwargs_ = gen_kwargs or {sn: {} for sn in generators.keys()} + batch_ids_list = gen_kwargs_.get(split_name, {}).get("shards_ids", []) + + total_samples = ( + sum(len(batch) for batch in batch_ids_list) if batch_ids_list else 0 + ) + + if num_proc > 1 and batch_ids_list: # pragma: no cover + # Parallel execution + queue = mp.Queue() + tqdm_proc = mp.Process( + target=tqdm_updater, + args=(total_samples, queue, f"Writing {split_name} split"), + ) + tqdm_proc.start() + + processes = [] + start_index = 0 + temp_tars = [] + + for batch in batch_ids_list: + temp_tar = tar_path.replace(".tar", f"_batch_{start_index}.tar") + temp_tars.append(temp_tar) + p = mp.Process( + target=worker_batch, + args=( + tar_path, + gen_func, + var_features_keys, + batch, + start_index, + queue, + ), + ) + p.start() + processes.append(p) + start_index += len(batch) + + for p in processes: + p.join() + + tqdm_proc.join() + + # Merge temporary tar files + # Open file explicitly to handle Windows paths + with open(tar_path, 'wb') as f: + with wds.TarWriter(f) as main_tar: + for temp_tar in temp_tars: + if Path(temp_tar).exists(): + with wds.ShardList([temp_tar]) as shard: + for sample in shard: + main_tar.write(sample) + Path(temp_tar).unlink() + + else: + # Sequential execution + sample_counter = 0 + + # Determine total for progress bar + if not batch_ids_list: + # No batch info, estimate or skip progress + total_samples = None + + # Open file explicitly to handle Windows paths + with open(tar_path, 'wb') as f: + with wds.TarWriter(f) as tar_writer: + with tqdm( + total=total_samples, + desc=f"Writing {split_name} split", + disable=not verbose, + ) as pbar: + for sample in gen_func(): + _write_sample_to_tar( + tar_writer, sample, var_features_keys, sample_counter + ) + sample_counter += 1 + if total_samples is not None: + pbar.update(1) + + +def push_local_datasetdict_to_hub( + repo_id: str, local_dir: Union[str, Path], num_workers: int = 1 +) -> None: # pragma: no cover + """Pushes a local dataset directory to Hugging Face Hub. + + This function uploads the contents of a local directory to a specified + Hugging Face repository as a dataset. It uses the HfApi to handle large + folder uploads with configurable parallelism. + + Args: + repo_id: The Hugging Face repository ID where the dataset will be uploaded. + local_dir: Path to the local directory containing the dataset files to upload. + num_workers: Number of worker threads to use for uploading. Defaults to 1. + + Returns: + None: Uploads the dataset directly to Hugging Face Hub. + """ + api = HfApi() + api.upload_large_folder( + folder_path=local_dir, + repo_id=repo_id, + repo_type="dataset", + num_workers=num_workers, + ignore_patterns=["*.tmp"], + allow_patterns=["data/*.tar", "*.yaml", "*.yml", "*.json", "README.md"], + ) + + +def configure_dataset_card( + repo_id: str, + infos: dict[str, dict[str, str]], + local_dir: Union[str, Path], + viewer: Optional[bool] = None, # noqa: ARG001 + pretty_name: Optional[str] = None, + dataset_long_description: Optional[str] = None, + illustration_urls: Optional[list[str]] = None, + arxiv_paper_urls: Optional[list[str]] = None, +) -> None: # pragma: no cover + """Configures and pushes a dataset card to Hugging Face Hub for a WebDataset backend. + + This function generates a dataset card in YAML format with metadata, features, + splits information, and usage examples. It automatically detects splits and + sample counts from the local directory structure, then pushes the card to + the specified Hugging Face repository. + + Args: + repo_id: The Hugging Face repository ID where the dataset card will be pushed. + infos: Dictionary containing dataset metadata, including legal information. + local_dir: Path to the local directory containing the dataset files. + viewer: Unused parameter for viewer configuration. + pretty_name: A human-readable name for the dataset. + dataset_long_description: A detailed description of the dataset. + illustration_urls: List of URLs to images that illustrate the dataset. + arxiv_paper_urls: List of arXiv URLs for papers related to the dataset. + + Returns: + None: Pushes the dataset card directly to Hugging Face Hub. + """ + dataset_card_str = """--- +task_categories: +- graph-ml +tags: +- physics learning +- geometry learning +--- +""" + local_folder = Path(local_dir) + + # Detect tar files in data directory + data_dir = local_folder / "data" + if not data_dir.exists(): + raise ValueError(f"Data directory not found: {data_dir}") + + tar_files = list(data_dir.glob("*.tar")) + split_names = [f.stem for f in tar_files] + + # Count samples and compute sizes + nbe_samples = {} + num_bytes = {} + size_bytes = 0 + + for tar_file in tar_files: + split_name = tar_file.stem + + # Count samples in tar + sample_count = 0 + with wds.WebDataset(str(tar_file)) as dataset: + for _ in dataset: + sample_count += 1 + + nbe_samples[split_name] = sample_count + num_bytes[split_name] = tar_file.stat().st_size + size_bytes += num_bytes[split_name] + + lines = dataset_card_str.splitlines() + lines = [s for s in lines if not s.startswith("license")] + + indices = [i for i, line in enumerate(lines) if line.strip() == "---"] + + assert len(indices) >= 2, ( + "Cannot find two instances of '---', dataset card format error." + ) + lines = lines[: indices[1] + 1] + + count = 6 + lines.insert(count, f"license: {infos['legal']['license']}") + count += 1 + lines.insert(count, "viewer: false") + count += 1 + if pretty_name: + lines.insert(count, f"pretty_name: {pretty_name}") + count += 1 + + lines.insert(count, "dataset_info:") + count += 1 + lines.insert(count, " splits:") + count += 1 + for sn in split_names: + lines.insert(count, f" - name: {sn}") + count += 1 + lines.insert(count, f" num_bytes: {num_bytes[sn]}") + count += 1 + lines.insert(count, f" num_examples: {nbe_samples[sn]}") + count += 1 + lines.insert(count, f" download_size: {size_bytes}") + count += 1 + lines.insert(count, f" dataset_size: {size_bytes}") + count += 1 + lines.insert(count, "configs:") + count += 1 + lines.insert(count, "- config_name: default") + count += 1 + lines.insert(count, " data_files:") + count += 1 + for sn in split_names: + lines.insert(count, f" - split: {sn}") + count += 1 + lines.insert(count, f" path: data/{sn}.tar") + count += 1 + + str__ = "\n".join(lines) + "\n" + + if illustration_urls: + str__ += "

\n" + for url in illustration_urls: + str__ += f"{url}\n" + str__ += "

\n\n" + + str__ += f"```yaml\n{yaml.dump(infos, sort_keys=False, allow_unicode=True)}\n```" + + str__ += """ +This dataset was generated with [`plaid`](https://plaid-lib.readthedocs.io/) using the WebDataset backend, +we refer to this documentation for additional details on how to extract data from `plaid_sample` objects. + +The simplest way to use this dataset is to first download it: +```python +from plaid.storage import download_from_hub + +repo_id = "channel/dataset" +local_folder = "downloaded_dataset" + +download_from_hub(repo_id, local_folder, backend="webdataset") +``` + +Then, to iterate over the dataset and instantiate samples: +```python +from plaid.storage import init_from_disk + +local_folder = "downloaded_dataset" +split_name = "train" + +datasetdict, converterdict = init_from_disk(local_folder, backend="webdataset") + +dataset = datasetdict[split_name] +converter = converterdict[split_name] + +for i in range(len(dataset)): + plaid_sample = converter.to_plaid(dataset, i) +``` + +It is possible to stream the data directly: +```python +from plaid.storage import init_streaming_from_hub + +repo_id = "channel/dataset" + +datasetdict, converterdict = init_streaming_from_hub(repo_id, backend="webdataset") + +dataset = datasetdict[split_name] +converter = converterdict[split_name] + +for sample_raw in dataset: + plaid_sample = converter.sample_to_plaid(sample_raw) +``` + +Plaid samples' features can be retrieved like the following: +```python +from plaid.storage import load_problem_definitions_from_disk +local_folder = "downloaded_dataset" +pb_defs = load_problem_definitions_from_disk(local_folder) + +# or +from plaid.storage import load_problem_definitions_from_hub +repo_id = "channel/dataset" +pb_defs = load_problem_definitions_from_hub(repo_id) + +pb_def = pb_defs[0] + +plaid_sample = ... # use a method from above to instantiate a plaid sample + +for t in plaid_sample.get_all_time_values(): + for path in pb_def.get_in_features_identifiers(): + plaid_sample.get_feature_by_path(path=path, time=t) + for path in pb_def.get_out_features_identifiers(): + plaid_sample.get_feature_by_path(path=path, time=t) +``` +""" + + if dataset_long_description: + str__ += f""" +### Dataset Description +{dataset_long_description} +""" + + if arxiv_paper_urls: + str__ += """ +### Dataset Sources + +- **Papers:** +""" + for url in arxiv_paper_urls: + str__ += f" - [arxiv]({url})\n" + + dataset_card = DatasetCard(str__) + dataset_card.push_to_hub(repo_id) diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index 5d5a7018..badbc111 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -281,6 +281,57 @@ def test_zarr(self, tmp_path, generator_split, infos, problem_definition): with pytest.raises(KeyError): converter.to_dict(dataset, 0, features=["dummy"]) + def test_webdataset(self, tmp_path, generator_split, infos, problem_definition): + test_dir = tmp_path / "test_webdataset" + + save_to_disk( + output_folder=test_dir, + generators=generator_split, + backend="webdataset", + infos=infos, + pb_defs={"pb_def": problem_definition}, + overwrite=True, + verbose=True, + ) + + datasetdict, converterdict = init_from_disk(test_dir) + + dataset = datasetdict["train"] + converter = converterdict["train"] + + plaid_sample = converter.to_plaid(dataset, 0) + self.assert_sample(plaid_sample) + plaid_sample = converter.sample_to_plaid(dataset[0]) + self.assert_sample(plaid_sample) + + converter.plaid_to_dict(plaid_sample) + + # coverage of WebDatasetWrapper class + for sample in dataset: + sample + len(dataset) + dataset.ids + + for t in plaid_sample.get_all_time_values(): + for path in problem_definition.get_in_features_identifiers(): + plaid_sample.get_feature_by_path(path=path, time=t) + for path in problem_definition.get_out_features_identifiers(): + plaid_sample.get_feature_by_path(path=path, time=t) + + converter.to_dict(dataset, 0) + converter.sample_to_dict(dataset[0]) + + converter.to_dict( + dataset, + 0, + features=[ + "TestBaseName/TestZoneName/VertexFields/test_field_same_size", + "Global/global_0", + ], + ) + with pytest.raises(KeyError): + converter.to_dict(dataset, 0, features=["dummy"]) + def test_cgns(self, tmp_path, generator_split, infos, problem_definition): test_dir = tmp_path / "test_cgns" @@ -332,6 +383,7 @@ def test_registry(self): assert "hf_datasets" in backends assert "zarr" in backends assert "cgns" in backends + assert "webdataset" in backends hf_module = registry.get_backend("hf_datasets") assert hf_module is not None @@ -342,5 +394,8 @@ def test_registry(self): cgns_module = registry.get_backend("cgns") assert cgns_module is not None + webdataset_module = registry.get_backend("webdataset") + assert webdataset_module is not None + with pytest.raises(ValueError): _ = registry.get_backend("non_existent_backend")