77import typing
88from collections .abc import (
99 AsyncGenerator ,
10+ Callable ,
1011 Coroutine ,
1112 Sequence ,
1213)
1314from typing import (
1415 Any ,
15- Callable ,
16- Optional ,
16+ cast ,
1717)
1818
1919from eth_typing import ChecksumAddress , HexStr
4848 GolemBaseUpdate ,
4949 QueryEntitiesResult ,
5050 UpdateEntityReturnType ,
51+ WatchLogsHandle ,
5152)
5253from .utils import parse_legacy_btl_extended_log , rlp_encode_transaction
5354
@@ -126,12 +127,16 @@ def __init__(self, rpc_url: str):
126127 async def get_storage_value (self , entity_key : EntityKey ) -> bytes :
127128 """Get the storage value stored in the given entity."""
128129 return base64 .b64decode (
129- await self .eth .get_storage_value (entity_key .as_hex_string ()) # type: ignore
130+ await self .eth .get_storage_value ( # type: ignore[attr-defined]
131+ entity_key .as_hex_string ()
132+ )
130133 )
131134
132135 async def get_entity_metadata (self , entity_key : EntityKey ) -> EntityMetadata :
133136 """Get the metadata of the given entity."""
134- metadata = await self .eth .get_entity_metadata (entity_key .as_hex_string ()) # type: ignore
137+ metadata = await self .eth .get_entity_metadata ( # type: ignore[attr-defined]
138+ entity_key .as_hex_string ()
139+ )
135140
136141 return EntityMetadata (
137142 entity_key = entity_key ,
@@ -158,20 +163,22 @@ async def get_entities_to_expire_at_block(
158163 return list (
159164 map (
160165 lambda e : EntityKey (GenericBytes .from_hex_string (e )),
161- await self .eth .get_entities_to_expire_at_block (block_number ), # type: ignore
166+ await self .eth .get_entities_to_expire_at_block ( # type: ignore[attr-defined]
167+ block_number
168+ ),
162169 )
163170 )
164171
165172 async def get_entity_count (self ) -> int :
166173 """Get the total entity count in Golem Base."""
167- return await self .eth .get_entity_count () # type: ignore
174+ return cast ( int , await self .eth .get_entity_count ()) # type: ignore[attr-defined]
168175
169176 async def get_all_entity_keys (self ) -> Sequence [EntityKey ]:
170177 """Get all entity keys in Golem Base."""
171178 return list (
172179 map (
173180 lambda e : EntityKey (GenericBytes .from_hex_string (e )),
174- await self .eth .get_all_entity_keys (), # type: ignore
181+ await self .eth .get_all_entity_keys (), # type: ignore[attr-defined]
175182 )
176183 )
177184
@@ -182,9 +189,7 @@ async def get_entities_of_owner(
182189 return list (
183190 map (
184191 lambda e : EntityKey (GenericBytes .from_hex_string (e )),
185- # https://github.com/pylint-dev/pylint/issues/3162
186- # pylint: disable=no-member
187- await self .eth .get_entities_of_owner (owner ), # type: ignore
192+ await self .eth .get_entities_of_owner (owner ), # type: ignore[attr-defined]
188193 )
189194 )
190195
@@ -195,7 +200,7 @@ async def query_entities(self, query: str) -> Sequence[QueryEntitiesResult]:
195200 lambda result : QueryEntitiesResult (
196201 entity_key = result .key , storage_value = base64 .b64decode (result .value )
197202 ),
198- await self .eth .query_entities (query ), # type: ignore
203+ await self .eth .query_entities (query ), # type: ignore[attr-defined]
199204 )
200205 )
201206
@@ -228,6 +233,24 @@ async def create(
228233 ws_client = await AsyncWeb3 (WebSocketProvider (ws_url ))
229234 return GolemBaseClient (rpc_url , ws_client , private_key )
230235
236+ async def _start_subscription_loop (self ) -> None :
237+ """Create a long running task to handle subscriptions."""
238+ # The loop will finish when there are no subscriptions left, so this method
239+ # gets called every time a subscription is created, and we'll check
240+ # whether we need to make a new task or whether one is already running.
241+ if not self ._background_tasks :
242+ # Start the asyncio event loop
243+ task = asyncio .create_task (
244+ self .ws_client ().subscription_manager .handle_subscriptions ()
245+ )
246+ self ._background_tasks .add (task )
247+
248+ def task_done (task : asyncio .Task [None ]) -> None :
249+ logger .info ("Subscription background task done, removing..." )
250+ self ._background_tasks .discard (task )
251+
252+ task .add_done_callback (task_done )
253+
231254 def __init__ (self , rpc_url : str , ws_client : AsyncWeb3 , private_key : bytes ) -> None :
232255 """Initialise the GolemBaseClient instance."""
233256 self ._http_client = GolemBaseHttpClient (rpc_url )
@@ -262,7 +285,11 @@ async def inner(show_traceback: bool) -> bool:
262285 # The method on the provider is usually not called directly, instead you
263286 # can call the eponymous method on the client, which will delegate to the
264287 # provider.
265- self .http_client ().provider .is_connected = is_connected (self .http_client ()) # type: ignore
288+ object .__setattr__ (
289+ self .http_client ().provider ,
290+ "is_connected" ,
291+ is_connected (self .http_client ()),
292+ )
266293
267294 # Allow caching of certain methods to improve performance
268295 self .http_client ().provider .cache_allowed_requests = True
@@ -303,7 +330,7 @@ def ws_client(self) -> AsyncWeb3:
303330
304331 async def is_connected (self ) -> bool :
305332 """Check whether the client's underlying http client is connected."""
306- return await self .http_client ().is_connected ()
333+ return cast ( bool , await self .http_client ().is_connected ()) # type: ignore[redundant-cast]
307334
308335 async def disconnect (self ) -> None :
309336 """
@@ -318,7 +345,7 @@ async def disconnect(self) -> None:
318345
319346 def get_account_address (self ) -> ChecksumAddress :
320347 """Get the address associated with the private key of this client."""
321- return self .account .address # type: ignore
348+ return cast ( ChecksumAddress , self .account .address )
322349
323350 async def get_storage_value (self , entity_key : EntityKey ) -> bytes :
324351 """Get the storage value stored in the given entity."""
@@ -382,10 +409,10 @@ async def extend_entities(
382409
383410 async def send_transaction (
384411 self ,
385- creates : Optional [ Sequence [GolemBaseCreate ]] = None ,
386- updates : Optional [ Sequence [GolemBaseUpdate ]] = None ,
387- deletes : Optional [ Sequence [GolemBaseDelete ]] = None ,
388- extensions : Optional [ Sequence [GolemBaseExtend ]] = None ,
412+ creates : Sequence [GolemBaseCreate ] | None = None ,
413+ updates : Sequence [GolemBaseUpdate ] | None = None ,
414+ deletes : Sequence [GolemBaseDelete ] | None = None ,
415+ extensions : Sequence [GolemBaseExtend ] | None = None ,
389416 ) -> GolemBaseTransactionReceipt :
390417 """
391418 Send a generic transaction to Golem Base.
@@ -529,11 +556,13 @@ async def _send_gb_transaction(
529556
530557 async def watch_logs (
531558 self ,
532- create_callback : Callable [[CreateEntityReturnType ], None ],
533- update_callback : Callable [[UpdateEntityReturnType ], None ],
534- delete_callback : Callable [[EntityKey ], None ],
535- extend_callback : Callable [[ExtendEntityReturnType ], None ],
536- ) -> None :
559+ * ,
560+ label : str ,
561+ create_callback : Callable [[CreateEntityReturnType ], None ] | None = None ,
562+ update_callback : Callable [[UpdateEntityReturnType ], None ] | None = None ,
563+ delete_callback : Callable [[EntityKey ], None ] | None = None ,
564+ extend_callback : Callable [[ExtendEntityReturnType ], None ] | None = None ,
565+ ) -> WatchLogsHandle :
537566 """
538567 Subscribe to events on Golem Base.
539568
@@ -550,44 +579,62 @@ async def log_handler(
550579 logger .debug ("New log: %s" , log_receipt )
551580 res = await self ._process_golem_base_log_receipt (log_receipt )
552581
553- for create in res .creates :
554- create_callback (create )
555- for update in res .updates :
556- update_callback (update )
557- for key in res .deletes :
558- delete_callback (key )
559- for extension in res .extensions :
560- extend_callback (extension )
582+ if create_callback :
583+ for create in res .creates :
584+ create_callback (create )
585+ if update_callback :
586+ for update in res .updates :
587+ update_callback (update )
588+ if delete_callback :
589+ for key in res .deletes :
590+ delete_callback (key )
591+ if extend_callback :
592+ for extension in res .extensions :
593+ extend_callback (extension )
561594
562595 def create_subscription (topic : HexStr ) -> LogsSubscription :
563596 return LogsSubscription (
564- label = f"Golem Base subscription to topic { topic } " ,
597+ label = f"Golem Base subscription to topic { topic } with label { label } " ,
565598 address = self .golem_base_contract .address ,
566599 topics = [topic ],
567600 handler = log_handler ,
568601 # optional `handler_context` args to help parse a response
569602 handler_context = {},
570603 )
571604
572- async def handle_subscriptions () -> None :
573- await self ._ws_client .subscription_manager .subscribe (
574- list (
575- map (
576- lambda event : create_subscription (event .topic ),
577- self .golem_base_contract .all_events (),
578- )
605+ event_names = []
606+ if create_callback :
607+ event_names .append ("GolemBaseStorageEntityCreated" )
608+ if update_callback :
609+ event_names .append ("GolemBaseStorageEntityUpdated" )
610+ if delete_callback :
611+ event_names .append ("GolemBaseStorageEntityDeleted" )
612+ if extend_callback :
613+ event_names .extend (
614+ [
615+ "GolemBaseStorageEntityBTLExtended" ,
616+ "GolemBaseStorageEntityBTLExptended" ,
617+ "GolemBaseStorageEntityTTLExptended" ,
618+ ]
619+ )
620+
621+ events = list (
622+ map (
623+ lambda event_name : create_subscription (
624+ self .golem_base_contract .get_event_by_name (event_name ).topic
579625 ),
626+ event_names ,
580627 )
581- # handle subscriptions via configured handlers:
582- await self .ws_client ().subscription_manager .handle_subscriptions ()
628+ )
629+ subscription_ids = await self ._ws_client .subscription_manager .subscribe (
630+ events ,
631+ )
632+ logger .info ("Sub ID: %s" , subscription_ids )
583633
584- # Create a long running task to handle subscriptions that we can run on
585- # the asyncio event loop
586- task = asyncio .create_task (handle_subscriptions ())
587- self ._background_tasks .add (task )
634+ # Start a subscription loop in case there is none running
635+ await self ._start_subscription_loop ()
588636
589- def task_done (task : asyncio .Task [None ]) -> None :
590- logger .info ("Subscription background task done, removing..." )
591- self ._background_tasks .discard (task )
637+ async def unsubscribe () -> None :
638+ await self ._ws_client .subscription_manager .unsubscribe (subscription_ids )
592639
593- task . add_done_callback ( task_done )
640+ return WatchLogsHandle ( _unsubscribe = unsubscribe )
0 commit comments