Skip to content

Commit a7391bd

Browse files
committed
feat: add action decoders
1 parent 294670a commit a7391bd

3 files changed

Lines changed: 344 additions & 1 deletion

File tree

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
from typing import Any, Literal, TypeVar, get_args, get_origin # noqa: UP035
2+
3+
import msgpack
4+
import pyarrow as pa
5+
import pyarrow.flight as flight
6+
from pydantic import BaseModel, ConfigDict, Field, field_validator
7+
8+
9+
def deserialize_record_batch(cls: Any, value: Any) -> pa.Schema:
10+
if isinstance(value, pa.RecordBatch):
11+
return value
12+
try:
13+
# handle both raw JSON string and parsed dict
14+
if isinstance(value, bytes):
15+
buffer = pa.BufferReader(value)
16+
# Open the IPC stream
17+
ipc_stream = pa.ipc.open_stream(buffer)
18+
19+
# Read the RecordBatch
20+
record_batch = next(ipc_stream)
21+
return record_batch
22+
23+
return pa.RecordBatch(value)
24+
except Exception as e:
25+
raise ValueError(f"Invalid Arrow record batch: {e}") from e
26+
27+
28+
def deserialize_schema(cls: Any, value: Any) -> pa.Schema:
29+
if isinstance(value, pa.Schema):
30+
return value
31+
try:
32+
# handle both raw JSON string and parsed dict
33+
if isinstance(value, bytes):
34+
return pa.ipc.read_schema(pa.BufferReader(value))
35+
36+
return pa.schema(value)
37+
except Exception as e:
38+
raise ValueError(f"Invalid Arrow schema: {e}") from e
39+
40+
41+
def deserialize_flight_descriptor(cls: Any, value: Any) -> flight.FlightDescriptor:
42+
if isinstance(value, flight.FlightDescriptor):
43+
return value
44+
try:
45+
# handle both raw JSON string and parsed dict
46+
if isinstance(value, bytes):
47+
return flight.FlightDescriptor.deserialize(value)
48+
except Exception as e:
49+
raise ValueError(f"Invalid Flight descriptor: {e}") from e
50+
51+
52+
class CreateTableActionParameters(BaseModel):
53+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
54+
catalog_name: str
55+
schema_name: str
56+
table_name: str
57+
58+
arrow_schema: pa.Schema
59+
_validate_arrow_schema = field_validator("arrow_schema", mode="before")(deserialize_schema)
60+
61+
on_conflict: Literal["error", "ignore", "replace"]
62+
63+
not_null_constraints: list[int]
64+
unique_constraints: list[int]
65+
check_constraints: list[str]
66+
67+
68+
T = TypeVar("T", bound=BaseModel)
69+
70+
71+
def unpack_with_model(action: flight.Action, model_cls: type[T]) -> T:
72+
decode_fields: set[str] = set()
73+
for name, field in model_cls.model_fields.items():
74+
if isinstance(field.annotation, str) or (
75+
get_origin(field.annotation) is list
76+
and get_args(field.annotation) is str
77+
or get_origin(field.annotation) is Literal
78+
):
79+
decode_fields.add(name)
80+
81+
unpacked = msgpack.unpackb(
82+
action.body.to_pybytes(),
83+
raw=True,
84+
object_hook=lambda s: {
85+
k.decode("utf8"): v.decode("utf8") if k.decode("utf8") in decode_fields else v
86+
for k, v in s.items()
87+
},
88+
)
89+
return model_cls.model_validate(unpacked)
90+
91+
92+
class DropObjectParameters(BaseModel):
93+
type: Literal["table", "schema"]
94+
catalog_name: str
95+
schema_name: str
96+
name: str
97+
ignore_not_found: bool
98+
99+
100+
class AlterBase(BaseModel):
101+
catalog: str
102+
schema_name: str = Field("schema_name", alias="schema")
103+
name: str
104+
ignore_not_found: bool
105+
106+
107+
class AddColumnParameters(AlterBase):
108+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
109+
column_schema: pa.Schema
110+
if_column_not_exists: bool
111+
112+
_validate_column_schema = field_validator("column_schema", mode="before")(deserialize_schema)
113+
114+
115+
class AddConstraintParameters(AlterBase):
116+
constraint: str
117+
118+
119+
class AddFieldParameters(AlterBase):
120+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
121+
column_schema: pa.Schema
122+
if_field_not_exists: bool
123+
124+
_validate_field_schema = field_validator("column_schema", mode="before")(deserialize_schema)
125+
126+
127+
class ChangeColumnTypeParameters(AlterBase):
128+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
129+
column_schema: pa.Schema
130+
expression: str
131+
132+
_validate_column_schema = field_validator("column_schema", mode="before")(deserialize_schema)
133+
134+
135+
class ColumnStatisticsParameters(AlterBase):
136+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
137+
flight_descriptor: flight.FlightDescriptor
138+
column_name: str
139+
type: str
140+
141+
_validate_flight_descriptor = field_validator("flight_descriptor", mode="before")(
142+
deserialize_flight_descriptor
143+
)
144+
145+
146+
class CreateSchemaParameters(BaseModel):
147+
catalog_name: str
148+
schema_name: str = Field("schema_name", alias="schema")
149+
150+
comment: str | None = None
151+
tags: dict[str, str]
152+
153+
154+
class CreateTransactionParameters(BaseModel):
155+
identifier: str | None
156+
157+
158+
class DropNotNullParameters(AlterBase):
159+
column_name: str
160+
161+
162+
class EndpointsParametersParameters(BaseModel):
163+
json_filters: str
164+
column_ids: list[int]
165+
166+
167+
class EndpointsParameters(BaseModel):
168+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
169+
descriptor: flight.FlightDescriptor
170+
_validate_descriptor = field_validator("descriptor", mode="before")(
171+
deserialize_flight_descriptor
172+
)
173+
parameters: EndpointsParametersParameters
174+
175+
176+
class ListSchemasParameters(BaseModel):
177+
catalog_name: str
178+
179+
180+
class RemoveColumnParameters(AlterBase):
181+
removed_column: str
182+
if_column_exists: bool
183+
cascade: bool
184+
185+
186+
class RemoveFieldParameters(AlterBase):
187+
column_path: list[str]
188+
if_column_exists: bool
189+
cascade: bool
190+
191+
192+
class RenameTableParameters(AlterBase):
193+
new_table_name: str
194+
195+
196+
class SetDefaultParameters(AlterBase):
197+
column_name: str
198+
expression: str
199+
200+
201+
class SetNotNullParameters(AlterBase):
202+
column_name: str
203+
204+
205+
class TableFunctionFlightInfoParameters(BaseModel):
206+
model_config = ConfigDict(arbitrary_types_allowed=True) # for Pydantic v2
207+
catalog: str
208+
schema_name: str
209+
action_name: str
210+
parameters: pa.RecordBatch
211+
table_input_schema: pa.Schema
212+
213+
_validate_parameters = field_validator("parameters", mode="before")(deserialize_record_batch)
214+
215+
_validate_table_input_schema = field_validator("table_input_schema", mode="before")(
216+
deserialize_schema
217+
)
218+
219+
220+
def add_column(action: flight.Action) -> AddColumnParameters:
221+
return unpack_with_model(action, AddColumnParameters)
222+
223+
224+
def add_constraint(action: flight.Action) -> AddConstraintParameters:
225+
return unpack_with_model(action, AddConstraintParameters)
226+
227+
228+
def add_field(action: flight.Action) -> AddFieldParameters:
229+
return unpack_with_model(action, AddFieldParameters)
230+
231+
232+
def change_column_type(action: flight.Action) -> ChangeColumnTypeParameters:
233+
return unpack_with_model(action, ChangeColumnTypeParameters)
234+
235+
236+
def create_table(action: flight.Action) -> CreateTableActionParameters:
237+
return unpack_with_model(action, CreateTableActionParameters)
238+
239+
240+
def column_statistics(action: flight.Action) -> ColumnStatisticsParameters:
241+
return unpack_with_model(action, ColumnStatisticsParameters)
242+
243+
244+
def create_schema(action: flight.Action) -> CreateSchemaParameters:
245+
return unpack_with_model(action, CreateSchemaParameters)
246+
247+
248+
def create_transaction(action: flight.Action) -> CreateTransactionParameters:
249+
return unpack_with_model(action, CreateTransactionParameters)
250+
251+
252+
def drop_not_null(action: flight.Action) -> DropNotNullParameters:
253+
return unpack_with_model(action, DropNotNullParameters)
254+
255+
256+
def drop_schema(action: flight.Action) -> DropObjectParameters:
257+
return unpack_with_model(action, DropObjectParameters)
258+
259+
260+
def drop_table(action: flight.Action) -> DropObjectParameters:
261+
return unpack_with_model(action, DropObjectParameters)
262+
263+
264+
def endpoints(action: flight.Action) -> EndpointsParameters:
265+
return unpack_with_model(action, EndpointsParameters)
266+
267+
268+
def list_schemas(action: flight.Action) -> ListSchemasParameters:
269+
return unpack_with_model(action, ListSchemasParameters)
270+
271+
272+
def remove_column(action: flight.Action) -> RemoveColumnParameters:
273+
return unpack_with_model(action, RemoveColumnParameters)
274+
275+
276+
def remove_field(action: flight.Action) -> RemoveFieldParameters:
277+
return unpack_with_model(action, RemoveFieldParameters)
278+
279+
280+
def rename_table(action: flight.Action) -> RenameTableParameters:
281+
return unpack_with_model(action, RenameTableParameters)
282+
283+
284+
def set_default(action: flight.Action) -> SetDefaultParameters:
285+
return unpack_with_model(action, SetDefaultParameters)
286+
287+
288+
def set_not_null(action: flight.Action) -> SetNotNullParameters:
289+
return unpack_with_model(action, SetNotNullParameters)

