1- import io
21import json
3- import struct
42from collections .abc import Callable , Generator
53from dataclasses import dataclass
64from typing import Any , TypeVar
1614from duckdb_query_tools import duckdb_serialized_expression , sql_statement_analyzer
1715from pydantic import BaseModel
1816
19- ticket_with_metadata_indicator = b"<TICKET_WITH_METADATA>"
20-
2117
2218class FlightTicketData (BaseModel ):
2319 flight_name : str
@@ -36,6 +32,12 @@ def unpack(src: bytes) -> "FlightTicketData":
3632 return FlightTicketData .model_validate (unpacked )
3733
3834
35+ class AugmentedTicketData (BaseModel ):
36+ ticket : bytes
37+ metadata_compressed_length : int
38+ metadata : bytes
39+
40+
3941T = TypeVar ("T" , bound = FlightTicketData )
4042
4143
@@ -73,33 +75,32 @@ def endpoint(*, ticket_data: T, allow_metadata: bool, supports_predicate_pushdow
7375 )
7476
7577
76- def decode_ticket (* , ticket : flight .Ticket , model_selector : Callable [[str , bytes ], T ]) -> tuple [T , dict [str , str ]]:
78+ def decode_ticket (
79+ * , ticket : flight .Ticket , model_selector : Callable [[str , bytes ], T ], is_augmented_ticket : bool
80+ ) -> tuple [T , dict [str , str ]]:
7781 """
7882 Decode a ticket that has embedded and compressed metadata.
7983
8084 There is no concept of multiple headers handled here, headers are strings.
8185 """
82- if (
83- ticket .ticket [0 : min (len (ticket_with_metadata_indicator ), len (ticket .ticket ))]
84- == ticket_with_metadata_indicator
85- ):
86- # We have a ticket with metadata.
87- stream = io .BytesIO (ticket .ticket )
88- stream .seek (len (ticket_with_metadata_indicator ))
89- # Unpack the byte string as a uint32 ('I' is the format code for uint32)
90- ticket_data_length = struct .unpack ("<I" , stream .read (4 ))[0 ]
91-
92- # The ticket itself is a msgpack message.
93- msgpack_ticket_contents = stream .read (ticket_data_length )
94-
95- basic_data = FlightTicketData .unpack (msgpack_ticket_contents )
96- decoded_ticket_data = model_selector (basic_data .flight_name , msgpack_ticket_contents )
97-
98- metadata_decompressed_length = struct .unpack ("<I" , stream .read (4 ))[0 ]
86+
87+ if is_augmented_ticket :
88+ augmented_ticket = AugmentedTicketData .model_validate (
89+ msgpack .unpack (
90+ ticket .ticket ,
91+ raw = True ,
92+ object_hook = lambda s : {k .decode ("utf8" ): v for k , v in s .items ()},
93+ )
94+ )
95+
96+ basic_data = FlightTicketData .unpack (augmented_ticket .ticket )
97+ decoded_ticket_data = model_selector (basic_data .flight_name , augmented_ticket .ticket )
98+
99+ metadata_decompressed_length = augmented_ticket .metadata_compressed_length
99100 if metadata_decompressed_length > 1024 * 1024 * 2 :
100101 raise flight .FlightUnavailableError ("Decompressed Flight metadata is too large limit is 2mb." )
101102
102- metadata = stream . read ()
103+ metadata = augmented_ticket . metadata
103104 parsed_headers : dict [str , str ] = {}
104105 try :
105106 # That metadata is zstd compressed, so we need to decompress it.
0 commit comments