diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5f31a3f..aebc387 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,8 +54,8 @@ for each new feature simplifies the development, review and merge processes by maintining logical separation. To create a feature branch: ```bash - git fetch agrifoodpy - git checkout -b agrifoodpy/main + git fetch afp + git checkout -b afp/main ``` ### Hack away! diff --git a/agrifoodpy/pipeline/pipeline.py b/agrifoodpy/pipeline/pipeline.py index 06e98b0..10c0e41 100644 --- a/agrifoodpy/pipeline/pipeline.py +++ b/agrifoodpy/pipeline/pipeline.py @@ -32,6 +32,20 @@ def _load_function(path): module = importlib.import_module(module_path) return getattr(module, func_name) + @staticmethod + def _is_supported_yaml_function(path): + """Return True for dotted numpy/xarray function paths.""" + if not isinstance(path, str) or "." not in path: + return False + + module_path, _ = path.rsplit(".", 1) + return ( + module_path == "numpy" + or module_path.startswith("numpy.") + or module_path == "xarray" + or module_path.startswith("xarray.") + ) + @classmethod def read(cls, filename): """Read a pipeline configuration from a YAML file @@ -47,6 +61,52 @@ def read(cls, filename): The pipeline object. """ + def dynamic_call_constructor(package_name): + """Build a multi-constructor for supported package functions.""" + + def constructor(loader, suffix, node): + func_path = f"{package_name}.{suffix}" if suffix else package_name + + # Check if the function path is supported + if not cls._is_supported_yaml_function(func_path): + raise yaml.constructor.ConstructorError( + None, + None, + f"Unsupported YAML function tag '!{func_path}'.", + node.start_mark, + ) + + func = cls._load_function(func_path) + + if isinstance(node, yaml.ScalarNode): + return func + if isinstance(node, yaml.SequenceNode): + args = loader.construct_sequence(node, deep=True) + return func(*args) + if isinstance(node, yaml.MappingNode): + kwargs = loader.construct_mapping(node, deep=True) + return func(**kwargs) + + raise yaml.constructor.ConstructorError( + None, + None, + f"Unsupported YAML node type for '!{func_path}'.", + node.start_mark, + ) + + return constructor + + yaml.add_multi_constructor( + "!numpy.", + dynamic_call_constructor("numpy"), + Loader=yaml.FullLoader, + ) + yaml.add_multi_constructor( + "!xarray.", + dynamic_call_constructor("xarray"), + Loader=yaml.FullLoader, + ) + with open(filename, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) diff --git a/agrifoodpy/pipeline/tests/data/test_config_numpy_array.yaml b/agrifoodpy/pipeline/tests/data/test_config_numpy_array.yaml new file mode 100644 index 0000000..2a8fdcf --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config_numpy_array.yaml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + name: Numpy Array + params: + key: "test_numpy_array" + value: !numpy.array [[1, 2, 3]] \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/data/test_config_numpy_array_kwargs.yaml b/agrifoodpy/pipeline/tests/data/test_config_numpy_array_kwargs.yaml new file mode 100644 index 0000000..82ab817 --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config_numpy_array_kwargs.yaml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + name: Numpy Array + params: + key: "test_numpy_array" + value: !numpy.array {object: [1, 2, 3]} \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/data/test_config_unsupported_function.yaml b/agrifoodpy/pipeline/tests/data/test_config_unsupported_function.yaml new file mode 100644 index 0000000..995e9ed --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config_unsupported_function.yaml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + name: Unsupported Function + params: + key: "test_pandas_array" + value: !pandas.DataFrame {data: [1, 2, 3], columns: ["A", "B", "C"]} \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray.yaml b/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray.yaml new file mode 100644 index 0000000..71f9a71 --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray.yaml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + name: Xarray DataArray + params: + key: "test_value" + value: !xarray.DataArray [[1,2,3] , {Year: [2020, 2021, 2022]}, "Year"] \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray_kwargs.yaml b/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray_kwargs.yaml new file mode 100644 index 0000000..7da04f6 --- /dev/null +++ b/agrifoodpy/pipeline/tests/data/test_config_xarray_dataarray_kwargs.yaml @@ -0,0 +1,6 @@ +nodes: + - function: agrifoodpy.utils.nodes.write_to_datablock + name: Xarray DataArray + params: + key: "test_value" + value: !xarray.DataArray {data: [1, 2, 3], dims: ["Year"], coords: {Year: [2020, 2021, 2022]}} \ No newline at end of file diff --git a/agrifoodpy/pipeline/tests/test_pipeline.py b/agrifoodpy/pipeline/tests/test_pipeline.py index 413149f..6f05ad2 100644 --- a/agrifoodpy/pipeline/tests/test_pipeline.py +++ b/agrifoodpy/pipeline/tests/test_pipeline.py @@ -1,5 +1,8 @@ from agrifoodpy.pipeline import Pipeline, standalone +import numpy as np +import xarray as xr import pytest +import os def test_init(): pipeline = Pipeline() @@ -359,3 +362,56 @@ def reserved_param_node(x, datablock=None): @pipeline_node(['wrong_key']) def unknown_input_node(right_key): pass + + +# Test reading YAML config with numpy array parameters and values +def test_read_yaml_numpy_array(): + + script_dir = os.path.dirname(__file__) + config_path = os.path.join(script_dir, "data/test_config_numpy_array.yaml") + + pipeline = Pipeline.read(str(config_path)) + pipeline.run() + + assert np.array_equal(pipeline.params[0]['value'], np.array([1, 2, 3])) + assert np.array_equal(pipeline.datablock["test_numpy_array"], np.array([1, 2, 3])) + +def test_read_yaml_numpy_array_kwargs(): + script_dir = os.path.dirname(__file__) + config_path = os.path.join(script_dir, "data/test_config_numpy_array_kwargs.yaml") + + pipeline = Pipeline.read(str(config_path)) + pipeline.run() + + assert np.array_equal(pipeline.params[0]['value'], np.array([1, 2, 3])) + assert np.array_equal(pipeline.datablock["test_numpy_array"], np.array([1, 2, 3])) + +def test_read_yaml_xarray_dataarray(): + script_dir = os.path.dirname(__file__) + config_path = os.path.join(script_dir, "data/test_config_xarray_dataarray.yaml") + + pipeline = Pipeline.read(str(config_path)) + pipeline.run() + + expected_array = xr.DataArray([1, 2, 3], coords={"Year": [2020, 2021, 2022]}, dims=["Year"]) + xr.testing.assert_equal(pipeline.params[0]['value'], expected_array) + xr.testing.assert_equal(pipeline.datablock["test_value"], expected_array) + +def test_read_yaml_xarray_dataarray_kwargs(): + script_dir = os.path.dirname(__file__) + config_path = os.path.join(script_dir, "data/test_config_xarray_dataarray_kwargs.yaml") + + pipeline = Pipeline.read(str(config_path)) + pipeline.run() + + expected_array = xr.DataArray([1, 2, 3], coords={"Year": [2020, 2021, 2022]}, dims=["Year"]) + xr.testing.assert_equal(pipeline.params[0]['value'], expected_array) + xr.testing.assert_equal(pipeline.datablock["test_value"], expected_array) + +def test_read_yaml_unsupported_function(): + from yaml.constructor import ConstructorError + script_dir = os.path.dirname(__file__) + config_path = os.path.join(script_dir, "data/test_config_unsupported_function.yaml") + + with pytest.raises(ConstructorError): + pipeline = Pipeline.read(str(config_path)) \ No newline at end of file diff --git a/docs/config_file.rst b/docs/config_file.rst index 63eac7b..7e09384 100644 --- a/docs/config_file.rst +++ b/docs/config_file.rst @@ -4,7 +4,7 @@ Command line tool ================= The ``agrifoodpy`` command line tool allows you to run a pipeline of functions -defined in a configuration file. This is useful for automating workflows and +defined in a YAML configuration file. This is useful for automating workflows and reproducibility. You can specify the configuration file and an output file for the results.