Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions kwave/enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from enum import Enum


class AlphaMode(str, Enum):
"""Controls which absorption/dispersion terms are included in the equation of state."""

NO_ABSORPTION = "no_absorption"
NO_DISPERSION = "no_dispersion"
STOKES = "stokes"

def __str__(self):
return self.value


################################################################
# literals that link the discrete cosine and sine transform types with
# their type definitions in the functions dtt1D, dtt2D, and dtt3D
Expand Down
22 changes: 12 additions & 10 deletions kwave/kmedium.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from dataclasses import dataclass
from typing import List
from typing import List, Optional, Union

import numpy as np

import kwave.utils.checks
from kwave.enums import AlphaMode


@dataclass
Expand All @@ -20,8 +21,8 @@ class kWaveMedium(object):
# power law absorption exponent
alpha_power: np.array = None
# optional input to force either the absorption or dispersion terms in the equation of state to be excluded;
# valid inputs are 'no_absorption' or 'no_dispersion'
alpha_mode: np.array = None
# valid inputs are AlphaMode.NO_ABSORPTION, AlphaMode.NO_DISPERSION, or the equivalent strings
alpha_mode: Optional[Union[AlphaMode, str]] = None
# frequency domain filter applied to the absorption and dispersion terms in the equation of state
alpha_filter: np.array = None
# two element array used to control the sign of absorption and dispersion terms in the equation of state
Expand All @@ -43,6 +44,8 @@ class kWaveMedium(object):

def __post_init__(self):
self.sound_speed = np.atleast_1d(self.sound_speed)
if isinstance(self.alpha_mode, str) and not isinstance(self.alpha_mode, AlphaMode):
self.alpha_mode = AlphaMode(self.alpha_mode)
Comment on lines +47 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Invalid string raises unhelpful ValueError from enum constructor

When an invalid string is passed at construction time (e.g. kWaveMedium(sound_speed=1500, alpha_mode="typo")), AlphaMode("typo") raises ValueError: 'typo' is not a valid AlphaMode — the more descriptive message in check_fields is never reached. Consider wrapping the conversion to re-raise with a clearer hint:

if isinstance(self.alpha_mode, str) and not isinstance(self.alpha_mode, AlphaMode):
    try:
        self.alpha_mode = AlphaMode(self.alpha_mode)
    except ValueError:
        valid = [m.value for m in AlphaMode]
        raise ValueError(
            f"medium.alpha_mode must be one of {valid}, got {self.alpha_mode!r}"
        )


def check_fields(self, kgrid_shape: np.ndarray) -> None:
"""
Expand All @@ -54,13 +57,12 @@ def check_fields(self, kgrid_shape: np.ndarray) -> None:
Returns:
None
"""
# check the absorption mode input is valid
if self.alpha_mode is not None:
assert self.alpha_mode in [
"no_absorption",
"no_dispersion",
"stokes",
], "medium.alpha_mode must be set to 'no_absorption', 'no_dispersion', or 'stokes'."
# check the absorption mode input is valid (already normalized to AlphaMode in __post_init__)
if self.alpha_mode is not None and not isinstance(self.alpha_mode, AlphaMode):
raise ValueError(
f"medium.alpha_mode must be an AlphaMode enum value or one of "
f"'no_absorption', 'no_dispersion', 'stokes', got {self.alpha_mode!r}"
)
Comment on lines +61 to +65
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Post-construction string assignment silently breaks check_fields

check_fields is called from kWaveSimulation after the object is fully constructed (line 610 in kWaveSimulation.py). Since kWaveMedium is a plain dataclass, its attributes are mutable — a caller can legally write medium.alpha_mode = "no_dispersion" after __post_init__ runs, bypassing the normalization step. Under the old code that assignment would still pass the in [...] check; under this PR it raises a ValueError even for a valid string value, silently breaking previously-working patterns.

The type hint Optional[Union[AlphaMode, str]] reinforces the expectation that plain strings remain valid inputs throughout the object's lifetime. Fix by normalising strings in check_fields as well, or by converting alpha_mode to a property setter so normalisation is always applied on write:

# check the absorption mode input is valid (already normalized to AlphaMode in __post_init__)
if self.alpha_mode is not None:
    if isinstance(self.alpha_mode, str) and not isinstance(self.alpha_mode, AlphaMode):
        self.alpha_mode = AlphaMode(self.alpha_mode)  # late-normalise post-construction strings
    elif not isinstance(self.alpha_mode, AlphaMode):
        raise ValueError(
            f"medium.alpha_mode must be an AlphaMode enum value or one of "
            f"'no_absorption', 'no_dispersion', 'stokes', got {self.alpha_mode!r}"
        )


# check the absorption filter input is valid
if self.alpha_filter is not None and not (self.alpha_filter.shape == kgrid_shape).all():
Expand Down
Loading