55from typing import Any , Generic , NoReturn , ParamSpec , TypeVar
66
77import msgpack
8- import pyarrow as pa
98import pyarrow .flight as flight
109import structlog
1110import zstandard as zstd
12- from pydantic import BaseModel , ConfigDict , field_validator
11+ from pydantic import BaseModel
1312
1413from query_farm_server_base import action_decoders
1514
@@ -72,6 +71,7 @@ class ExchangeOperation(str, Enum):
7271 INSERT = "insert"
7372 UPDATE = "update"
7473 DELETE = "delete"
74+ SCALAR_FUNCTION = "scalar_function"
7575
7676
7777class ActionType (str , Enum ):
@@ -118,6 +118,7 @@ class ActionHandlerSpec:
118118
119119def compress_list_schemas_result (result : AirportSerializedCatalogRoot ) -> list [Any ]:
120120 packed_data = msgpack .packb (result .model_dump ())
121+ assert packed_data
121122 compressor = zstd .ZstdCompressor (level = SCHEMA_TOP_LEVEL_COMPRESSION_LEVEL )
122123 compressed_data = compressor .compress (packed_data )
123124 return [len (packed_data ), compressed_data ]
@@ -310,14 +311,14 @@ def impl_list_flights(
310311 def list_flights (
311312 self , context : flight .ServerCallContext , criteria : bytes
312313 ) -> Iterator [flight .FlightInfo ]:
313- try :
314- caller = self .credentials_from_context_ (context )
314+ caller = self .credentials_from_context_ (context )
315315
316- logger = log .bind (
317- ** self .auth_logging_items (context , caller ),
318- criteria = criteria ,
319- )
316+ logger = log .bind (
317+ ** self .auth_logging_items (context , caller ),
318+ criteria = criteria ,
319+ )
320320
321+ try :
321322 logger .info ("list_flights" , criteria = criteria )
322323
323324 call_context = CallContext (
@@ -347,14 +348,13 @@ def get_flight_info(
347348 context : flight .ServerCallContext ,
348349 descriptor : flight .FlightDescriptor ,
349350 ) -> flight .FlightInfo :
350- try :
351- caller = self .credentials_from_context_ (context )
352-
353- logger = log .bind (
354- ** self .auth_logging_items (context , caller ),
355- descriptor = descriptor ,
356- )
351+ caller = self .credentials_from_context_ (context )
357352
353+ logger = log .bind (
354+ ** self .auth_logging_items (context , caller ),
355+ descriptor = descriptor ,
356+ )
357+ try :
358358 logger .info (
359359 "get_flight_info" ,
360360 descriptor = descriptor ,
@@ -567,18 +567,20 @@ def action_drop_schema(
567567 self ._unimplemented_action (ActionType .DROP_SCHEMA )
568568
569569 def pack_result (self , value : Any ) -> Iterator [bytes ]:
570- return iter ([msgpack .packb (value )])
570+ result = msgpack .packb (value )
571+ assert result
572+ return iter ([result ])
571573
572574 def do_action (
573575 self , context : flight .ServerCallContext , action : flight .Action
574576 ) -> Iterator [bytes ]:
575- try :
576- caller = self .credentials_from_context_ (context )
577+ caller = self .credentials_from_context_ (context )
577578
578- logger = log .bind (
579- ** self .auth_logging_items (context , caller ),
580- )
579+ logger = log .bind (
580+ ** self .auth_logging_items (context , caller ),
581+ )
581582
583+ try :
582584 call_context = CallContext (
583585 context = context ,
584586 caller = caller ,
@@ -641,6 +643,17 @@ def exchange_delete(
641643 ) -> int :
642644 self ._unimplemented_exchange_operation (ExchangeOperation .DELETE )
643645
646+ def exchange_scalar_function (
647+ self ,
648+ * ,
649+ context : CallContext [AccountType , TokenType ],
650+ descriptor : flight .FlightDescriptor ,
651+ reader : flight .MetadataRecordBatchReader ,
652+ writer : flight .MetadataRecordBatchWriter ,
653+ return_chunks : bool ,
654+ ) -> int :
655+ self ._unimplemented_exchange_operation (ExchangeOperation .SCALAR_FUNCTION )
656+
644657 def exchange_update (
645658 self ,
646659 * ,
@@ -659,14 +672,13 @@ def do_exchange(
659672 reader : flight .MetadataRecordBatchReader ,
660673 writer : flight .MetadataRecordBatchWriter ,
661674 ) -> None :
662- try :
663- caller = self .credentials_from_context_ (context )
664-
665- logger = log .bind (
666- ** self .auth_logging_items (context , caller ),
667- descriptor = descriptor ,
668- )
675+ caller = self .credentials_from_context_ (context )
669676
677+ logger = log .bind (
678+ ** self .auth_logging_items (context , caller ),
679+ descriptor = descriptor ,
680+ )
681+ try :
670682 call_context = CallContext (
671683 context = context ,
672684 caller = caller ,
@@ -687,7 +699,7 @@ def do_exchange(
687699 )
688700 return_chunks : bool = return_chunks_headers [0 ] == "1"
689701
690- last_metadata : Any
702+ last_metadata : Any = None
691703 if airport_operation == ExchangeOperation .INSERT :
692704 keys_inserted = self .exchange_insert (
693705 context = call_context ,
@@ -715,11 +727,19 @@ def do_exchange(
715727 return_chunks = return_chunks ,
716728 )
717729 last_metadata = {"total_deleted" : keys_deleted }
730+ elif airport_operation == ExchangeOperation .SCALAR_FUNCTION :
731+ self .scalar_function (
732+ context = call_context ,
733+ descriptor = descriptor ,
734+ reader = reader ,
735+ writer = writer ,
736+ )
718737 else :
719738 raise flight .FlightServerError (
720739 f"Unknown airport-operation header: { airport_operation } "
721740 )
722- writer .write_metadata (msgpack .packb (last_metadata ))
741+ if airport_operation != ExchangeOperation .SCALAR_FUNCTION :
742+ writer .write_metadata (msgpack .packb (last_metadata ))
723743 writer .close ()
724744 return
725745
@@ -744,13 +764,12 @@ def impl_do_get(
744764 def do_get (
745765 self , context : flight .ServerCallContext , ticket : flight .Ticket
746766 ) -> flight .RecordBatchStream :
747- try :
748- caller = self .credentials_from_context_ (context )
749-
750- logger = log .bind (
751- ** self .auth_logging_items (context , caller ),
752- )
767+ caller = self .credentials_from_context_ (context )
768+ logger = log .bind (
769+ ** self .auth_logging_items (context , caller ),
770+ )
753771
772+ try :
754773 logger .info ("do_get" , ticket = ticket )
755774
756775 call_context = CallContext (
@@ -784,13 +803,12 @@ def do_put(
784803 reader : flight .MetadataRecordBatchReader ,
785804 writer : flight .FlightMetadataWriter ,
786805 ) -> None :
787- try :
788- caller = self .credentials_from_context_ (context )
789-
790- logger = log .bind (
791- ** self .auth_logging_items (context , caller ),
792- )
806+ caller = self .credentials_from_context_ (context )
807+ logger = log .bind (
808+ ** self .auth_logging_items (context , caller ),
809+ )
793810
811+ try :
794812 logger .info ("do_put" , descriptor = descriptor )
795813
796814 call_context = CallContext (
0 commit comments