Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5960,7 +5960,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
else_map[unwrapped_subject] = else_map[named_subject]
pattern_map = self.propagate_up_typemap_info(pattern_map)
else_map = self.propagate_up_typemap_info(else_map)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.check_and_remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map, from_assignment=False)
if pattern_map:
for expr, typ in pattern_map.items():
Expand Down Expand Up @@ -6066,15 +6066,8 @@ def infer_variable_types_from_type_maps(
already_exists = True
if isinstance(expr.node, Var) and expr.node.is_final:
self.msg.cant_assign_to_final(expr.name, False, expr)
if self.check_subtype(
typ,
previous_type,
expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type",
):
inferred_types[var] = previous_type
# We'll check compatibility in check_and_remove_capture_conflicts
inferred_types[var] = previous_type

if not already_exists:
new_type = UnionType.make_union(types)
Expand All @@ -6086,15 +6079,24 @@ def infer_variable_types_from_type_maps(
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types

def remove_capture_conflicts(
def check_and_remove_capture_conflicts(
self, type_map: TypeMap, inferred_types: dict[SymbolNode, Type]
) -> None:
if not is_unreachable_map(type_map):
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]
if is_unreachable_map(type_map):
return
for expr, typ in list(type_map.items()):
if not isinstance(expr, NameExpr):
continue
node = expr.node
if node not in inferred_types or not self.check_subtype(
typ,
inferred_types[node],
expr,
msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE,
subtype_label="pattern captures type",
supertype_label="variable has type",
):
del type_map[expr]

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
if o.alias_node:
Expand Down
30 changes: 30 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,36 @@ reveal_type(a) # N: Revealed type is "builtins.bool"
a = 3
reveal_type(a) # N: Revealed type is "builtins.int"

[case testMatchCapturePatternAfterPreviousCase]
# flags: --strict-equality --warn-unreachable

def f1(x: int | None, y: int):
match x:
case None:
pass
case y:
reveal_type(y) # N: Revealed type is "builtins.int"

def f2(x: int | None, y: int, cond: bool):
match x:
case None if cond:
pass
case y: # E: Incompatible types in capture pattern (pattern captures type "int | None", variable has type "int")
reveal_type(y) # N: Revealed type is "builtins.int"

def f3(x: int | None, y: int):
match x:
case None if True:
pass
case y:
reveal_type(y) # N: Revealed type is "builtins.int"

match x:
case None if False:
pass # E: Statement is unreachable
case y: # E: Incompatible types in capture pattern (pattern captures type "int | None", variable has type "int")
reveal_type(y) # N: Revealed type is "builtins.int"

[case testMatchCapturePatternPreexistingIncompatible]
# flags: --strict-equality --warn-unreachable
a: str
Expand Down
Loading