Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions src/msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ typedef struct {
PyObject *get_class_annotations;
PyObject *get_typeddict_info;
PyObject *get_dataclass_info;
PyObject *convert_generic_alias;
PyObject *rebuild;
PyObject *types_uniontype;
#if PY312_PLUS
Expand Down Expand Up @@ -4984,6 +4985,29 @@ is_dataclass_or_attrs_class(TypeNodeCollectState *state, PyObject *t) {
);
}

static MS_INLINE int
normalize_types_generic_alias(TypeNodeCollectState *state, PyObject **t, PyObject *origin, PyObject *args) {
// if '*t' is a 'types.GenericAlias', replace it (in place) with an equivalent
// 'typing._GenericAlias'. 'types.GenericAlias' has __slots__ and forwards
// attribute access to its origin, so we can't cache type info (as
// '__msgspec_cache__') on it; a 'typing._GenericAlias' can.
//
// only the branches that cache on the alias (struct/dataclass/TypedDict/
// NamedTuple/Literal) need to call this. a 'types.GenericAlias' here only arises
// when subclassing a builtin container generic (e.g. 'collections.abc.Mapping') or
// from a manually constructed 'types.GenericAlias' (e.g. wrapping 'typing.Literal').
//
// replace '*t' with a new reference on success (dropping the original), or
// return -1 with an exception set.
if (MS_LIKELY(Py_TYPE(*t) != &Py_GenericAliasType)) return 0;
PyObject *converted = PyObject_CallFunctionObjArgs(
state->mod->convert_generic_alias, origin, args, NULL
);
if (converted == NULL) return -1;
Py_SETREF(*t, converted);
return 0;
}