src/query_farm_server_base/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def auth_logging_items(
4444
) -> dict[str, Any]:
4545
"""Return the items that will be bound to the logger."""
4646
return {
47-
"token": None if credentials is None else credentials.token,
47+
"token": None if credentials is None else credentials.token.token,
4848
"account": None if credentials is None else credentials.account.account_id,
4949
"address": context.peer(),
5050
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import msgpack
2+
import pyarrow.flight as flight
3+
import pyarrow as pa
4+
from . import action_decoders
5+
6+
7+
def test_decode_drop_table() -> None:
8+
serialized = msgpack.packb(
9+
{
10+
"type": "table",
11+
"catalog_name": "test_catalog",
12+
"schema_name": "test_schema",
13+
"name": "test_table",
14+
"ignore_not_found": True,
15+
}
16+
)
17+
decoded = action_decoders.drop_table(flight.Action("drop_table", serialized))
18+
assert decoded.type == "table"
19+
assert decoded.catalog_name == "test_catalog"
20+
assert decoded.schema_name == "test_schema"
21+
assert decoded.name == "test_table"
22+
assert decoded.ignore_not_found is True
23+
24+
25+
def test_decode_create_table() -> None:
26+
real_schema = pa.schema(
27+
[
28+
("column1", pa.int32()),
29+
("column2", pa.string()),
30+
]
31+
)
32+
serialized_schema = real_schema.serialize().to_pybytes()
33+
serialized = msgpack.packb(
34+
{
35+
"catalog_name": "test_catalog",
36+
"schema_name": "test_schema",
37+
"table_name": "test_table",
38+
"arrow_schema": serialized_schema,
39+
"on_conflict": "error",
40+
"not_null_constraints": [],
41+
"unique_constraints": [],
42+
"check_constraints": ["test1"],
43+
},
44+
)
45+
46+
decoded = action_decoders.create_table(flight.Action("create_table", serialized))
47+
assert decoded.catalog_name == "test_catalog"
48+
assert decoded.schema_name == "test_schema"
49+
assert decoded.table_name == "test_table"
50+
assert decoded.arrow_schema == real_schema
51+
assert decoded.on_conflict == "error"
52+
assert decoded.not_null_constraints == []
53+
assert decoded.unique_constraints == []
54+
assert decoded.check_constraints == ["test1"]

0 commit comments

Comments
 (0)