Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ for each new feature simplifies the development, review and merge processes by
maintining logical separation. To create a feature branch:

```bash
git fetch agrifoodpy
git checkout -b <your-branch-name> agrifoodpy/main
git fetch afp
git checkout -b <your-branch-name> afp/main
```

### Hack away!
Expand Down
60 changes: 60 additions & 0 deletions agrifoodpy/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def _load_function(path):
module = importlib.import_module(module_path)
return getattr(module, func_name)

@staticmethod
def _is_supported_yaml_function(path):
"""Return True for dotted numpy/xarray function paths."""
if not isinstance(path, str) or "." not in path:
return False

module_path, _ = path.rsplit(".", 1)
return (
module_path == "numpy"
or module_path.startswith("numpy.")
or module_path == "xarray"
or module_path.startswith("xarray.")
)
Comment on lines +35 to +47

@classmethod
def read(cls, filename):
"""Read a pipeline configuration from a YAML file
Expand All @@ -47,6 +61,52 @@ def read(cls, filename):
The pipeline object.
"""

def dynamic_call_constructor(package_name):
"""Build a multi-constructor for supported package functions."""

def constructor(loader, suffix, node):
func_path = f"{package_name}.{suffix}" if suffix else package_name

# Check if the function path is supported
if not cls._is_supported_yaml_function(func_path):
raise yaml.constructor.ConstructorError(
None,
None,
f"Unsupported YAML function tag '!{func_path}'.",
node.start_mark,
)

func = cls._load_function(func_path)

if isinstance(node, yaml.ScalarNode):
return func
if isinstance(node, yaml.SequenceNode):
args = loader.construct_sequence(node, deep=True)
return func(*args)
if isinstance(node, yaml.MappingNode):
kwargs = loader.construct_mapping(node, deep=True)
return func(**kwargs)

Comment on lines +79 to +89
raise yaml.constructor.ConstructorError(
None,
None,
f"Unsupported YAML node type for '!{func_path}'.",
node.start_mark,
)

return constructor

yaml.add_multi_constructor(
"!numpy.",
dynamic_call_constructor("numpy"),
Loader=yaml.FullLoader,
)
yaml.add_multi_constructor(
"!xarray.",
dynamic_call_constructor("xarray"),
Loader=yaml.FullLoader,
)
Comment on lines +99 to +108

with open(filename, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

Expand Down
6 changes: 6 additions & 0 deletions agrifoodpy/pipeline/tests/data/test_config_numpy_array.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nodes:
- function: agrifoodpy.utils.nodes.write_to_datablock
name: Numpy Array
params:
key: "test_numpy_array"
value: !numpy.array [[1, 2, 3]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nodes:
- function: agrifoodpy.utils.nodes.write_to_datablock
name: Numpy Array
params:
key: "test_numpy_array"
value: !numpy.array {object: [1, 2, 3]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nodes:
- function: agrifoodpy.utils.nodes.write_to_datablock
name: Unsupported Function
params:
key: "test_pandas_array"
value: !pandas.DataFrame {data: [1, 2, 3], columns: ["A", "B", "C"]}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nodes:
- function: agrifoodpy.utils.nodes.write_to_datablock
name: Xarray DataArray
params:
key: "test_value"
value: !xarray.DataArray [[1,2,3] , {Year: [2020, 2021, 2022]}, "Year"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
nodes:
- function: agrifoodpy.utils.nodes.write_to_datablock
name: Xarray DataArray
params:
key: "test_value"
value: !xarray.DataArray {data: [1, 2, 3], dims: ["Year"], coords: {Year: [2020, 2021, 2022]}}
56 changes: 56 additions & 0 deletions agrifoodpy/pipeline/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from agrifoodpy.pipeline import Pipeline, standalone
import numpy as np
import xarray as xr
import pytest
import os

def test_init():
pipeline = Pipeline()
Expand Down Expand Up @@ -359,3 +362,56 @@ def reserved_param_node(x, datablock=None):
@pipeline_node(['wrong_key'])
def unknown_input_node(right_key):
pass


# Test reading YAML config with numpy array parameters and values
def test_read_yaml_numpy_array():

script_dir = os.path.dirname(__file__)
config_path = os.path.join(script_dir, "data/test_config_numpy_array.yaml")

pipeline = Pipeline.read(str(config_path))
pipeline.run()

assert np.array_equal(pipeline.params[0]['value'], np.array([1, 2, 3]))
assert np.array_equal(pipeline.datablock["test_numpy_array"], np.array([1, 2, 3]))

def test_read_yaml_numpy_array_kwargs():
script_dir = os.path.dirname(__file__)
config_path = os.path.join(script_dir, "data/test_config_numpy_array_kwargs.yaml")

pipeline = Pipeline.read(str(config_path))
pipeline.run()

assert np.array_equal(pipeline.params[0]['value'], np.array([1, 2, 3]))
assert np.array_equal(pipeline.datablock["test_numpy_array"], np.array([1, 2, 3]))

def test_read_yaml_xarray_dataarray():
script_dir = os.path.dirname(__file__)
config_path = os.path.join(script_dir, "data/test_config_xarray_dataarray.yaml")

pipeline = Pipeline.read(str(config_path))
pipeline.run()

expected_array = xr.DataArray([1, 2, 3], coords={"Year": [2020, 2021, 2022]}, dims=["Year"])
xr.testing.assert_equal(pipeline.params[0]['value'], expected_array)
xr.testing.assert_equal(pipeline.datablock["test_value"], expected_array)

def test_read_yaml_xarray_dataarray_kwargs():
script_dir = os.path.dirname(__file__)
config_path = os.path.join(script_dir, "data/test_config_xarray_dataarray_kwargs.yaml")

pipeline = Pipeline.read(str(config_path))
pipeline.run()

expected_array = xr.DataArray([1, 2, 3], coords={"Year": [2020, 2021, 2022]}, dims=["Year"])
xr.testing.assert_equal(pipeline.params[0]['value'], expected_array)
xr.testing.assert_equal(pipeline.datablock["test_value"], expected_array)

def test_read_yaml_unsupported_function():
from yaml.constructor import ConstructorError
script_dir = os.path.dirname(__file__)
config_path = os.path.join(script_dir, "data/test_config_unsupported_function.yaml")

with pytest.raises(ConstructorError):
pipeline = Pipeline.read(str(config_path))
2 changes: 1 addition & 1 deletion docs/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Command line tool
=================

The ``agrifoodpy`` command line tool allows you to run a pipeline of functions
defined in a configuration file. This is useful for automating workflows and
defined in a YAML configuration file. This is useful for automating workflows and
reproducibility. You can specify the configuration file and an output file for
the results.

Expand Down
Loading