Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def from_matched_data(
quantity=quantity,
)
data.attrs["weight"] = weight

# FIXME: Consider changes needed for vertical
return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

def __repr__(self):
Expand Down Expand Up @@ -1228,6 +1230,8 @@ def save(self, filename: Union[str, Path]) -> None:
"""
ds = self.data

# FIXME: Consider changes needed for vertical

# add self.raw_mod_data to ds with prefix 'raw_' to avoid name conflicts
# an alternative strategy would be to use NetCDF groups
# https://docs.xarray.dev/en/stable/user-guide/io.html#groups
Expand Down Expand Up @@ -1265,7 +1269,8 @@ def load(filename: Union[str, Path]) -> "Comparer":
return Comparer(matched_data=data)

if data.gtype == "vertical":
return Comparer(matched_data=data) # FIXME: consider during Phase3
# FIXME: consider during Phase3
return Comparer(matched_data=data)

if data.gtype == "point":
raw_mod_data: Dict[
Expand Down
6 changes: 2 additions & 4 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,8 @@ def _match_space_time(
)
case PointModelResult() as pmr, PointObservation():
aligned = pmr.align(observation, max_gap=max_model_gap)
case VerticalModelResult(), VerticalObservation():
raise NotImplementedError("Vertical matching not implemented yet!")
# aligned = vmr.align(observation, max_gap=max_model_gap)
case VerticalModelResult() as vmr, VerticalObservation():
aligned = vmr.align(observation)
case _:
raise TypeError(
f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}"
Expand All @@ -393,7 +392,6 @@ def _match_space_time(
raise ValueError(
f"Aux variables are not allowed to have identical names. Choose either aux from obs or model. Overlapping: {overlapping}"
)

for dv in aligned:
data[dv] = aligned[dv]

Expand Down
18 changes: 15 additions & 3 deletions src/modelskill/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .point import PointModelResult
from .track import TrackModelResult
from .vertical import VerticalModelResult
from .dfsu import DfsuModelResult
from .grid import GridModelResult

Expand All @@ -16,6 +17,7 @@
_modelresult_lookup = {
GeometryType.POINT: PointModelResult,
GeometryType.TRACK: TrackModelResult,
GeometryType.VERTICAL: VerticalModelResult,
GeometryType.UNSTRUCTURED: DfsuModelResult,
GeometryType.GRID: GridModelResult,
}
Expand All @@ -25,9 +27,17 @@ def model_result(
data: DataInputType,
*,
aux_items: Optional[list[int | str]] = None,
gtype: Optional[Literal["point", "track", "unstructured", "grid"]] = None,
gtype: Optional[
Literal["point", "track", "vertical", "unstructured", "grid"]
] = None,
**kwargs: Any,
) -> PointModelResult | TrackModelResult | DfsuModelResult | GridModelResult:
) -> (
PointModelResult
| TrackModelResult
| VerticalModelResult
| DfsuModelResult
| GridModelResult
):
"""A factory function for creating an appropriate object based on the data input.

