Skip to content

Commit 196ae5b

Browse files
committed
fix: simplify ticket handling, put predicates in the ticket by default
1 parent a8bd3ac commit 196ae5b

1 file changed

Lines changed: 13 additions & 66 deletions

File tree

src/query_farm_server_base/flight_handling.py

Lines changed: 13 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import pyarrow.flight as flight
1111
import sqlglot
1212
import structlog
13-
import zstandard as zstd
1413
from duckdb_query_tools import (
1514
duckdb_serialized_expression,
1615
sql_statement_analyzer,
@@ -20,10 +19,12 @@
2019

2120
class FlightTicketData(BaseModel):
2221
flight_name: str
22+
json_filters: str | None = None
23+
column_ids: list[int] | None = None
2324

2425
@staticmethod
2526
def unpack(src: bytes) -> "FlightTicketData":
26-
decode_fields = {"flight_name"}
27+
decode_fields = {"flight_name", "json_filters", "column_ids"}
2728
unpacked = msgpack.unpackb(
2829
src,
2930
raw=True,
@@ -35,18 +36,6 @@ def unpack(src: bytes) -> "FlightTicketData":
3536

3637
return FlightTicketData.model_validate(unpacked)
3738

38-
39-
class AugmentedTicketData(BaseModel):
40-
ticket: bytes
41-
metadata_uncompressed_length: int
42-
metadata: bytes
43-
44-
45-
class FilterMetadata(BaseModel):
46-
json_filters: str
47-
column_ids: list[int]
48-
49-
5039
T = TypeVar("T", bound=FlightTicketData)
5140

5241

@@ -72,7 +61,6 @@ def generate_record_batches_for_used_fields(
7261
def endpoint(
7362
*,
7463
ticket_data: T,
75-
supports_predicate_pushdown: bool,
7664
locations: list[str] | None = ["arrow-flight-reuse-connection://?"],
7765
) -> flight.FlightEndpoint:
7866
"""Create a FlightEndpoint that allows metadata filtering to be passed
@@ -82,74 +70,33 @@ def endpoint(
8270
return flight.FlightEndpoint(
8371
packed_data,
8472
locations,
85-
None,
86-
msgpack.packb({"supports_predicate_pushdown": supports_predicate_pushdown}),
8773
)
8874

8975

9076
def decode_ticket(
9177
*,
9278
ticket: flight.Ticket,
9379
model_selector: Callable[[str, bytes], T],
94-
is_augmented_ticket: bool,
9580
) -> tuple[T, dict[str, str]]:
9681
"""
9782
Decode a ticket that has embedded and compressed metadata.
9883
9984
There is no concept of multiple headers handled here, headers are strings.
10085
"""
86+
parsed_headers: dict[str, str] = {}
10187

102-
if is_augmented_ticket:
103-
augmented_ticket = AugmentedTicketData.model_validate(
104-
msgpack.unpackb(
105-
ticket.ticket,
106-
raw=True,
107-
object_hook=lambda s: {k.decode("utf8"): v for k, v in s.items()},
108-
)
109-
)
88+
basic_data = FlightTicketData.unpack(ticket.ticket)
11089

111-
basic_data = FlightTicketData.unpack(augmented_ticket.ticket)
112-
decoded_ticket_data = model_selector(basic_data.flight_name, augmented_ticket.ticket)
90+
if basic_data.json_filters and basic_data.json_filters != "":
91+
parsed_headers = {"airport-duckdb-json-filters": basic_data.json_filters}
11392

114-
if augmented_ticket.metadata_uncompressed_length > 1024 * 1024 * 2:
115-
raise flight.FlightUnavailableError(
116-
"Decompressed Flight metadata is too large limit is 2mb."
117-
)
118-
119-
metadata = augmented_ticket.metadata
120-
parsed_headers: dict[str, str] = {}
121-
try:
122-
# That metadata is zstd compressed, so we need to decompress it.
123-
decompressor = zstd.ZstdDecompressor()
124-
decompressed_metadata = decompressor.decompress(metadata)
125-
126-
decode_fields = {"json_filters"}
127-
unpacked_data = msgpack.unpackb(
128-
decompressed_metadata,
129-
raw=True,
130-
object_hook=lambda s: {
131-
k.decode("utf8"): v.decode("utf8") if k in decode_fields else v
132-
for k, v in s.items()
133-
},
134-
)
135-
136-
filter_metadata = FilterMetadata.model_validate(unpacked_data)
137-
138-
if filter_metadata.json_filters and filter_metadata.json_filters != "":
139-
parsed_headers["airport-duckdb-json-filters"] = filter_metadata.json_filters
140-
141-
if len(filter_metadata.column_ids) > 0:
142-
parsed_headers["airport-duckdb-column-ids"] = ",".join(
143-
map(str, filter_metadata.column_ids)
144-
)
93+
if basic_data.column_ids and len(basic_data.column_ids) > 0:
94+
parsed_headers["airport-duckdb-column-ids"] = ",".join(
95+
map(str, basic_data.column_ids)
96+
)
14597

146-
except Exception as e:
147-
raise flight.FlightUnavailableError("Unable to decompress metadata.") from e
148-
return decoded_ticket_data, parsed_headers
149-
else:
150-
basic_data = FlightTicketData.unpack(ticket.ticket)
151-
decoded_ticket_data = model_selector(basic_data.flight_name, ticket.ticket)
152-
return decoded_ticket_data, {}
98+
decoded_ticket_data = model_selector(basic_data.flight_name, ticket.ticket)
99+
return decoded_ticket_data, parsed_headers
153100

154101

155102
@dataclass

0 commit comments

Comments
 (0)