Skip to content

Fix handling of type var syntax and types.GenericAlias#962

Open
provinzkraut wants to merge 3 commits intojcrist:mainfrom
provinzkraut:fix-tvar-typing-generic-alias
Open

Fix handling of type var syntax and types.GenericAlias#962
provinzkraut wants to merge 3 commits intojcrist:mainfrom
provinzkraut:fix-tvar-typing-generic-alias

Conversation

@provinzkraut
Copy link
Copy Markdown
Contributor

@provinzkraut provinzkraut commented Nov 29, 2025

Fix #957.

This is a few fixes combined into one, since they were tightly coupled.

Handling of "new style" generics during type resolution

When subscribing a "new style" generic (such as collections.abc.Mapping), it produces a types.GenericAlias (vs. the "old style" typing._GenericAlias), which msgspec did not handle correctly during inspection

Handling of type var syntax

When dealing with builtin generics that resolve to typing.TypeAlias, msgspec did not account for type var syntax at correctly in all cases, so type information would get lost during the conversion process

Type conversions on types.GenericAlias

During type conversion, msgspec caches certain information on the type objects themselves, if the types are complex (i.e. Structs or dataclass-like).

When decoding into a Foo[int], msgspec will set an __msgspec_cache__ attribute on the Foo[int] alias type.

For typing._GenericAlias, this work, since it has a __dict__, so you can just assign attributes to it. However, types.GenericAlias does not allow assigning arbitrary attributes to it.

I fix this by downtyping types.GenericAlias into a typing._GenericAlias, when encountering a generic Struct or dataclass type. This allows to keep the existing caching mechanism in place.

This seemed like the most reasonable fix to me, as the other alternatives would like incur some sort of performance penalty; By storing the typing info directly on the alias, msgspec can forego maintaining a dedicated cache, making lookups very fast. It also allows to not care about invalidating a cache, since it will just be gce'd when the alias isn't referenced anymore.

One thing to not here though is that in the future, typing._GenericAlias might just go away (at least from the stdlib), in which case we'll have to find another way to deal with this.

@provinzkraut
Copy link
Copy Markdown
Contributor Author

provinzkraut commented Nov 30, 2025

Okay, maybe this isn't a good solution. I've discovered a slight issue with the typing._GenericAlias implementation:

It does not cache itself when called directly (i.g. doing _GenericAlias(<type>, <args>), but only when subscribed. This means that List[int] is List[int] holds true, but _GenericAlias(List, int) is _GenericAlias(List, int) does not.

Since msgspec relies on that cache being preserved, the solution as currently proposed only maintains the cache within a decoder instance, meaning that calling decode() twice, even with the same type, does not maintain the cache.

Note that this only affects generic types bound by a types.GenericAlias. Everything else is unaffected, including "regular" generic classes (e.g. a class Foo[T](Struct) will be cached just fine.

@provinzkraut provinzkraut force-pushed the fix-tvar-typing-generic-alias branch from 43d8d35 to 4285b96 Compare November 30, 2025 11:53
@provinzkraut
Copy link
Copy Markdown
Contributor Author

provinzkraut commented Nov 30, 2025

Alright, I think I managed to find a solution. Essentially, we're now reproducing a typing._GenericAlias that properly caches itself, from a types.GenericAlias.

It's basically doing

typing._GenericAlias(
    alias.__origin__, 
    alias.__origin__.__parameters__
).__getitem__(*alias.__args__)

which is functionally the same as

alias = typing._SpecialGenericAlias(list, 1)[int]
# this is the same as
alias = typing.List[int]

I've also added some more test cases to ensure the caching works properly in all newly supported cases.

@provinzkraut
Copy link
Copy Markdown
Contributor Author

Another friendly ping to @ofek for a review :)
Hope you don't mind, I just didn't want these PRs to get buried, but I also don't want to continue to send PRs and have them pile up 😅
LMK if I can do something to make the review process easier for you.

@ofek
Copy link
Copy Markdown
Collaborator

ofek commented Dec 19, 2025

I've been really busy, sorry for the wait! Can you confirm locally that the minor performance regression here compared to the last run on main is just noise?

I really appreciate all the work you've been doing 🙏

@provinzkraut
Copy link
Copy Markdown
Contributor Author

provinzkraut commented Dec 19, 2025

Can you confirm locally that the minor performance regression here compared to the last run on main is just noise?

I'll check it out!

@provinzkraut
Copy link
Copy Markdown
Contributor Author

I've been really busy, sorry for the wait! Can you confirm locally that the minor performance regression here compared to the last run on main is just noise?

Seems to be just noise, on my machine (Apple Silicon M4) at least.

@provinzkraut provinzkraut force-pushed the fix-tvar-typing-generic-alias branch from 65f7010 to 866b311 Compare April 8, 2026 19:43
@provinzkraut provinzkraut force-pushed the fix-tvar-typing-generic-alias branch from 866b311 to 33ba0ed Compare April 10, 2026 11:16
Comment thread src/msgspec/_utils.py
return _get_type_hints(obj, include_extras=True)


PY_31PLUS = sys.version_info >= (3, 12)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in variable name: PY_31PLUS reads as "Python 3.1+", but the comparison is >= (3, 12). Should be PY_312PLUS (consistent with PY312_PLUS used in _core.c).

@Siyet
Copy link
Copy Markdown
Collaborator

Siyet commented Apr 12, 2026

Performance & correctness review

Tested on a dedicated VPS, pinned to a single core via taskset -c 0.

VPS specs
  • OS: Ubuntu 24.04.3 LTS, kernel 6.8.0-107-generic
  • CPU: AMD EPYC (with IBPB), 1 vCPU
  • RAM: 1.9 GiB
  • Python: 3.12.3
  • Method: A/B/A (main -> PR -> main), 3x100 rounds per benchmark, trimmed mean (drop top/bottom 10%), taskset -c 0

Benchmark results

Benchmark main_1 (ms) PR (ms) main_2 (ms) PR vs avg(main) stdev (m1/pr/m2) verdict
point_encode 0.0788 0.0810 0.0841 -0.5% 2.4/2.5/15.8% noise
point_decode 0.0837 0.0888 0.0871 +4.0% 2.6/1.4/4.4% noise (main drifts)
generic_struct_decode 0.0587 0.0603 0.0586 +2.7% 1.2/1.5/1.4% borderline, see below
decoder_creation_simple 0.0903 0.0925 0.0914 +1.9% 1.4/1.9/2.3% noise
decoder_creation_generic 0.1681 0.1696 0.1691 +0.6% 1.1/3.7/1.6% noise

generic_struct_decode shows +2.7%, but this benchmark uses a pre-created Decoder(Box[int]) where Box is Struct, Generic[T] (a typing._GenericAlias, not a types.GenericAlias). The convert_types_generic_alias path is not hit here, so this is VPS noise despite the low stdev.

Conclusion: no measurable performance regression.

Correctness verification

6/6 tests pass on the PR branch:

Test Status
dataclass inheriting collections.abc.Mapping[str, T] (original #957) PASS
Struct + ABCMeta + collections.abc.Mapping[str, T] PASS
typing.Mapping[str, T] variant PASS
PEP 695 type var syntax (Python 3.12) PASS
Refcount / cache across decoder instances PASS
Cache persistence across 10 decode calls PASS

Memory leak test

10,000 iterations of Decoder(Foo[int]) where Foo inherits collections.abc.Mapping (exercising convert_types_generic_alias on every call):

  • RSS growth: 60 KB (within noise)
  • tracemalloc: 3.4 KB growth (from test scaffolding)

No memory leak detected.

Benchmark script
import timeit
import statistics
import json
import typing

import msgspec
from msgspec import Struct

T = typing.TypeVar("T")

class Point(Struct):
    x: int
    y: int
    z: float

class Box(Struct, typing.Generic[T]):
    value: T

enc = msgspec.json.Encoder()
dec_point = msgspec.json.Decoder(Point)
dec_box_int = msgspec.json.Decoder(Box[int])
point_msg = enc.encode(Point(1, 2, 3.14))
box_msg = enc.encode(Box(value=42))

def bench_point_encode():
    p = Point(1, 2, 3.14)
    for _ in range(1000):
        enc.encode(p)

def bench_point_decode():
    for _ in range(1000):
        dec_point.decode(point_msg)

def bench_generic_struct_decode():
    for _ in range(1000):
        dec_box_int.decode(box_msg)

def bench_decoder_creation_simple():
    for _ in range(200):
        msgspec.json.Decoder(Point)

def bench_decoder_creation_generic():
    for _ in range(200):
        msgspec.json.Decoder(Box[int])

benchmarks = [
    ("point_encode", bench_point_encode),
    ("point_decode", bench_point_decode),
    ("generic_struct_decode", bench_generic_struct_decode),
    ("decoder_creation_simple", bench_decoder_creation_simple),
    ("decoder_creation_generic", bench_decoder_creation_generic),
]

# Warmup
for name, func in benchmarks:
    timeit.repeat(func, number=1, repeat=10)

# 3 full runs of 100 rounds each
all_results = {}
for run in range(3):
    for name, func in benchmarks:
        times = timeit.repeat(func, number=1, repeat=100)
        times_sorted = sorted(times)
        trimmed = times_sorted[10:-10]  # drop top/bottom 10%
        if name not in all_results:
            all_results[name] = []
        all_results[name].extend(trimmed)

results = {}
for name, times in all_results.items():
    results[name] = {
        "mean": statistics.mean(times),
        "stdev": statistics.stdev(times),
        "min": min(times),
        "p50": statistics.median(times),
    }

print(json.dumps(results, indent=2))
Verification script
import sys
import typing
import collections.abc
import dataclasses
import gc

import msgspec
from msgspec import Struct


results = []

def test(name, func):
    try:
        func()
        results.append((name, "PASS", ""))
    except Exception as e:
        results.append((name, "FAIL", str(e)))


def test_dataclass_mapping():
    @dataclasses.dataclass
    class Bar(typing.Generic[typing.T], collections.abc.Mapping[str, typing.T]):
        data: dict[str, typing.T]
        def __getitem__(self, x): return self.data[x]
        def __len__(self): return len(self.data)
        def __iter__(self): return iter(self.data)

    x = Bar(data={"x": 3})
    encoded = msgspec.msgpack.encode(x)
    decoded = msgspec.msgpack.decode(encoded, type=Bar[int])
    assert decoded == x

test("dataclass_mapping_generic", test_dataclass_mapping)


def test_struct_mapping():
    import abc
    class CombinedMeta(msgspec.structs.StructMeta, abc.ABCMeta):
        pass

    T = typing.TypeVar("T")

    class Foo(collections.abc.Mapping[str, T], Struct, typing.Generic[T], metaclass=CombinedMeta):
        data: dict[str, T]
        def __getitem__(self, x): return self.data[x]
        def __len__(self): return len(self.data)
        def __iter__(self): return iter(self.data)

    encoded = msgspec.msgpack.encode(Foo(data={"x": 1}))
    decoded = msgspec.msgpack.decode(encoded, type=Foo[int])
    assert decoded.data == {"x": 1}

    try:
        msgspec.msgpack.decode(
            msgspec.msgpack.encode(Foo(data={"x": "foo"})), type=Foo[int]
        )
        assert False, "Should have raised ValidationError"
    except msgspec.ValidationError:
        pass

test("struct_mapping_generic", test_struct_mapping)


def test_typing_mapping():
    T = typing.TypeVar("T")

    @dataclasses.dataclass
    class Bar(typing.Generic[T], typing.Mapping[str, T]):
        data: dict[str, T]
        def __getitem__(self, x): return self.data[x]
        def __len__(self): return len(self.data)
        def __iter__(self): return iter(self.data)

    x = Bar(data={"x": 3})
    encoded = msgspec.msgpack.encode(x)
    decoded = msgspec.msgpack.decode(encoded, type=Bar[int])
    assert decoded == x

test("typing_mapping_generic", test_typing_mapping)


if sys.version_info >= (3, 12):
    def test_typevar_syntax():
        code = '''
from msgspec import Struct
class Ex[T](Struct):
    x: T
    y: list[T]
'''
        ns = {}
        exec(code, ns)
        Ex = ns["Ex"]
        msg = msgspec.json.encode(Ex(1, [1, 2]))
        res = msgspec.json.decode(msg, type=Ex[int])
        assert res.x == 1 and res.y == [1, 2]

    test("pep695_typevar_syntax", test_typevar_syntax)


def test_refcount_cache():
    T = typing.TypeVar("T")

    @dataclasses.dataclass
    class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
        data: dict[str, T]
        def __getitem__(self, x): return self.data[x]
        def __len__(self): return len(self.data)
        def __iter__(self): return iter(self.data)

    typ = Foo[int]
    dec1 = msgspec.json.Decoder(typ)
    dec2 = msgspec.json.Decoder(typ)
    msg = msgspec.json.encode(Foo(data={"a": 1}))
    r1 = dec1.decode(msg)
    r2 = dec2.decode(msg)
    assert r1 == r2
    del dec1, dec2
    gc.collect()

test("refcount_cache", test_refcount_cache)


def test_cache_persistence():
    T = typing.TypeVar("T")

    @dataclasses.dataclass
    class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
        data: dict[str, T]
        def __getitem__(self, x): return self.data[x]
        def __len__(self): return len(self.data)
        def __iter__(self): return iter(self.data)

    typ = Foo[int]
    msg = msgspec.msgpack.encode(Foo(data={"a": 1}))
    for i in range(10):
        dec = msgspec.msgpack.Decoder(typ)
        result = dec.decode(msg)
        assert result.data == {"a": 1}

test("cache_persistence", test_cache_persistence)

for name, status, detail in results:
    marker = "OK" if status == "PASS" else "FAIL"
    print(f"[{marker}] {name}" + (f" - {detail}" if detail else ""))
print(f"\n{len(results)} tests, {sum(1 for _, s, _ in results if s == 'FAIL')} failed")
Memory leak test script
import gc
import sys
import typing
import collections.abc
import dataclasses
import tracemalloc

import msgspec
from msgspec import Struct

T = typing.TypeVar("T")

@dataclasses.dataclass
class Foo(typing.Generic[T], collections.abc.Mapping[str, T]):
    data: dict[str, T]
    def __getitem__(self, x): return self.data[x]
    def __len__(self): return len(self.data)
    def __iter__(self): return iter(self.data)

tracemalloc.start()
gc.collect()

with open("/proc/self/status") as f:
    rss_before = int([l for l in f if l.startswith("VmRSS:")][0].split()[1])

snap1 = tracemalloc.take_snapshot()

for i in range(10000):
    dec = msgspec.json.Decoder(Foo[int])
    msg = msgspec.json.encode(Foo(data={"a": 1}))
    dec.decode(msg)
    del dec

gc.collect()

with open("/proc/self/status") as f:
    rss_after = int([l for l in f if l.startswith("VmRSS:")][0].split()[1])

snap2 = tracemalloc.take_snapshot()

print(f"RSS diff: {rss_after - rss_before} KB")
for stat in snap2.compare_to(snap1, "lineno")[:5]:
    print(f"  {stat}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Types inheriting from GenericAlias types (such as Mapping) cannot be decoded.

3 participants