Parameters
Expand All @@ -36,7 +46,7 @@ def model_result(
The data to be used for creating the ModelResult object.
aux_items : Optional[list[int | str]]
Auxiliary items, by default None
gtype : Optional[Literal["point", "track", "unstructured", "grid"]]
gtype : Optional[Literal["point", "track", "vertical", "unstructured", "grid"]]
The geometry type of the data. If not specified, it will be guessed from the data.
**kwargs
Additional keyword arguments to be passed to the ModelResult constructor.
Expand All @@ -48,6 +58,8 @@ def model_result(
<DfsuModelResult> 'Oresund2D'
>>> ms.model_result("ERA5_DutchCoast.nc", item="swh", name="ERA5")
<GridModelResult> 'ERA5'
>>> ms.model_result("VerticalProfile_obs1.dfs0", z_item="z", item="Salinity", name="vmod", gtype="vertical")
<VerticalModelResult> 'vmod'
"""
if gtype is None:
geometry = _guess_gtype(data)
Expand Down
98 changes: 98 additions & 0 deletions src/modelskill/model/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from typing import Any, Literal, Sequence

import xarray as xr
import pandas as pd
import numpy as np

from ..types import VerticalType
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_vertical_input
from ..obs import VerticalObservation


class VerticalModelResult(TimeSeries):
Expand Down Expand Up @@ -75,3 +78,98 @@ def __init__(
def z(self) -> Any:
"""z-coordinate"""
return self._coordinate_values("z")

def _match_to_nearest_times(
self, obs_df, mod_df, t_tol: pd.Timedelta | None = None
) -> pd.DataFrame:
obs_times = obs_df.index.unique().sort_values()
mod_times_unique = mod_df.index.unique().sort_values()

# get_indexer requires a unique, monotonic index - work on unique times first
idx = mod_times_unique.get_indexer(obs_times, method="nearest", tolerance=t_tol)
valid = idx != -1

matched_mod_times = mod_times_unique[idx[valid]]
obs_times_valid = obs_times[valid]

return pd.DataFrame(
{"obs_time": obs_times_valid, "mod_time": matched_mod_times}
)

def _interpolate_to_obs_depths(
self,
obs_df,
mod_df,
obs_times_valid,
matched_mod_times,
*,
obs_value_col: str,
mod_value_col: str,
) -> pd.DataFrame:
records = []

for obs_t, mod_t in zip(obs_times_valid, matched_mod_times):
obs_at_t = obs_df.loc[[obs_t]].sort_values("z")
mod_at_t = mod_df.loc[[mod_t]].sort_values("z")

obs_z = obs_at_t["z"].to_numpy(dtype=float)
mod_z = mod_at_t["z"].to_numpy(dtype=float)
mod_values = mod_at_t[mod_value_col].to_numpy(dtype=float)

if mod_z.size < 2:
continue

mod_interp = np.interp(obs_z, mod_z, mod_values, left=np.nan, right=np.nan)
for z, mod_v in zip(obs_z, mod_interp):
records.append({"time": obs_t, "z": z, self.name: mod_v})

if not records:
return pd.DataFrame(
columns=["z", self.name], index=pd.Index([], name="time")
)

return pd.DataFrame(records).set_index("time")

def align(
self, vo: VerticalObservation, temporal_tolerance: pd.Timedelta | None = None
) -> xr.Dataset:
"""Align model result to observation by matching nearest times and interpolating to observation depths.
Model depths outside the range of observation depths are extrapolated using nearest model values.

Parameters
----------
vo : VerticalObservation
Vertical observation to align with
temporal_tolerance : pd.Timedelta, optional
Maximum allowed time difference for matching, by default None

Returns
-------
xr.Dataset
Aligned model result

"""
# if temporal_tolerance is not given. Estimate on half the median time step of the model data.
if temporal_tolerance is None:
median_dt = self.time.unique().to_series().diff().median()
temporal_tolerance = median_dt / 2

matched_times = self._match_to_nearest_times(
vo.data[["z"]].to_dataframe(),
self.data[["z"]].to_dataframe(),
t_tol=temporal_tolerance,
)

pairs = self._interpolate_to_obs_depths(
vo.data.to_dataframe(),
self.data.to_dataframe(),
matched_times["obs_time"],
matched_times["mod_time"],
obs_value_col=vo.name,
mod_value_col=self.name,
)
# Convert to xarray Dataset and set kind attribute
xarr = pairs.reset_index().set_index(["time"]).to_xarray()
xarr[self.name].attrs["kind"] = "model"

return xarr
8 changes: 8 additions & 0 deletions tests/model/test_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def test_open_and_parse(self, request, input_fixture):
assert mr.x == pytest.approx(12.0)
assert mr.y == pytest.approx(55.0)

def test_open_with_factory(self, dfs0_fpath):
mr = ms.model_result(
dfs0_fpath, z_item="z", item="Salinity", name="test", gtype="vertical"
)
assert isinstance(mr, ms.VerticalModelResult)
assert mr.gtype == "vertical"
assert mr.name == "test"

# ================
# Test failing and optional args
# ================
Expand Down
66 changes: 66 additions & 0 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def o3():
return ms.TrackObservation(fn, item=3, name="c2")


@pytest.fixture
def o4():
fn = "tests/testdata/vertical/VerticalProfile_obs1.dfs0"
return ms.VerticalObservation(fn, z_item="z", name="vobs", x=657500, y=6553600)


@pytest.fixture
def mr12_gaps():
fn = "tests/testdata/SW/ts_storm_4.dfs0"
Expand Down Expand Up @@ -71,6 +77,66 @@ def mr3():
return ms.model_result(fn, item=0, name="SW_3")


@pytest.fixture
def mr4():
fn = "tests/testdata/vertical/VerticalModel_at_obs.dfs0"
return ms.model_result(fn, item="Salinity", name="vmod", gtype="vertical")


@pytest.fixture
def mr5():
fn = "tests/testdata/vertical/sigma_z_coast.dfsu"
return ms.model_result(fn, item="Salinity", name="3dmod")


class TestVerticalObservation:
# ============
# Check vartions of match with vertical data
def test_match_dfs0_dfs0(self, o4, mr4):
cmp = ms.match(o4, mr4)
assert cmp.n_models == 1
assert cmp.n_points > 0
assert cmp.x == pytest.approx(657500)
assert cmp.y == pytest.approx(6553600)
assert cmp.z is not None
assert cmp.name == "vobs"
assert cmp.gtype == "vertical"
assert cmp.mod_names == ["vmod"]

def test_match_dfsu_dfs0(self, o4, mr5):
cmp = ms.match(o4, mr5)
assert cmp.n_models == 1
assert cmp.n_points > 0
assert cmp.x == pytest.approx(657500)
assert cmp.y == pytest.approx(6553600)
assert cmp.z is not None
assert cmp.name == "vobs"
assert cmp.gtype == "vertical"
assert cmp.mod_names == ["3dmod"]

def test_match_multiple(self, o4, mr4, mr5):
cmp = ms.match(o4, [mr4, mr5])
assert cmp.n_models == 2
assert cmp.x == pytest.approx(657500)
assert cmp.y == pytest.approx(6553600)
assert cmp.z is not None
assert cmp.name == "vobs"
assert cmp.gtype == "vertical"
assert cmp.mod_names == ["vmod", "3dmod"]

# ==========
# Test from_matched
# ==========

# ==========
# Test slicing
# ==========

# ==========
# Test correct results
# ==========


def test_properties_after_match(o1, mr1):
cmp = ms.match(o1, mr1)
assert cmp.n_models == 1
Expand Down