Skip to content

Commit b921130

Browse files
committed
Do not leak type variables in type[T]
1 parent 4434f73 commit b921130

4 files changed

Lines changed: 43 additions & 15 deletions

File tree

mypy/constraints.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,11 @@ def infer_constraints_for_callable(
275275

276276

277277
def infer_constraints(
278-
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
278+
template: Type,
279+
actual: Type,
280+
direction: int,
281+
skip_neg_op: bool = False,
282+
erase_types: bool = True,
279283
) -> list[Constraint]:
280284
"""Infer type constraints.
281285
@@ -312,14 +316,14 @@ def infer_constraints(
312316
# Return early on an empty branch.
313317
return []
314318
type_state.inferring.append((template, actual))
315-
res = _infer_constraints(template, actual, direction, skip_neg_op)
319+
res = _infer_constraints(template, actual, direction, skip_neg_op, erase_types)
316320
type_state.inferring.pop()
317321
return res
318-
return _infer_constraints(template, actual, direction, skip_neg_op)
322+
return _infer_constraints(template, actual, direction, skip_neg_op, erase_types)
319323

320324

321325
def _infer_constraints(
322-
template: Type, actual: Type, direction: int, skip_neg_op: bool
326+
template: Type, actual: Type, direction: int, skip_neg_op: bool, erase_types: bool
323327
) -> list[Constraint]:
324328
orig_template = template
325329
template = get_proper_type(template)
@@ -424,7 +428,7 @@ def _infer_constraints(
424428
return []
425429

426430
# Remaining cases are handled by ConstraintBuilderVisitor.
427-
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))
431+
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op, erase_types))
428432

429433

430434
def _is_type_type(tp: ProperType) -> TypeGuard[TypeType | UnionType]:
@@ -659,14 +663,20 @@ class ConstraintBuilderVisitor(TypeVisitor[list[Constraint]]):
659663
# TODO: The value may be None. Is that actually correct?
660664
actual: ProperType
661665

662-
def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
666+
def __init__(
667+
self, actual: ProperType, direction: int, skip_neg_op: bool, erase_types: bool
668+
) -> None:
663669
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
664670
self.actual = actual
665671
self.direction = direction
666672
# Whether to skip polymorphic inference (involves inference in opposite direction)
667673
# this is used to prevent infinite recursion when both template and actual are
668674
# generic callables.
669675
self.skip_neg_op = skip_neg_op
676+
# Normally we should erase generic actual type when inferring against type[T]
677+
# to avoid leaking type variables, see testGenericClassAsArgumentToType.
678+
# The only exception is self-types in generic classes, where we set this to False.
679+
self.erase_types = erase_types
670680

671681
# Trivial leaf types
672682

@@ -1376,15 +1386,17 @@ def visit_overloaded(self, template: Overloaded) -> list[Constraint]:
13761386
def visit_type_type(self, template: TypeType) -> list[Constraint]:
13771387
if isinstance(self.actual, CallableType):
13781388
if self.actual.is_type_obj():
1379-
return infer_constraints(
1380-
template.item, self.actual.get_instance_type(), self.direction
1381-
)
1389+
instance_type = self.actual.get_instance_type()
1390+
if self.erase_types:
1391+
instance_type = erase_typevars(instance_type)
1392+
return infer_constraints(template.item, instance_type, self.direction)
13821393
return infer_constraints(template.item, self.actual.ret_type, self.direction)
13831394
elif isinstance(self.actual, Overloaded):
13841395
if self.actual.is_type_obj():
1385-
return infer_constraints(
1386-
template.item, self.actual.items[0].get_instance_type(), self.direction
1387-
)
1396+
instance_type = self.actual.items[0].get_instance_type()
1397+
if self.erase_types:
1398+
instance_type = erase_typevars(instance_type)
1399+
return infer_constraints(template.item, instance_type, self.direction)
13881400
return infer_constraints(template.item, self.actual.items[0].ret_type, self.direction)
13891401
elif isinstance(self.actual, TypeType):
13901402
return infer_constraints(template.item, self.actual.item, self.direction)

mypy/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,11 @@ def infer_type_arguments(
7070
actual: Type,
7171
is_supertype: bool = False,
7272
skip_unsatisfied: bool = False,
73+
erase_types: bool = True,
7374
) -> list[Type | None]:
7475
# Like infer_function_type_arguments, but only match a single type
7576
# against a generic type.
76-
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
77+
constraints = infer_constraints(
78+
template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF, erase_types=erase_types
79+
)
7780
return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0]

mypy/typeops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ class B(A): pass
495495

496496
# Solve for these type arguments using the actual class or instance type.
497497
typeargs = infer_type_arguments(
498-
self_vars, self_param_type, original_type, is_supertype=True
498+
self_vars, self_param_type, original_type, is_supertype=True, erase_types=False
499499
)
500500
if (
501501
is_classmethod
@@ -504,7 +504,11 @@ class B(A): pass
504504
):
505505
# In case we call a classmethod through an instance x, fallback to type(x).
506506
typeargs = infer_type_arguments(
507-
self_vars, self_param_type, TypeType(original_type), is_supertype=True
507+
self_vars,
508+
self_param_type,
509+
TypeType(original_type),
510+
is_supertype=True,
511+
erase_types=False,
508512
)
509513

510514
# Update the method signature with the solutions found.

test-data/unit/check-generics.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3688,3 +3688,12 @@ reveal_type(ok3) # N: Revealed type is "tuple[()]"
36883688
bad1: list[()] = [] # E: "list" expects 1 type argument, but none given \
36893689
# E: Missing type arguments for generic type "list"
36903690
[builtins fixtures/tuple.pyi]
3691+
3692+
[case testGenericClassAsArgumentToType]
3693+
from typing import TypeVar, Generic
3694+
3695+
T = TypeVar("T")
3696+
def test(tp: type[T]) -> T: ...
3697+
3698+
class C(Generic[T]): ...
3699+
reveal_type(test(C)) # N: Revealed type is "__main__.C[Any]"

0 commit comments

Comments
 (0)