1010import pyarrow .flight as flight
1111import sqlglot
1212import structlog
13- import zstandard as zstd
1413from duckdb_query_tools import (
1514 duckdb_serialized_expression ,
1615 sql_statement_analyzer ,
2019
2120class FlightTicketData (BaseModel ):
2221 flight_name : str
22+ json_filters : str | None = None
23+ column_ids : list [int ] | None = None
2324
2425 @staticmethod
2526 def unpack (src : bytes ) -> "FlightTicketData" :
26- decode_fields = {"flight_name" }
27+ decode_fields = {"flight_name" , "json_filters" , "column_ids" }
2728 unpacked = msgpack .unpackb (
2829 src ,
2930 raw = True ,
@@ -35,18 +36,6 @@ def unpack(src: bytes) -> "FlightTicketData":
3536
3637 return FlightTicketData .model_validate (unpacked )
3738
38-
39- class AugmentedTicketData (BaseModel ):
40- ticket : bytes
41- metadata_uncompressed_length : int
42- metadata : bytes
43-
44-
45- class FilterMetadata (BaseModel ):
46- json_filters : str
47- column_ids : list [int ]
48-
49-
5039T = TypeVar ("T" , bound = FlightTicketData )
5140
5241
@@ -72,7 +61,6 @@ def generate_record_batches_for_used_fields(
7261def endpoint (
7362 * ,
7463 ticket_data : T ,
75- supports_predicate_pushdown : bool ,
7664 locations : list [str ] | None = ["arrow-flight-reuse-connection://?" ],
7765) -> flight .FlightEndpoint :
7866 """Create a FlightEndpoint that allows metadata filtering to be passed
@@ -82,74 +70,33 @@ def endpoint(
8270 return flight .FlightEndpoint (
8371 packed_data ,
8472 locations ,
85- None ,
86- msgpack .packb ({"supports_predicate_pushdown" : supports_predicate_pushdown }),
8773 )
8874
8975
9076def decode_ticket (
9177 * ,
9278 ticket : flight .Ticket ,
9379 model_selector : Callable [[str , bytes ], T ],
94- is_augmented_ticket : bool ,
9580) -> tuple [T , dict [str , str ]]:
9681 """
9782 Decode a ticket that has embedded and compressed metadata.
9883
9984 There is no concept of multiple headers handled here, headers are strings.
10085 """
86+ parsed_headers : dict [str , str ] = {}
10187
102- if is_augmented_ticket :
103- augmented_ticket = AugmentedTicketData .model_validate (
104- msgpack .unpackb (
105- ticket .ticket ,
106- raw = True ,
107- object_hook = lambda s : {k .decode ("utf8" ): v for k , v in s .items ()},
108- )
109- )
88+ basic_data = FlightTicketData .unpack (ticket .ticket )
11089
111- basic_data = FlightTicketData . unpack ( augmented_ticket . ticket )
112- decoded_ticket_data = model_selector ( basic_data .flight_name , augmented_ticket . ticket )
90+ if basic_data . json_filters and basic_data . json_filters != "" :
91+ parsed_headers = { "airport-duckdb-json-filters" : basic_data .json_filters }
11392
114- if augmented_ticket .metadata_uncompressed_length > 1024 * 1024 * 2 :
115- raise flight .FlightUnavailableError (
116- "Decompressed Flight metadata is too large limit is 2mb."
117- )
118-
119- metadata = augmented_ticket .metadata
120- parsed_headers : dict [str , str ] = {}
121- try :
122- # That metadata is zstd compressed, so we need to decompress it.
123- decompressor = zstd .ZstdDecompressor ()
124- decompressed_metadata = decompressor .decompress (metadata )
125-
126- decode_fields = {"json_filters" }
127- unpacked_data = msgpack .unpackb (
128- decompressed_metadata ,
129- raw = True ,
130- object_hook = lambda s : {
131- k .decode ("utf8" ): v .decode ("utf8" ) if k in decode_fields else v
132- for k , v in s .items ()
133- },
134- )
135-
136- filter_metadata = FilterMetadata .model_validate (unpacked_data )
137-
138- if filter_metadata .json_filters and filter_metadata .json_filters != "" :
139- parsed_headers ["airport-duckdb-json-filters" ] = filter_metadata .json_filters
140-
141- if len (filter_metadata .column_ids ) > 0 :
142- parsed_headers ["airport-duckdb-column-ids" ] = "," .join (
143- map (str , filter_metadata .column_ids )
144- )
93+ if basic_data .column_ids and len (basic_data .column_ids ) > 0 :
94+ parsed_headers ["airport-duckdb-column-ids" ] = "," .join (
95+ map (str , basic_data .column_ids )
96+ )
14597
146- except Exception as e :
147- raise flight .FlightUnavailableError ("Unable to decompress metadata." ) from e
148- return decoded_ticket_data , parsed_headers
149- else :
150- basic_data = FlightTicketData .unpack (ticket .ticket )
151- decoded_ticket_data = model_selector (basic_data .flight_name , ticket .ticket )
152- return decoded_ticket_data , {}
98+ decoded_ticket_data = model_selector (basic_data .flight_name , ticket .ticket )
99+ return decoded_ticket_data , parsed_headers
153100
154101
155102@dataclass
0 commit comments