Skip to content

Commit 7ebe4cf

Browse files
committed
fix: add support for scalar function
1 parent 3311248 commit 7ebe4cf

1 file changed

Lines changed: 60 additions & 42 deletions

File tree

src/query_farm_server_base/server.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from typing import Any, Generic, NoReturn, ParamSpec, TypeVar
66

77
import msgpack
8-
import pyarrow as pa
98
import pyarrow.flight as flight
109
import structlog
1110
import zstandard as zstd
12-
from pydantic import BaseModel, ConfigDict, field_validator
11+
from pydantic import BaseModel
1312

1413
from query_farm_server_base import action_decoders
1514

@@ -72,6 +71,7 @@ class ExchangeOperation(str, Enum):
7271
INSERT = "insert"
7372
UPDATE = "update"
7473
DELETE = "delete"
74+
SCALAR_FUNCTION = "scalar_function"
7575

7676

7777
class ActionType(str, Enum):
@@ -118,6 +118,7 @@ class ActionHandlerSpec:
118118

119119
def compress_list_schemas_result(result: AirportSerializedCatalogRoot) -> list[Any]:
120120
packed_data = msgpack.packb(result.model_dump())
121+
assert packed_data
121122
compressor = zstd.ZstdCompressor(level=SCHEMA_TOP_LEVEL_COMPRESSION_LEVEL)
122123
compressed_data = compressor.compress(packed_data)
123124
return [len(packed_data), compressed_data]
@@ -310,14 +311,14 @@ def impl_list_flights(
310311
def list_flights(
311312
self, context: flight.ServerCallContext, criteria: bytes
312313
) -> Iterator[flight.FlightInfo]:
313-
try:
314-
caller = self.credentials_from_context_(context)
314+
caller = self.credentials_from_context_(context)
315315

316-
logger = log.bind(
317-
**self.auth_logging_items(context, caller),
318-
criteria=criteria,
319-
)
316+
logger = log.bind(
317+
**self.auth_logging_items(context, caller),
318+
criteria=criteria,
319+
)
320320

321+
try:
321322
logger.info("list_flights", criteria=criteria)
322323

323324
call_context = CallContext(
@@ -347,14 +348,13 @@ def get_flight_info(
347348
context: flight.ServerCallContext,
348349
descriptor: flight.FlightDescriptor,
349350
) -> flight.FlightInfo:
350-
try:
351-
caller = self.credentials_from_context_(context)
352-
353-
logger = log.bind(
354-
**self.auth_logging_items(context, caller),
355-
descriptor=descriptor,
356-
)
351+
caller = self.credentials_from_context_(context)
357352

353+
logger = log.bind(
354+
**self.auth_logging_items(context, caller),
355+
descriptor=descriptor,
356+
)
357+
try:
358358
logger.info(
359359
"get_flight_info",
360360
descriptor=descriptor,
@@ -567,18 +567,20 @@ def action_drop_schema(
567567
self._unimplemented_action(ActionType.DROP_SCHEMA)
568568

569569
def pack_result(self, value: Any) -> Iterator[bytes]:
570-
return iter([msgpack.packb(value)])
570+
result = msgpack.packb(value)
571+
assert result
572+
return iter([result])
571573

572574
def do_action(
573575
self, context: flight.ServerCallContext, action: flight.Action
574576
) -> Iterator[bytes]:
575-
try:
576-
caller = self.credentials_from_context_(context)
577+
caller = self.credentials_from_context_(context)
577578

578-
logger = log.bind(
579-
**self.auth_logging_items(context, caller),
580-
)
579+
logger = log.bind(
580+
**self.auth_logging_items(context, caller),
581+
)
581582

583+
try:
582584
call_context = CallContext(
583585
context=context,
584586
caller=caller,
@@ -641,6 +643,17 @@ def exchange_delete(
641643
) -> int:
642644
self._unimplemented_exchange_operation(ExchangeOperation.DELETE)
643645

646+
def exchange_scalar_function(
647+
self,
648+
*,
649+
context: CallContext[AccountType, TokenType],
650+
descriptor: flight.FlightDescriptor,
651+
reader: flight.MetadataRecordBatchReader,
652+
writer: flight.MetadataRecordBatchWriter,
653+
return_chunks: bool,
654+
) -> int:
655+
self._unimplemented_exchange_operation(ExchangeOperation.SCALAR_FUNCTION)
656+
644657
def exchange_update(
645658
self,
646659
*,
@@ -659,14 +672,13 @@ def do_exchange(
659672
reader: flight.MetadataRecordBatchReader,
660673
writer: flight.MetadataRecordBatchWriter,
661674
) -> None:
662-
try:
663-
caller = self.credentials_from_context_(context)
664-
665-
logger = log.bind(
666-
**self.auth_logging_items(context, caller),
667-
descriptor=descriptor,
668-
)
675+
caller = self.credentials_from_context_(context)
669676

677+
logger = log.bind(
678+
**self.auth_logging_items(context, caller),
679+
descriptor=descriptor,
680+
)
681+
try:
670682
call_context = CallContext(
671683
context=context,
672684
caller=caller,
@@ -687,7 +699,7 @@ def do_exchange(
687699
)
688700
return_chunks: bool = return_chunks_headers[0] == "1"
689701

690-
last_metadata: Any
702+
last_metadata: Any = None
691703
if airport_operation == ExchangeOperation.INSERT:
692704
keys_inserted = self.exchange_insert(
693705
context=call_context,
@@ -715,11 +727,19 @@ def do_exchange(
715727
return_chunks=return_chunks,
716728
)
717729
last_metadata = {"total_deleted": keys_deleted}
730+
elif airport_operation == ExchangeOperation.SCALAR_FUNCTION:
731+
self.scalar_function(
732+
context=call_context,
733+
descriptor=descriptor,
734+
reader=reader,
735+
writer=writer,
736+
)
718737
else:
719738
raise flight.FlightServerError(
720739
f"Unknown airport-operation header: {airport_operation}"
721740
)
722-
writer.write_metadata(msgpack.packb(last_metadata))
741+
if airport_operation != ExchangeOperation.SCALAR_FUNCTION:
742+
writer.write_metadata(msgpack.packb(last_metadata))
723743
writer.close()
724744
return
725745

@@ -744,13 +764,12 @@ def impl_do_get(
744764
def do_get(
745765
self, context: flight.ServerCallContext, ticket: flight.Ticket
746766
) -> flight.RecordBatchStream:
747-
try:
748-
caller = self.credentials_from_context_(context)
749-
750-
logger = log.bind(
751-
**self.auth_logging_items(context, caller),
752-
)
767+
caller = self.credentials_from_context_(context)
768+
logger = log.bind(
769+
**self.auth_logging_items(context, caller),
770+
)
753771

772+
try:
754773
logger.info("do_get", ticket=ticket)
755774

756775
call_context = CallContext(
@@ -784,13 +803,12 @@ def do_put(
784803
reader: flight.MetadataRecordBatchReader,
785804
writer: flight.FlightMetadataWriter,
786805
) -> None:
787-
try:
788-
caller = self.credentials_from_context_(context)
789-
790-
logger = log.bind(
791-
**self.auth_logging_items(context, caller),
792-
)
806+
caller = self.credentials_from_context_(context)
807+
logger = log.bind(
808+
**self.auth_logging_items(context, caller),
809+
)
793810

811+
try:
794812
logger.info("do_put", descriptor=descriptor)
795813

796814
call_context = CallContext(

0 commit comments

Comments
 (0)