|
1 | 1 | from abc import ABC, abstractmethod |
2 | | -from collections.abc import Callable, Iterator |
| 2 | +from collections.abc import Callable, Generator, Iterator |
3 | 3 | from dataclasses import dataclass |
4 | 4 | from enum import Enum |
5 | | -from typing import Any, Generator, Generic, NoReturn, ParamSpec, TypeVar |
| 5 | +from typing import Any, Generic, NoReturn, ParamSpec, TypeVar |
6 | 6 |
|
7 | 7 | import msgpack |
8 | 8 | import pyarrow as pa |
@@ -708,102 +708,106 @@ def do_exchange( |
708 | 708 | ) |
709 | 709 |
|
710 | 710 | header_middleware = context.get_middleware("headers") |
| 711 | + assert header_middleware |
711 | 712 | airport_operation_headers = header_middleware.client_headers.get("airport-operation") |
712 | | - if airport_operation_headers is not None and len(airport_operation_headers) > 0: |
713 | | - airport_operation = airport_operation_headers[0] |
714 | | - |
715 | | - logger.debug("do_exchange", airport_operation=airport_operation) |
716 | | - |
717 | | - return_chunks_headers = header_middleware.client_headers.get("return-chunks") |
718 | | - if return_chunks_headers is None or len(return_chunks_headers) == 0: |
719 | | - raise flight.FlightServerError( |
720 | | - "The return-chunks header is required for this operation." |
721 | | - ) |
722 | | - return_chunks: bool = return_chunks_headers[0] == "1" |
723 | | - |
724 | | - last_metadata: Any = None |
725 | | - if airport_operation == ExchangeOperation.INSERT: |
726 | | - keys_inserted = self.exchange_insert( |
727 | | - context=call_context, |
728 | | - descriptor=descriptor, |
729 | | - reader=reader, |
730 | | - writer=writer, |
731 | | - return_chunks=return_chunks, |
732 | | - ) |
733 | | - last_metadata = {"total_inserted": keys_inserted} |
734 | | - elif airport_operation == ExchangeOperation.UPDATE: |
735 | | - keys_updated = self.exchange_update( |
736 | | - context=call_context, |
737 | | - descriptor=descriptor, |
738 | | - reader=reader, |
739 | | - writer=writer, |
740 | | - return_chunks=return_chunks, |
741 | | - ) |
742 | | - last_metadata = {"total_updated": keys_updated} |
743 | | - elif airport_operation == ExchangeOperation.DELETE: |
744 | | - keys_deleted = self.exchange_delete( |
745 | | - context=call_context, |
746 | | - descriptor=descriptor, |
747 | | - reader=reader, |
748 | | - writer=writer, |
749 | | - return_chunks=return_chunks, |
750 | | - ) |
751 | | - last_metadata = {"total_deleted": keys_deleted} |
752 | | - elif airport_operation == ExchangeOperation.SCALAR_FUNCTION: |
753 | | - self.exchange_scalar_function( |
754 | | - context=call_context, |
755 | | - descriptor=descriptor, |
756 | | - reader=reader, |
757 | | - writer=writer, |
758 | | - ) |
759 | | - elif airport_operation == ExchangeOperation.TABLE_FUNCTION_IN_OUT: |
760 | | - # The parameters are sent as the first chunk of the read stream |
761 | | - # as part of the metadata. |
762 | | - chunk = next(reader) |
763 | | - assert chunk.data is None |
764 | | - assert chunk.app_metadata is not None |
765 | | - |
766 | | - parameters = parameter_types.table_function_parameters(chunk.app_metadata) |
767 | | - |
768 | | - output_schema, generator = self.exchange_table_function_in_out( |
769 | | - context=call_context, |
770 | | - descriptor=descriptor, |
771 | | - parameters=parameters, |
772 | | - input_schema=reader.schema, |
773 | | - ) |
774 | | - |
775 | | - writer.begin(output_schema) |
776 | | - # Prime the generator |
777 | | - generator.send(None) |
778 | 713 |
|
779 | | - for item in reader: |
780 | | - assert item.data is not None |
781 | | - result = generator.send(item.data) |
782 | | - writer.write_batch(result) |
783 | | - |
784 | | - try: |
785 | | - generator.send(None) |
786 | | - except StopIteration as e: |
787 | | - if e.value is not None: |
788 | | - writer.write_batch(e.value) |
789 | | - writer.write_metadata(b"finished") |
790 | | - writer.close() |
791 | | - |
792 | | - else: |
793 | | - raise flight.FlightServerError( |
794 | | - f"Unknown airport-operation header: {airport_operation}" |
795 | | - ) |
796 | | - if airport_operation != ExchangeOperation.SCALAR_FUNCTION: |
797 | | - writer.write_metadata(msgpack.packb(last_metadata)) |
798 | | - writer.close() |
799 | | - return |
800 | | - |
801 | | - return self.impl_do_exchange( |
802 | | - context=call_context, |
803 | | - descriptor=descriptor, |
804 | | - reader=reader, |
805 | | - writer=writer, |
806 | | - ) |
| 714 | + if airport_operation_headers is None or len(airport_operation_headers) == 0: |
| 715 | + return self.impl_do_exchange( |
| 716 | + context=call_context, |
| 717 | + descriptor=descriptor, |
| 718 | + reader=reader, |
| 719 | + writer=writer, |
| 720 | + ) |
| 721 | + |
| 722 | + airport_operation = airport_operation_headers[0] |
| 723 | + logger.debug("do_exchange", airport_operation=airport_operation) |
| 724 | + |
| 725 | + return_chunks_headers = header_middleware.client_headers.get("return-chunks") |
| 726 | + if return_chunks_headers is None or len(return_chunks_headers) == 0: |
| 727 | + raise flight.FlightServerError( |
| 728 | + "The return-chunks header is required for this operation." |
| 729 | + ) |
| 730 | + return_chunks: bool = return_chunks_headers[0] == "1" |
| 731 | + |
| 732 | + last_metadata: Any = None |
| 733 | + if airport_operation == ExchangeOperation.INSERT: |
| 734 | + keys_inserted = self.exchange_insert( |
| 735 | + context=call_context, |
| 736 | + descriptor=descriptor, |
| 737 | + reader=reader, |
| 738 | + writer=writer, |
| 739 | + return_chunks=return_chunks, |
| 740 | + ) |
| 741 | + last_metadata = {"total_inserted": keys_inserted} |
| 742 | + elif airport_operation == ExchangeOperation.UPDATE: |
| 743 | + keys_updated = self.exchange_update( |
| 744 | + context=call_context, |
| 745 | + descriptor=descriptor, |
| 746 | + reader=reader, |
| 747 | + writer=writer, |
| 748 | + return_chunks=return_chunks, |
| 749 | + ) |
| 750 | + last_metadata = {"total_updated": keys_updated} |
| 751 | + elif airport_operation == ExchangeOperation.DELETE: |
| 752 | + keys_deleted = self.exchange_delete( |
| 753 | + context=call_context, |
| 754 | + descriptor=descriptor, |
| 755 | + reader=reader, |
| 756 | + writer=writer, |
| 757 | + return_chunks=return_chunks, |
| 758 | + ) |
| 759 | + last_metadata = {"total_deleted": keys_deleted} |
| 760 | + elif airport_operation == ExchangeOperation.SCALAR_FUNCTION: |
| 761 | + self.exchange_scalar_function( |
| 762 | + context=call_context, |
| 763 | + descriptor=descriptor, |
| 764 | + reader=reader, |
| 765 | + writer=writer, |
| 766 | + ) |
| 767 | + elif airport_operation == ExchangeOperation.TABLE_FUNCTION_IN_OUT: |
| 768 | + # The parameters are sent as the first chunk of the read stream |
| 769 | + # as part of the metadata. |
| 770 | + chunk = next(reader) |
| 771 | + assert chunk.data is None |
| 772 | + assert chunk.app_metadata is not None |
| 773 | + |
| 774 | + parameters = parameter_types.table_function_parameters(chunk.app_metadata) |
| 775 | + |
| 776 | + output_schema, generator = self.exchange_table_function_in_out( |
| 777 | + context=call_context, |
| 778 | + descriptor=descriptor, |
| 779 | + parameters=parameters, |
| 780 | + input_schema=reader.schema, |
| 781 | + ) |
| 782 | + |
| 783 | + writer.begin(output_schema) |
| 784 | + # Prime the generator |
| 785 | + generator.send(None) |
| 786 | + |
| 787 | + for item in reader: |
| 788 | + assert item.data is not None |
| 789 | + result = generator.send(item.data) |
| 790 | + writer.write_batch(result) |
| 791 | + |
| 792 | + try: |
| 793 | + generator.send(None) |
| 794 | + except StopIteration as e: |
| 795 | + if e.value is not None: |
| 796 | + writer.write_batch(e.value) |
| 797 | + else: |
| 798 | + raise flight.FlightServerError( |
| 799 | + f"Unknown airport-operation header: {airport_operation}" |
| 800 | + ) |
| 801 | + if airport_operation not in ( |
| 802 | + ExchangeOperation.SCALAR_FUNCTION, |
| 803 | + ExchangeOperation.TABLE_FUNCTION_IN_OUT, |
| 804 | + ): |
| 805 | + writer.write_metadata(msgpack.packb(last_metadata)) |
| 806 | + elif airport_operation == ExchangeOperation.TABLE_FUNCTION_IN_OUT: |
| 807 | + # The last metadata is the end of the stream |
| 808 | + writer.write_metadata(b"finished") |
| 809 | + |
| 810 | + writer.close() |
807 | 811 | except Exception as e: |
808 | 812 | logger.exception("do_exchange", error=str(e)) |
809 | 813 | raise |
|
0 commit comments