Skip to content

Commit 294670a

Browse files
committed
fix: improve server
1 parent 3795d01 commit 294670a

2 files changed

Lines changed: 20 additions & 25 deletions

File tree

src/query_farm_server_base/middleware.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import dataclass
12
from typing import Any, Generic, TypeVar
23

34
import pyarrow.flight as flight
@@ -26,6 +27,7 @@ def start_call(self, info: Any, headers: dict[str, Any]) -> SaveHeadersMiddlewar
2627
TokenType = TypeVar("TokenType", bound=auth.AccountToken)
2728

2829

30+
@dataclass
2931
class SuppliedCredentials(Generic[AccountType, TokenType]):
3032
def __init__(self, token: TokenType, account: AccountType) -> None:
3133
assert token

src/query_farm_server_base/server.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,10 @@
88

99
log = structlog.get_logger()
1010

11-
1211
AccountType = TypeVar("AccountType", bound=auth.Account)
1312
TokenType = TypeVar("TokenType", bound=auth.AccountToken)
1413

1514

16-
class Caller(Generic[AccountType, TokenType]):
17-
def __init__(self, *, account: AccountType, token: TokenType) -> None:
18-
self.account = account
19-
self.token = token
20-
21-
2215
class BasicFlightServer(flight.FlightServerBase, Generic[AccountType, TokenType]):
2316
def __init__(
2417
self,
@@ -38,21 +31,21 @@ def auth_middleware(
3831
assert isinstance(auth_middleware, middleware.SaveCredentialsMiddleware)
3932
return auth_middleware
4033

41-
def caller_from_context_(
34+
def credentials_from_context_(
4235
self, context: flight.ServerCallContext
43-
) -> Caller[AccountType, TokenType]:
36+
) -> middleware.SuppliedCredentials[AccountType, TokenType] | None | None:
4437
auth_middleware = self.auth_middleware(context)
45-
return Caller(account=auth_middleware.account, token=auth_middleware.token)
38+
return auth_middleware.credentials
4639

4740
def auth_logging_items(
4841
self,
4942
context: flight.ServerCallContext,
50-
caller: Caller[AccountType, TokenType],
43+
credentials: middleware.SuppliedCredentials[AccountType, TokenType] | None,
5144
) -> dict[str, Any]:
5245
"""Return the items that will be bound to the logger."""
5346
return {
54-
"token": caller.token,
55-
"account": caller.account.account_id,
47+
"token": None if credentials is None else credentials.token,
48+
"account": None if credentials is None else credentials.account.account_id,
5649
"address": context.peer(),
5750
}
5851

@@ -61,15 +54,15 @@ def impl_list_flights(
6154
*,
6255
context: flight.ServerCallContext,
6356
criteria: bytes,
64-
caller: Caller[AccountType, TokenType],
57+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
6558
logger: structlog.BoundLogger,
6659
) -> Iterator[flight.FlightInfo]:
6760
raise NotImplementedError("impl_list_flights not implemented")
6861

6962
def list_flights(
7063
self, context: flight.ServerCallContext, criteria: bytes
7164
) -> Iterator[flight.FlightInfo]:
72-
caller = self.caller_from_context_(context)
65+
caller = self.credentials_from_context_(context)
7366

7467
logger = log.bind(
7568
**self.auth_logging_items(context, caller),
@@ -90,7 +83,7 @@ def impl_get_flight_info(
9083
*,
9184
context: flight.ServerCallContext,
9285
descriptor: flight.FlightDescriptor,
93-
caller: Caller[AccountType, TokenType],
86+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
9487
logger: structlog.BoundLogger,
9588
) -> flight.FlightInfo:
9689
raise NotImplementedError("impl_get_flight_info not implemented")
@@ -100,7 +93,7 @@ def get_flight_info(
10093
context: flight.ServerCallContext,
10194
descriptor: flight.FlightDescriptor,
10295
) -> flight.FlightInfo:
103-
caller = self.caller_from_context_(context)
96+
caller = self.credentials_from_context_(context)
10497

10598
logger = log.bind(
10699
**self.auth_logging_items(context, caller),
@@ -123,15 +116,15 @@ def impl_do_action(
123116
*,
124117
context: flight.ServerCallContext,
125118
action: flight.Action,
126-
caller: Caller[AccountType, TokenType],
119+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
127120
logger: structlog.BoundLogger,
128121
) -> Iterator[bytes]:
129122
raise NotImplementedError("impl_do_action not implemented")
130123

131124
def do_action(
132125
self, context: flight.ServerCallContext, action: flight.Action
133126
) -> Iterator[bytes]:
134-
caller = self.caller_from_context_(context)
127+
caller = self.credentials_from_context_(context)
135128

136129
logger = log.bind(
137130
**self.auth_logging_items(context, caller),
@@ -154,7 +147,7 @@ def impl_do_exchange(
154147
descriptor: flight.FlightDescriptor,
155148
reader: flight.MetadataRecordBatchReader,
156149
writer: flight.MetadataRecordBatchWriter,
157-
caller: Caller[AccountType, TokenType],
150+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
158151
logger: structlog.BoundLogger,
159152
) -> None:
160153
raise NotImplementedError("impl_do_exchange not implemented")
@@ -166,7 +159,7 @@ def do_exchange(
166159
reader: flight.MetadataRecordBatchReader,
167160
writer: flight.MetadataRecordBatchWriter,
168161
) -> None:
169-
caller = self.caller_from_context_(context)
162+
caller = self.credentials_from_context_(context)
170163

171164
logger = log.bind(
172165
**self.auth_logging_items(context, caller),
@@ -187,15 +180,15 @@ def impl_do_get(
187180
*,
188181
context: flight.ServerCallContext,
189182
ticket: flight.Ticket,
190-
caller: Caller[AccountType, TokenType],
183+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
191184
logger: structlog.BoundLogger,
192185
) -> flight.RecordBatchStream:
193186
raise NotImplementedError("impl_do_get not implemented")
194187

195188
def do_get(
196189
self, context: flight.ServerCallContext, ticket: flight.Ticket
197190
) -> flight.RecordBatchStream:
198-
caller = self.caller_from_context_(context)
191+
caller = self.credentials_from_context_(context)
199192

200193
logger = log.bind(
201194
**self.auth_logging_items(context, caller),
@@ -214,7 +207,7 @@ def impl_do_put(
214207
self,
215208
*,
216209
context: flight.ServerCallContext,
217-
caller: Caller[AccountType, TokenType],
210+
caller: middleware.SuppliedCredentials[AccountType, TokenType] | None,
218211
logger: structlog.BoundLogger,
219212
descriptor: flight.FlightDescriptor,
220213
reader: flight.MetadataRecordBatchReader,
@@ -229,7 +222,7 @@ def do_put(
229222
reader: flight.MetadataRecordBatchReader,
230223
writer: flight.FlightMetadataWriter,
231224
) -> None:
232-
caller = self.caller_from_context_(context)
225+
caller = self.credentials_from_context_(context)
233226

234227
logger = log.bind(
235228
**self.auth_logging_items(context, caller),

0 commit comments

Comments
 (0)