Skip to content

Commit 90e104e

Browse files
committed
fix: cleanup ticket handling
1 parent b1f5e9c commit 90e104e

3 files changed

Lines changed: 25 additions & 24 deletions

File tree

src/query_farm_server_base/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22
from datetime import UTC, datetime
33
from decimal import Decimal
4+
45
from pydantic import BaseModel, ConfigDict, Field
56

67

src/query_farm_server_base/flight_handling.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import io
21
import json
3-
import struct
42
from collections.abc import Callable, Generator
53
from dataclasses import dataclass
64
from typing import Any, TypeVar
@@ -16,8 +14,6 @@
1614
from duckdb_query_tools import duckdb_serialized_expression, sql_statement_analyzer
1715
from pydantic import BaseModel
1816

19-
ticket_with_metadata_indicator = b"<TICKET_WITH_METADATA>"
20-
2117

2218
class FlightTicketData(BaseModel):
2319
flight_name: str
@@ -36,6 +32,12 @@ def unpack(src: bytes) -> "FlightTicketData":
3632
return FlightTicketData.model_validate(unpacked)
3733

3834

35+
class AugmentedTicketData(BaseModel):
36+
ticket: bytes
37+
metadata_compressed_length: int
38+
metadata: bytes
39+
40+
3941
T = TypeVar("T", bound=FlightTicketData)
4042

4143

@@ -73,33 +75,32 @@ def endpoint(*, ticket_data: T, allow_metadata: bool, supports_predicate_pushdow
7375
)
7476

7577

76-
def decode_ticket(*, ticket: flight.Ticket, model_selector: Callable[[str, bytes], T]) -> tuple[T, dict[str, str]]:
78+
def decode_ticket(
79+
*, ticket: flight.Ticket, model_selector: Callable[[str, bytes], T], is_augmented_ticket: bool
80+
) -> tuple[T, dict[str, str]]:
7781
"""
7882
Decode a ticket that has embedded and compressed metadata.
7983
8084
There is no concept of multiple headers handled here, headers are strings.
8185
"""
82-
if (
83-
ticket.ticket[0 : min(len(ticket_with_metadata_indicator), len(ticket.ticket))]
84-
== ticket_with_metadata_indicator
85-
):
86-
# We have a ticket with metadata.
87-
stream = io.BytesIO(ticket.ticket)
88-
stream.seek(len(ticket_with_metadata_indicator))
89-
# Unpack the byte string as a uint32 ('I' is the format code for uint32)
90-
ticket_data_length = struct.unpack("<I", stream.read(4))[0]
91-
92-
# The ticket itself is a msgpack message.
93-
msgpack_ticket_contents = stream.read(ticket_data_length)
94-
95-
basic_data = FlightTicketData.unpack(msgpack_ticket_contents)
96-
decoded_ticket_data = model_selector(basic_data.flight_name, msgpack_ticket_contents)
97-
98-
metadata_decompressed_length = struct.unpack("<I", stream.read(4))[0]
86+
87+
if is_augmented_ticket:
88+
augmented_ticket = AugmentedTicketData.model_validate(
89+
msgpack.unpack(
90+
ticket.ticket,
91+
raw=True,
92+
object_hook=lambda s: {k.decode("utf8"): v for k, v in s.items()},
93+
)
94+
)
95+
96+
basic_data = FlightTicketData.unpack(augmented_ticket.ticket)
97+
decoded_ticket_data = model_selector(basic_data.flight_name, augmented_ticket.ticket)
98+
99+
metadata_decompressed_length = augmented_ticket.metadata_compressed_length
99100
if metadata_decompressed_length > 1024 * 1024 * 2:
100101
raise flight.FlightUnavailableError("Decompressed Flight metadata is too large limit is 2mb.")
101102

102-
metadata = stream.read()
103+
metadata = augmented_ticket.metadata
103104
parsed_headers: dict[str, str] = {}
104105
try:
105106
# That metadata is zstd compressed, so we need to decompress it.

src/query_farm_server_base/flight_inventory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import struct
21
from dataclasses import dataclass
32
from typing import Any, Literal
43

0 commit comments

Comments
 (0)