Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions src/msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ typedef struct {
PyObject *typing_final;
PyObject *typing_generic;
PyObject *typing_generic_alias;
PyObject *types_generic_alias;
PyObject *typing_annotated_alias;
PyObject *concrete_types;
PyObject *get_type_hints;
Expand Down Expand Up @@ -4954,6 +4955,49 @@ is_dataclass_or_attrs_class(TypeNodeCollectState *state, PyObject *t) {
);
}

static MS_INLINE PyObject*
convert_types_generic_alias(TypeNodeCollectState *state, PyObject *obj, PyObject *origin, PyObject *args) {
// if 'obj' is a 'types.GenericAlias', convert it into a 'typing._GenericAlias', so
// we can cache type info on it. 'types.GenericAlias' has __slots__, so caching on
// it directly does not work.
// it's unlikely to hit this case, as it will mostly occur when subclassing a
// built-in container generic, such as 'collections.abc.Mapping'

if (MS_UNLIKELY(Py_TYPE(obj) == (PyTypeObject *)state->mod->types_generic_alias)) {
// subscribed typing._GenericAlias instances are cached within the typing module
// we make use of this fact, by storing a __msgspec_cache__ attribute on the
// subscribed instance. only subscribed types are cache, so
// 'typing._GenericAlias(list, int) is typing._GenericAlias(list, int)' would be
// false.
// to achieve the same behaviour when re-creating a typing._GenericAlias from a
// types.GenericAlias, we first construct a temporary *unbound*
// typing._GenericAlias, on which we then call __getattr__. effectively doing
// typing._GenericAlias(list, T)[int], for which
// 'typing._GenericAlias(list, T)[int] is typing._GenericAlias(list, T)[int]'
// holds true
PyObject *params = PyObject_GetAttrString(origin, "__parameters__");
if (params == NULL) {
Py_DECREF(origin);
return NULL;
}

// create a new typing._GenericAlias with the unbound type params of the
// original types.GenericAlias.
// given a Mapping[str, int], this would produce a _GenericAlias(Mapping, (~K, ~V))
PyObject *new_alias = PyObject_CallFunctionObjArgs(state->mod->typing_generic_alias, origin, params, NULL);
if (new_alias == NULL) {
return NULL;
}

// bind it to the concrete types.
// given a _GenericAlias(Mapping, (~K, ~V)), produce a Mapping[str, int] again
PyObject *result = PyObject_CallMethod(new_alias, "__getitem__", "O", args);
Py_DECREF(new_alias);
return result;
}
return obj;
}

static int
typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
int out = 0;
Expand Down Expand Up @@ -5035,7 +5079,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
ms_is_struct_cls(t) ||
(origin != NULL && ms_is_struct_cls(origin))
) {
out = typenode_collect_struct(state, t);
out = typenode_collect_struct(state, convert_types_generic_alias(state, t, origin, args));
}
else if (PyType_IsSubtype(Py_TYPE(t), state->mod->EnumMetaType)) {
out = typenode_collect_enum(state, t);
Expand Down Expand Up @@ -5143,7 +5187,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
is_dataclass_or_attrs_class(state, t) ||
(origin != NULL && is_dataclass_or_attrs_class(state, origin))
) {
out = typenode_collect_dataclass(state, t);
out = typenode_collect_dataclass(state, convert_types_generic_alias(state, t, origin, args));
}
else {
if (origin != NULL) {
Expand Down Expand Up @@ -22299,6 +22343,7 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->typing_final);
Py_CLEAR(st->typing_generic);
Py_CLEAR(st->typing_generic_alias);
Py_CLEAR(st->types_generic_alias);
Py_CLEAR(st->typing_annotated_alias);
Py_CLEAR(st->concrete_types);
Py_CLEAR(st->get_type_hints);
Expand Down Expand Up @@ -22371,6 +22416,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(st->typing_final);
Py_VISIT(st->typing_generic);
Py_VISIT(st->typing_generic_alias);
Py_VISIT(st->types_generic_alias);
Py_VISIT(st->typing_annotated_alias);
Py_VISIT(st->concrete_types);
Py_VISIT(st->get_type_hints);
Expand Down Expand Up @@ -22603,6 +22649,7 @@ PyInit__core(void)
temp_module = PyImport_ImportModule("types");
if (temp_module == NULL) return NULL;
SET_REF(types_uniontype, "UnionType");
SET_REF(types_generic_alias, "GenericAlias");
Py_DECREF(temp_module);

/* Get the EnumMeta type */
Expand Down
28 changes: 23 additions & 5 deletions src/msgspec/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore
import collections
import sys
import types
import typing
from typing import _AnnotatedAlias # noqa: F401

Expand All @@ -22,6 +23,8 @@ def get_type_hints(obj):
return _get_type_hints(obj, include_extras=True)


PY_31PLUS = sys.version_info >= (3, 12)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in variable name: PY_31PLUS reads as "Python 3.1+", but the comparison is >= (3, 12). Should be PY_312PLUS (consistent with PY312_PLUS used in _core.c).


# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
# check is to try it and see. This check can be removed when we drop support
Expand Down Expand Up @@ -89,13 +92,23 @@ def inner(c, scope):
cls = c
new_scope = {}
else:
cls = getattr(c, "__origin__", None)
cls = typing.get_origin(c)
if cls in (None, object, typing.Generic) or cls in mapping:
return
params = cls.__parameters__
args = tuple(_apply_params(a, scope) for a in c.__args__)
assert len(params) == len(args)
mapping[cls] = new_scope = dict(zip(params, args))

# it's a built-in generic that has unresolved type vars. in this case,
# parameters and args are stored on the generic, not the __origin__
if isinstance(c, types.GenericAlias) or (
isinstance(c, typing._GenericAlias)
and not hasattr(cls, "__parameters__")
):
new_scope = dict(zip(c.__parameters__, typing.get_args(c)))
else:
params = cls.__parameters__
args = tuple(_apply_params(a, scope) for a in typing.get_args(c))
assert len(params) == len(args)
new_scope = dict(zip(params, args))
mapping[cls] = new_scope

if issubclass(cls, typing.Generic):
bases = getattr(cls, "__orig_bases__", cls.__bases__)
Expand Down Expand Up @@ -133,6 +146,11 @@ def get_class_annotations(obj):

mapping = typevar_mappings.get(cls)
cls_locals = dict(vars(cls))

if PY_31PLUS:
# resolve type parameters (e.g. class Foo[T]: pass)
cls_locals.update({p.__name__: p for p in cls.__type_params__})

cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})

ann = _get_class_annotations(cls)
Expand Down
Loading
Loading