Skip to content

Commit 7725252

Browse files
bpiwowarclaude
andcommitted
fix: handle Union types in generic type argument validation
GenericType.validate() now correctly checks compatibility when a generic type argument is a Union (e.g., Converter[str, str] is accepted for Param[Converter[str, Union[str, List[str], ...]]]). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent cf54bf5 commit 7725252

2 files changed

Lines changed: 90 additions & 10 deletions

File tree

src/experimaestro/core/types.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,47 @@ def __repr__(self):
837837
return str(self)
838838

839839

840+
def _is_type_arg_compatible(actual, expected) -> bool:
841+
"""Check if an actual type argument is compatible with the expected one.
842+
843+
Handles Union types: actual is compatible with Union[X, Y, Z] if actual
844+
is a subtype of any member of the union.
845+
"""
846+
# Both concrete types: use subclass check
847+
if isinstance(expected, type) and isinstance(actual, type):
848+
return issubclass(actual, expected)
849+
850+
# Expected is a Union: actual must be compatible with at least one member
851+
if get_origin(expected) is typing.Union:
852+
return any(
853+
_is_type_arg_compatible(actual, member) for member in get_args(expected)
854+
)
855+
856+
# Actual is a Union: each member of actual must be compatible with expected
857+
if get_origin(actual) is typing.Union:
858+
return all(
859+
_is_type_arg_compatible(member, expected) for member in get_args(actual)
860+
)
861+
862+
# For generic types (e.g., List[str]), check origin and args recursively
863+
expected_origin = get_origin(expected)
864+
actual_origin = get_origin(actual)
865+
if expected_origin is not None and actual_origin is not None:
866+
if not issubclass(actual_origin, expected_origin):
867+
return False
868+
expected_args = get_args(expected)
869+
actual_args = get_args(actual)
870+
if expected_args and actual_args:
871+
return all(
872+
_is_type_arg_compatible(a, e)
873+
for a, e in zip(actual_args, expected_args)
874+
)
875+
return True
876+
877+
# Fallback: exact equality
878+
return expected == actual
879+
880+
840881
class GenericType(Type):
841882
def __init__(self, type: typing.Type):
842883
self.type = type
@@ -873,17 +914,11 @@ def validate(self, value):
873914
for expected, actual in zip(self.args, matching_args):
874915
if isinstance(expected, TypeVar) or isinstance(actual, TypeVar):
875916
continue
876-
if isinstance(expected, type) and isinstance(actual, type):
877-
if not issubclass(actual, expected):
878-
raise TypeError(
879-
f"{type(value).__qualname__} has type argument "
880-
f"{actual.__qualname__} which is not a subtype of "
881-
f"expected {expected.__qualname__}"
882-
)
883-
elif expected != actual:
917+
if not _is_type_arg_compatible(actual, expected):
884918
raise TypeError(
885919
f"{type(value).__qualname__} has type argument "
886-
f"{actual} which does not match expected {expected}"
920+
f"{actual} which is not compatible with "
921+
f"expected {expected}"
887922
)
888923

889924
return value

src/experimaestro/tests/core/test_generics.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for the use of generics in configurations"""
22

3-
from typing import Generic, Optional, TypeVar
3+
from typing import Generic, List, Optional, Tuple, TypeVar, Union
44

55
import pytest
66
from experimaestro import field, Config, Param
@@ -398,3 +398,48 @@ class Container(Config):
398398

399399
with pytest.raises((ValueError, TypeError)):
400400
Container.C(sampler=MyBatchSampler.C())
401+
402+
403+
# =============================================================================
404+
# Tests for Union type compatibility in generic type arguments
405+
# =============================================================================
406+
407+
V = TypeVar("V")
408+
W = TypeVar("W")
409+
410+
411+
class Converter(Config, Generic[V, W]):
412+
pass
413+
414+
415+
HFTokenizerInput = Union[str, List[str], List[Tuple[str, str]]]
416+
417+
418+
class TopicTextConverter(Converter[str, str]):
419+
pass
420+
421+
422+
def test_core_generics_union_type_arg_accepts_member():
423+
"""Param[Base[str, Union[str, ...]]] should accept Base[str, str]
424+
since str is a member of the Union"""
425+
426+
class HFTokenizerAdapter(Config):
427+
converter: Param[Converter[str, HFTokenizerInput]]
428+
429+
converter = TopicTextConverter.C()
430+
adapter = HFTokenizerAdapter.C(converter=converter)
431+
assert adapter.converter is converter
432+
433+
434+
def test_core_generics_union_type_arg_rejects_non_member():
435+
"""Param[Base[str, Union[str, ...]]] should reject Base[str, int]
436+
since int is not a member of the Union"""
437+
438+
class IntConverter(Converter[str, int]):
439+
pass
440+
441+
class HFTokenizerAdapter(Config):
442+
converter: Param[Converter[str, HFTokenizerInput]]
443+
444+
with pytest.raises(TypeError):
445+
HFTokenizerAdapter.C(converter=IntConverter.C())

0 commit comments

Comments
 (0)