Skip to content

Commit c7566b6

Browse files
authored
codegen: emit guards for every anyOf variant to fix mypy union-attr on array-containing unions (#182)
Why === The encoder generator for non-discriminated anyOf unions emits a chain of ternary expressions, with the last variant historically rendered as the unguarded `else` branch. That works for simple unions like `object | str | list` (mypy can negative-narrow `x` to `list` in the final branch), but it breaks for deeper unions where the array variant is last, e.g. ``` str | float | bool | list[scalar] | None ``` When mypy fails to fully narrow `x` to `list[...]` through the prior `isinstance` checks (`isinstance(x, (int, float))` plus `bool` subclassing `int` make this tricky), it complains that scalar items of the union have no `__iter__` attribute: ``` error: Item "float" of "str | float | bool | None | list[...]" has no attribute "__iter__" (not iterable) [union-attr] error: Item "bool" of ... has no attribute "__iter__" [union-attr] error: Item "object" of ... has no attribute "__iter__" [union-attr] ``` This is the exact failure that has been blocking ai-infra's `codegen-latest-pid2-schema.yml` auto-update workflow since 2026-05-04, when repl-it-web#78355 widened `agentToolPostgreSQL.executeSqlCommand.params` from a flat scalar union to `array<scalar | array<scalar>>`. Every run since has failed on the regenerated `executeSqlCommand.py` at the `for y in x` iteration inside `encode_ExecutesqlcommandInputParams`. The committed pid2 client in ai-infra has been kept current by hand (see replit/ai-infra#12813), but the bot has been red for ~2.5 weeks. What changed ============ `src/replit_river/codegen/client.py`: in the non-discriminated-anyOf branch of `encode_type`, emit an explicit `isinstance` / `is None` guard for every entry in `encoder_parts` — including the last one — and append a `cast(Any, x)` fallback. mypy no longer has to negative-narrow into the iterating branch, so deep unions with an array variant lint cleanly. `Any` and `cast` are already part of `FILE_HEADER` so no import bookkeeping changes. Concretely, for the failing executeSqlCommand schema, the encoder now ends with: ```python return ( x if isinstance(x, str) else x if isinstance(x, (int, float)) else x if isinstance(x, bool) else None if x is None else [encode_..._AnyOf_4(y) for y in x] if isinstance(x, list) else cast(Any, x) ) ``` Test plan ========= - Existing `tests/v1/codegen/snapshot/test_anyof_mixed.py` snapshot updated to show the new `if isinstance(x, list) else cast(Any, x)` tail on its `obj | str | list[str]` encoder (the change is additive — the runtime behavior is unchanged). - New snapshot test `tests/v1/codegen/snapshot/test_anyof_array_in_union.py` added with a schema that mirrors `executeSqlCommand.params` (`array<scalar | array<scalar>>`) and locks in the fixed output. This is the regression test for ai-infra's CI failure. - `uv run pytest` is green (67 passed, including all v1 and v2 codegen tests). - `make lint` is clean apart from a pre-existing `pyright` `grpc` import error in `tests/v1/test_communication.py` that also fails on `main` (unrelated). - End-to-end verification against ai-infra: pointed ai-infra's `./pkgs/pid2_client/scripts/generate.sh` at this branch via `RIVER_CODEGEN_PATH=/tmp/opencode/river-python` and reran the full lint pipeline that the auto-update workflow runs in CI; `[mypy] completed in 15.19s` and the script exited `OK.` instead of the historical `union-attr` failure. Once this is released (e.g. `v0.17.20`) ai-infra can bump `replit-river` in `pkgs/pid2_client/pyproject.toml` and the auto-update workflow will start producing green PRs again. ~ written by Zerg 👾 ([ascendant-goliath-6d2f](https://zerg.zergrush.dev/chat?id=ascendant-goliath-6d2f))
1 parent 008e539 commit c7566b6

7 files changed

Lines changed: 246 additions & 8 deletions

File tree

src/replit_river/codegen/client.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -457,16 +457,35 @@ def {_field_name}(
457457
case _:
458458
encoder_parts.append((None, "x"))
459459

460-
# Build the ternary chain from encoder_parts
460+
# Build the ternary chain from encoder_parts.
461+
#
462+
# Every entry that has a `type_check` (isinstance / `x is None`) gets
463+
# its own guard, including the last one. Falling off the end means the
464+
# input did not match any declared anyOf variant, which should not
465+
# happen for a well-formed value; we emit a `cast(Any, x)` so mypy
466+
# doesn't try to narrow the value through the chain.
467+
#
468+
# Previously the last entry was emitted unconditionally as the `else`
469+
# branch. That works for simple unions (object | str | list), but
470+
# breaks down when the last variant's encoder requires iteration
471+
# (e.g. `[encode_X(y) for y in x]` when the variant is an array) and
472+
# mypy fails to fully narrow `x` through the prior `isinstance`
473+
# checks. The unguarded final branch then triggers `union-attr`
474+
# errors like "Item 'float' has no attribute '__iter__'".
461475
typeddict_encoder = list[str]()
462-
for i, (type_check, encoder_expr) in enumerate(encoder_parts):
463-
is_last = i == len(encoder_parts) - 1
464-
if is_last or type_check is None:
465-
# Last item or no type check - just the expression
476+
has_unguarded_terminal = False
477+
for type_check, encoder_expr in encoder_parts:
478+
if type_check is None:
479+
# No type check available — emit the bare expression and stop;
480+
# nothing after it could be reached anyway.
466481
typeddict_encoder.append(encoder_expr)
467-
else:
468-
# Add expression with type check
469-
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
482+
has_unguarded_terminal = True
483+
break
484+
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
485+
if not has_unguarded_terminal and encoder_parts:
486+
# Unreachable in practice (every declared variant was guarded
487+
# above), but mypy needs a concrete final expression.
488+
typeddict_encoder.append("cast(Any, x)")
470489
if permit_unknown_members:
471490
union = _make_open_union_type_expr(any_of)
472491
else:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class AnyOfArrayInUnionClient:
12+
def __init__(self, client: river.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .exec_sql_method import (
13+
Exec_Sql_MethodInput,
14+
Exec_Sql_MethodOutput,
15+
Exec_Sql_MethodOutputTypeAdapter,
16+
encode_Exec_Sql_MethodInput,
17+
encode_Exec_Sql_MethodInputParams,
18+
)
19+
20+
21+
class Test_ServiceService:
22+
def __init__(self, client: river.Client[Any]):
23+
self.client = client
24+
25+
async def exec_sql_method(
26+
self,
27+
input: Exec_Sql_MethodInput,
28+
timeout: datetime.timedelta,
29+
) -> Exec_Sql_MethodOutput:
30+
return await self.client.send_rpc(
31+
"test_service",
32+
"exec_sql_method",
33+
input,
34+
encode_Exec_Sql_MethodInput,
35+
lambda x: Exec_Sql_MethodOutputTypeAdapter.validate_python(
36+
x # type: ignore[arg-type]
37+
),
38+
lambda x: RiverErrorTypeAdapter.validate_python(
39+
x # type: ignore[arg-type]
40+
),
41+
timeout,
42+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
cast,
11+
)
12+
from typing_extensions import Annotated
13+
14+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
15+
from replit_river.error_schema import RiverError
16+
from replit_river.client import (
17+
RiverUnknownError,
18+
translate_unknown_error,
19+
RiverUnknownValue,
20+
translate_unknown_value,
21+
)
22+
23+
import replit_river as river
24+
25+
26+
Exec_Sql_MethodInputParamsAnyOf_4 = str | float | bool | None
27+
28+
29+
def encode_Exec_Sql_MethodInputParamsAnyOf_4(
30+
x: "Exec_Sql_MethodInputParamsAnyOf_4",
31+
) -> Any:
32+
return x
33+
34+
35+
Exec_Sql_MethodInputParams = (
36+
str | float | bool | list[Exec_Sql_MethodInputParamsAnyOf_4] | None
37+
)
38+
39+
40+
def encode_Exec_Sql_MethodInputParams(x: "Exec_Sql_MethodInputParams") -> Any:
41+
return (
42+
x
43+
if isinstance(x, str)
44+
else x
45+
if isinstance(x, (int, float))
46+
else x
47+
if isinstance(x, bool)
48+
else None
49+
if x is None
50+
else [encode_Exec_Sql_MethodInputParamsAnyOf_4(y) for y in x]
51+
if isinstance(x, list)
52+
else cast(Any, x)
53+
)
54+
55+
56+
def encode_Exec_Sql_MethodInput(
57+
x: "Exec_Sql_MethodInput",
58+
) -> Any:
59+
return {
60+
k: v
61+
for (k, v) in (
62+
{
63+
"params": [encode_Exec_Sql_MethodInputParams(y) for y in x["params"]]
64+
if "params" in x and x["params"] is not None
65+
else None,
66+
}
67+
).items()
68+
if v is not None
69+
}
70+
71+
72+
class Exec_Sql_MethodInput(TypedDict):
73+
params: NotRequired[list[Exec_Sql_MethodInputParams] | None]
74+
75+
76+
class Exec_Sql_MethodOutput(BaseModel):
77+
ok: bool
78+
79+
80+
Exec_Sql_MethodOutputTypeAdapter: TypeAdapter[Exec_Sql_MethodOutput] = TypeAdapter(
81+
Exec_Sql_MethodOutput
82+
)

tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def encode_Anyof_Mixed_MethodInputRun_Command(
6666
else x
6767
if isinstance(x, str)
6868
else list(x)
69+
if isinstance(x, list)
70+
else cast(Any, x)
6971
)
7072

7173

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pytest_snapshot.plugin import Snapshot
2+
3+
from tests.fixtures.codegen_snapshot_fixtures import validate_codegen
4+
5+
6+
async def test_anyof_array_in_union(snapshot: Snapshot) -> None:
7+
"""Test codegen for an array field whose item type is a non-discriminated
8+
anyOf union that itself contains an `array` variant.
9+
10+
Concretely this mirrors the PostgreSQL `executeSqlCommand.params` schema:
11+
`array<scalar | array<scalar>>`. The inner union encoder ends in an
12+
iteration over `x` (for the array variant), and historically that branch
13+
was emitted as the unguarded `else` of a ternary chain. When mypy failed
14+
to fully narrow `x` to `list[...]` through the preceding `isinstance`
15+
checks, it complained that scalar items of the union have no
16+
`__iter__` attribute (`union-attr`).
17+
18+
The fix emits an explicit `isinstance(x, list)` guard for the array
19+
branch and a `cast(Any, x)` fallback, so mypy never has to negative-
20+
narrow into the iterating branch.
21+
"""
22+
validate_codegen(
23+
snapshot=snapshot,
24+
snapshot_dir="tests/v1/codegen/snapshot/snapshots",
25+
read_schema=lambda: open(
26+
"tests/v1/codegen/types/anyof_array_in_union_schema.json"
27+
),
28+
target_path="test_anyof_array_in_union",
29+
client_name="AnyOfArrayInUnionClient",
30+
protocol_version="v1.1",
31+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"exec_sql_method": {
6+
"input": {
7+
"type": "object",
8+
"properties": {
9+
"params": {
10+
"description": "Parameterized query values. Each entry is either a scalar or an array of scalars (for ANY($1::text[]) etc.).",
11+
"type": "array",
12+
"items": {
13+
"anyOf": [
14+
{ "type": "string" },
15+
{ "type": "number" },
16+
{ "type": "boolean" },
17+
{ "type": "null" },
18+
{
19+
"type": "array",
20+
"items": {
21+
"anyOf": [
22+
{ "type": "string" },
23+
{ "type": "number" },
24+
{ "type": "boolean" },
25+
{ "type": "null" }
26+
]
27+
}
28+
}
29+
]
30+
}
31+
}
32+
}
33+
},
34+
"output": {
35+
"type": "object",
36+
"properties": {
37+
"ok": { "type": "boolean" }
38+
},
39+
"required": ["ok"]
40+
},
41+
"errors": {
42+
"not": {}
43+
},
44+
"type": "rpc"
45+
}
46+
}
47+
}
48+
}
49+
}

0 commit comments

Comments
 (0)