static int
typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
int out = 0;
Expand Down Expand Up @@ -5065,7 +5089,12 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
ms_is_struct_cls(t) ||
(origin != NULL && ms_is_struct_cls(origin))
) {
out = typenode_collect_struct(state, t);
if (normalize_types_generic_alias(state, &t, origin, args) < 0) {
out = -1;
}
else {
out = typenode_collect_struct(state, t);
}
}
else if (PyType_IsSubtype(Py_TYPE(t), state->mod->EnumMetaType)) {
out = typenode_collect_enum(state, t);
Expand Down Expand Up @@ -5155,25 +5184,45 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
state->literals = PyList_New(0);
if (state->literals == NULL) goto done;
}
out = PyList_Append(state->literals, t);
if (normalize_types_generic_alias(state, &t, origin, args) < 0) {
out = -1;
}
else {
out = PyList_Append(state->literals, t);
}
}
else if (
is_typeddict_class(state, t) ||
(origin != NULL && is_typeddict_class(state, origin))
) {
out = typenode_collect_typeddict(state, t);
if (normalize_types_generic_alias(state, &t, origin, args) < 0) {
out = -1;
}
else {
out = typenode_collect_typeddict(state, t);
}
}
else if (
is_namedtuple_class(state, t) ||
(origin != NULL && is_namedtuple_class(state, origin))
) {
out = typenode_collect_namedtuple(state, t);
if (normalize_types_generic_alias(state, &t, origin, args) < 0) {
out = -1;
}
else {
out = typenode_collect_namedtuple(state, t);
}
}
else if (
is_dataclass_or_attrs_class(state, t) ||
(origin != NULL && is_dataclass_or_attrs_class(state, origin))
) {
out = typenode_collect_dataclass(state, t);
if (normalize_types_generic_alias(state, &t, origin, args) < 0) {
out = -1;
}
else {
out = typenode_collect_dataclass(state, t);
}
}
else {
if (origin != NULL) {
Expand Down Expand Up @@ -22428,6 +22477,7 @@ msgspec_clear(PyObject *m)
Py_CLEAR(st->get_dataclass_info);
Py_CLEAR(st->rebuild);
Py_CLEAR(st->types_uniontype);
Py_CLEAR(st->convert_generic_alias);
#if PY312_PLUS
Py_CLEAR(st->typing_typealiastype);
#endif
Expand Down Expand Up @@ -22501,6 +22551,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
Py_VISIT(st->get_dataclass_info);
Py_VISIT(st->rebuild);
Py_VISIT(st->types_uniontype);
Py_VISIT(st->convert_generic_alias);
#if PY312_PLUS
Py_VISIT(st->typing_typealiastype);
#endif
Expand Down Expand Up @@ -22702,6 +22753,7 @@ PyInit__core(void)
SET_REF(get_dataclass_info, "get_dataclass_info");
SET_REF(typing_annotated_alias, "_AnnotatedAlias");
SET_REF(rebuild, "rebuild");
SET_REF(convert_generic_alias, "convert_generic_alias");
Py_DECREF(temp_module);

temp_module = PyImport_ImportModule("types");
Expand Down
79 changes: 71 additions & 8 deletions src/msgspec/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# type: ignore
import collections
import sys
import types
import typing
from typing import _AnnotatedAlias # noqa: F401
from typing import _AnnotatedAlias, _GenericAlias # noqa: F401

try:
from typing_extensions import get_type_hints as _get_type_hints
Expand All @@ -22,6 +23,8 @@ def get_type_hints(obj):
return _get_type_hints(obj, include_extras=True)


PY_312PLUS = sys.version_info >= (3, 12)

# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
# check is to try it and see. This check can be removed when we drop support
Expand Down Expand Up @@ -79,23 +82,42 @@ def _apply_params(obj, mapping):
def _get_class_mro_and_typevar_mappings(obj):
mapping = {}

if isinstance(obj, type):
# in Python 3.10 a natively produced 'types.GenericAlias' (e.g. 'list[int]', or the
# 'Base[int]' produced when a 'Generic' subclass inherits a builtin's
# '__class_getitem__') satisfies 'isinstance(_, type)', unlike on 3.11+. we still
# want to treat those as parametrised aliases, not bare classes
if isinstance(obj, type) and not isinstance(obj, types.GenericAlias):
cls = obj
else:
cls = obj.__origin__

def inner(c, scope):
if isinstance(c, type):
if isinstance(c, type) and not isinstance(c, types.GenericAlias):
cls = c
new_scope = {}
else:
cls = getattr(c, "__origin__", None)
cls = typing.get_origin(c)
if cls in (None, object, typing.Generic) or cls in mapping:
return
params = cls.__parameters__
args = tuple(_apply_params(a, scope) for a in c.__args__)
assert len(params) == len(args)
mapping[cls] = new_scope = dict(zip(params, args))

if hasattr(cls, "__parameters__"):
# 'cls' carries its own type vars. This covers both ordinary
# 'typing._GenericAlias' bases and the 'types.GenericAlias' that
# get produced when a user-defined 'Generic' subclass inherits a
# builtin's '__class_getitem__' (e.g. 'class Base(Mapping[str, T])',
# whose 'Base[...]' is a 'types.GenericAlias'). Map cls's own
# type vars onto the resolved args, applying '_apply_params' so any
# outer bindings in 'scope' (e.g. 'U -> int') are propagated.
params = cls.__parameters__
args = tuple(_apply_params(a, scope) for a in typing.get_args(c))
assert len(params) == len(args)
new_scope = dict(zip(params, args))
else:
# a true built-in generic (e.g. 'collections.abc.Mapping[str, T]')
# whose '__origin__' has no '__parameters__'; the unresolved type
# vars and args live on the alias itself, not the origin.
new_scope = dict(zip(c.__parameters__, typing.get_args(c)))
mapping[cls] = new_scope

if issubclass(cls, typing.Generic):
bases = getattr(cls, "__orig_bases__", cls.__bases__)
Expand Down Expand Up @@ -133,6 +155,11 @@ def get_class_annotations(obj):

mapping = typevar_mappings.get(cls)
cls_locals = dict(vars(cls))

if PY_312PLUS:
# resolve type parameters (e.g. class Foo[T]: pass)
cls_locals.update({p.__name__: p for p in cls.__type_params__})

try:
cls_module = cls.__module__
except AttributeError:
Expand Down Expand Up @@ -315,3 +342,39 @@ def inner(obj):
def rebuild(cls, kwargs):
"""Used to unpickle Structs with keyword-only fields"""
return cls(**kwargs)


def convert_generic_alias(origin, args):
# subscribed typing._GenericAlias instances are cached within the typing module
# we make use of this fact, by storing a __msgspec_cache__ attribute on the
# subscribed instance. only subscribed types are cached, so
# 'typing._GenericAlias(list, int) is typing._GenericAlias(list, int)' would be
# false.
# to achieve the same behaviour when re-creating a typing._GenericAlias from a
# types.GenericAlias, we first construct a temporary *unbound*
# typing._GenericAlias, on which we then call __getattr__. effectively doing
# typing._GenericAlias(list, T)[int], for which
# 'typing._GenericAlias(list, T)[int] is typing._GenericAlias(list, T)[int]'
# holds true
try:
params = origin.__parameters__
except AttributeError:
if not isinstance(origin, type):
# a special form such as 'typing.Literal', whose args are values rather than
# type parameters. Ordinary subscription yields the canonical, interned
# alias of the right subclass (e.g. 'typing.Literal[...]' -> '_LiteralGenericAlias').
return origin[args]

# a non-generic class with type arguments. only reachable for e.g.
# manually-built 'types.GenericAlias' instances and is probably nonsense or at
# least not somthing we can meaningfully represent, so complain about it here,
# rather than silently dropping the arguments.
raise TypeError(f"{origin.__name__!r} is not a generic type")

# a regularly-parametrised generic. Create a new typing._GenericAlias with the
# origin's unbound type params (e.g. for a 'Mapping[str, int]' this is a
# '_GenericAlias(Mapping, (~K, ~V))'), then bind it to the concrete args by
# subscripting with the args *tuple* (i.e. 'alias[(int, str)]', not
# 'alias[int, str]'), so generics with more than one type var work
alias = _GenericAlias(origin, params)
return alias[args]
Loading
Loading