88
99log = structlog .get_logger ()
1010
11-
1211AccountType = TypeVar ("AccountType" , bound = auth .Account )
1312TokenType = TypeVar ("TokenType" , bound = auth .AccountToken )
1413
1514
16- class Caller (Generic [AccountType , TokenType ]):
17- def __init__ (self , * , account : AccountType , token : TokenType ) -> None :
18- self .account = account
19- self .token = token
20-
21-
2215class BasicFlightServer (flight .FlightServerBase , Generic [AccountType , TokenType ]):
2316 def __init__ (
2417 self ,
@@ -38,21 +31,21 @@ def auth_middleware(
3831 assert isinstance (auth_middleware , middleware .SaveCredentialsMiddleware )
3932 return auth_middleware
4033
41- def caller_from_context_ (
34+ def credentials_from_context_ (
4235 self , context : flight .ServerCallContext
43- ) -> Caller [AccountType , TokenType ]:
36+ ) -> middleware . SuppliedCredentials [AccountType , TokenType ] | None | None :
4437 auth_middleware = self .auth_middleware (context )
45- return Caller ( account = auth_middleware .account , token = auth_middleware . token )
38+ return auth_middleware .credentials
4639
4740 def auth_logging_items (
4841 self ,
4942 context : flight .ServerCallContext ,
50- caller : Caller [AccountType , TokenType ],
43+ credentials : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
5144 ) -> dict [str , Any ]:
5245 """Return the items that will be bound to the logger."""
5346 return {
54- "token" : caller .token ,
55- "account" : caller .account .account_id ,
47+ "token" : None if credentials is None else credentials .token ,
48+ "account" : None if credentials is None else credentials .account .account_id ,
5649 "address" : context .peer (),
5750 }
5851
@@ -61,15 +54,15 @@ def impl_list_flights(
6154 * ,
6255 context : flight .ServerCallContext ,
6356 criteria : bytes ,
64- caller : Caller [AccountType , TokenType ],
57+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
6558 logger : structlog .BoundLogger ,
6659 ) -> Iterator [flight .FlightInfo ]:
6760 raise NotImplementedError ("impl_list_flights not implemented" )
6861
6962 def list_flights (
7063 self , context : flight .ServerCallContext , criteria : bytes
7164 ) -> Iterator [flight .FlightInfo ]:
72- caller = self .caller_from_context_ (context )
65+ caller = self .credentials_from_context_ (context )
7366
7467 logger = log .bind (
7568 ** self .auth_logging_items (context , caller ),
@@ -90,7 +83,7 @@ def impl_get_flight_info(
9083 * ,
9184 context : flight .ServerCallContext ,
9285 descriptor : flight .FlightDescriptor ,
93- caller : Caller [AccountType , TokenType ],
86+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
9487 logger : structlog .BoundLogger ,
9588 ) -> flight .FlightInfo :
9689 raise NotImplementedError ("impl_get_flight_info not implemented" )
@@ -100,7 +93,7 @@ def get_flight_info(
10093 context : flight .ServerCallContext ,
10194 descriptor : flight .FlightDescriptor ,
10295 ) -> flight .FlightInfo :
103- caller = self .caller_from_context_ (context )
96+ caller = self .credentials_from_context_ (context )
10497
10598 logger = log .bind (
10699 ** self .auth_logging_items (context , caller ),
@@ -123,15 +116,15 @@ def impl_do_action(
123116 * ,
124117 context : flight .ServerCallContext ,
125118 action : flight .Action ,
126- caller : Caller [AccountType , TokenType ],
119+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
127120 logger : structlog .BoundLogger ,
128121 ) -> Iterator [bytes ]:
129122 raise NotImplementedError ("impl_do_action not implemented" )
130123
131124 def do_action (
132125 self , context : flight .ServerCallContext , action : flight .Action
133126 ) -> Iterator [bytes ]:
134- caller = self .caller_from_context_ (context )
127+ caller = self .credentials_from_context_ (context )
135128
136129 logger = log .bind (
137130 ** self .auth_logging_items (context , caller ),
@@ -154,7 +147,7 @@ def impl_do_exchange(
154147 descriptor : flight .FlightDescriptor ,
155148 reader : flight .MetadataRecordBatchReader ,
156149 writer : flight .MetadataRecordBatchWriter ,
157- caller : Caller [AccountType , TokenType ],
150+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
158151 logger : structlog .BoundLogger ,
159152 ) -> None :
160153 raise NotImplementedError ("impl_do_exchange not implemented" )
@@ -166,7 +159,7 @@ def do_exchange(
166159 reader : flight .MetadataRecordBatchReader ,
167160 writer : flight .MetadataRecordBatchWriter ,
168161 ) -> None :
169- caller = self .caller_from_context_ (context )
162+ caller = self .credentials_from_context_ (context )
170163
171164 logger = log .bind (
172165 ** self .auth_logging_items (context , caller ),
@@ -187,15 +180,15 @@ def impl_do_get(
187180 * ,
188181 context : flight .ServerCallContext ,
189182 ticket : flight .Ticket ,
190- caller : Caller [AccountType , TokenType ],
183+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
191184 logger : structlog .BoundLogger ,
192185 ) -> flight .RecordBatchStream :
193186 raise NotImplementedError ("impl_do_get not implemented" )
194187
195188 def do_get (
196189 self , context : flight .ServerCallContext , ticket : flight .Ticket
197190 ) -> flight .RecordBatchStream :
198- caller = self .caller_from_context_ (context )
191+ caller = self .credentials_from_context_ (context )
199192
200193 logger = log .bind (
201194 ** self .auth_logging_items (context , caller ),
@@ -214,7 +207,7 @@ def impl_do_put(
214207 self ,
215208 * ,
216209 context : flight .ServerCallContext ,
217- caller : Caller [AccountType , TokenType ],
210+ caller : middleware . SuppliedCredentials [AccountType , TokenType ] | None ,
218211 logger : structlog .BoundLogger ,
219212 descriptor : flight .FlightDescriptor ,
220213 reader : flight .MetadataRecordBatchReader ,
@@ -229,7 +222,7 @@ def do_put(
229222 reader : flight .MetadataRecordBatchReader ,
230223 writer : flight .FlightMetadataWriter ,
231224 ) -> None :
232- caller = self .caller_from_context_ (context )
225+ caller = self .credentials_from_context_ (context )
233226
234227 logger = log .bind (
235228 ** self .auth_logging_items (context , caller ),
0 commit comments