diff --git a/kwave/enums.py b/kwave/enums.py index d17fdb3d..931cbea5 100644 --- a/kwave/enums.py +++ b/kwave/enums.py @@ -1,5 +1,17 @@ from enum import Enum + +class AlphaMode(str, Enum): + """Controls which absorption/dispersion terms are included in the equation of state.""" + + NO_ABSORPTION = "no_absorption" + NO_DISPERSION = "no_dispersion" + STOKES = "stokes" + + def __str__(self): + return self.value + + ################################################################ # literals that link the discrete cosine and sine transform types with # their type definitions in the functions dtt1D, dtt2D, and dtt3D diff --git a/kwave/kmedium.py b/kwave/kmedium.py index 2886299d..84b7186a 100644 --- a/kwave/kmedium.py +++ b/kwave/kmedium.py @@ -1,10 +1,11 @@ import logging from dataclasses import dataclass -from typing import List +from typing import List, Optional, Union import numpy as np import kwave.utils.checks +from kwave.enums import AlphaMode @dataclass @@ -20,8 +21,8 @@ class kWaveMedium(object): # power law absorption exponent alpha_power: np.array = None # optional input to force either the absorption or dispersion terms in the equation of state to be excluded; - # valid inputs are 'no_absorption' or 'no_dispersion' - alpha_mode: np.array = None + # valid inputs are AlphaMode.NO_ABSORPTION, AlphaMode.NO_DISPERSION, or the equivalent strings + alpha_mode: Optional[Union[AlphaMode, str]] = None # frequency domain filter applied to the absorption and dispersion terms in the equation of state alpha_filter: np.array = None # two element array used to control the sign of absorption and dispersion terms in the equation of state @@ -43,6 +44,8 @@ class kWaveMedium(object): def __post_init__(self): self.sound_speed = np.atleast_1d(self.sound_speed) + if isinstance(self.alpha_mode, str) and not isinstance(self.alpha_mode, AlphaMode): + self.alpha_mode = AlphaMode(self.alpha_mode) def check_fields(self, kgrid_shape: np.ndarray) -> None: """ @@ -54,13 +57,12 @@ def check_fields(self, kgrid_shape: np.ndarray) -> None: Returns: None """ - # check the absorption mode input is valid - if self.alpha_mode is not None: - assert self.alpha_mode in [ - "no_absorption", - "no_dispersion", - "stokes", - ], "medium.alpha_mode must be set to 'no_absorption', 'no_dispersion', or 'stokes'." + # check the absorption mode input is valid (already normalized to AlphaMode in __post_init__) + if self.alpha_mode is not None and not isinstance(self.alpha_mode, AlphaMode): + raise ValueError( + f"medium.alpha_mode must be an AlphaMode enum value or one of " + f"'no_absorption', 'no_dispersion', 'stokes', got {self.alpha_mode!r}" + ) # check the absorption filter input is valid if self.alpha_filter is not None and not (self.alpha_filter.shape == kgrid_shape).all():