Skip to content

Commit ae71ed5

Browse files
authored
Merge pull request #7 from Golem-Base/rvdp/unsubscribe
feat: allow to unsubscribe from events
2 parents 56d0552 + e8f4bc0 commit ae71ed5

8 files changed

Lines changed: 170 additions & 98 deletions

File tree

example/golem_base_sdk_example/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
}
6060

6161

62-
async def run_example(instance: str) -> None:
62+
async def run_example(instance: str) -> None: # noqa: PLR0915
6363
"""Run the example."""
6464
async with await anyio.open_file(
6565
BaseDirectory.xdg_config_home + "/golembase/private.key",
@@ -73,26 +73,31 @@ async def run_example(instance: str) -> None:
7373
private_key=key_bytes,
7474
)
7575

76-
await client.watch_logs(
77-
lambda create: logger.info(
76+
watch_logs_handle = await client.watch_logs(
77+
label="first",
78+
create_callback=lambda create: logger.info(
7879
"""\n
7980
Got create event: %s
8081
""",
8182
create,
8283
),
83-
lambda update: logger.info(
84+
update_callback=lambda update: logger.info(
8485
"""\n
8586
Got update event: %s
8687
""",
8788
update,
8889
),
89-
lambda deleted_key: logger.info(
90+
)
91+
92+
await client.watch_logs(
93+
label="second",
94+
delete_callback=lambda deleted_key: logger.info(
9095
"""\n
9196
Got delete event: %s
9297
""",
9398
deleted_key,
9499
),
95-
lambda extension: logger.info(
100+
extend_callback=lambda extension: logger.info(
96101
"""\n
97102
Got extend event: %s
98103
""",
@@ -170,11 +175,11 @@ async def run_example(instance: str) -> None:
170175
logger.info(
171176
"block number: %s", await client.http_client().eth.get_block_number()
172177
)
173-
update_receipt = await client.update_entities(
178+
[update_receipt] = await client.update_entities(
174179
[GolemBaseUpdate(entity_key, b"hello", 60, [Annotation("app", "demo")], [])]
175180
)
176181
logger.info("receipt: %s", update_receipt)
177-
entity_key = update_receipt[0].entity_key
182+
entity_key = update_receipt.entity_key
178183

179184
logger.info("entity metadata: %s", await client.get_entity_metadata(entity_key))
180185

@@ -210,6 +215,8 @@ async def run_example(instance: str) -> None:
210215
"All entities: %s",
211216
await client.get_all_entity_keys(),
212217
)
218+
219+
await watch_logs_handle.unsubscribe()
213220
else:
214221
logger.warning("Could not connect to the API...")
215222

example/nix/packages/golem-base-sdk-example.nix

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let
2727

2828
in
2929

30-
pkgs.python3Packages.buildPythonPackage rec {
30+
pkgs.python3Packages.buildPythonPackage {
3131
inherit pname;
3232
version = "0.0.1";
3333

@@ -65,9 +65,8 @@ pkgs.python3Packages.buildPythonPackage rec {
6565
];
6666

6767
checkPhase = ''
68-
mypy --config ${../../../mypy.ini} ${src}/golem_base_sdk_example
69-
ruff check --no-cache ${src}/golem_base_sdk_example
70-
PYLINTHOME="$TMPDIR" pylint ${src}/golem_base_sdk_example
68+
mypy golem_base_sdk_example
69+
ruff check --no-cache golem_base_sdk_example
7170
'';
7271

7372
meta = with lib; {

golem_base_sdk/__init__.py

Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import typing
88
from collections.abc import (
99
AsyncGenerator,
10+
Callable,
1011
Coroutine,
1112
Sequence,
1213
)
1314
from typing import (
1415
Any,
15-
Callable,
16-
Optional,
16+
cast,
1717
)
1818

1919
from eth_typing import ChecksumAddress, HexStr
@@ -48,6 +48,7 @@
4848
GolemBaseUpdate,
4949
QueryEntitiesResult,
5050
UpdateEntityReturnType,
51+
WatchLogsHandle,
5152
)
5253
from .utils import parse_legacy_btl_extended_log, rlp_encode_transaction
5354

@@ -126,12 +127,16 @@ def __init__(self, rpc_url: str):
126127
async def get_storage_value(self, entity_key: EntityKey) -> bytes:
127128
"""Get the storage value stored in the given entity."""
128129
return base64.b64decode(
129-
await self.eth.get_storage_value(entity_key.as_hex_string()) # type: ignore
130+
await self.eth.get_storage_value( # type: ignore[attr-defined]
131+
entity_key.as_hex_string()
132+
)
130133
)
131134

132135
async def get_entity_metadata(self, entity_key: EntityKey) -> EntityMetadata:
133136
"""Get the metadata of the given entity."""
134-
metadata = await self.eth.get_entity_metadata(entity_key.as_hex_string()) # type: ignore
137+
metadata = await self.eth.get_entity_metadata( # type: ignore[attr-defined]
138+
entity_key.as_hex_string()
139+
)
135140

136141
return EntityMetadata(
137142
entity_key=entity_key,
@@ -158,20 +163,22 @@ async def get_entities_to_expire_at_block(
158163
return list(
159164
map(
160165
lambda e: EntityKey(GenericBytes.from_hex_string(e)),
161-
await self.eth.get_entities_to_expire_at_block(block_number), # type: ignore
166+
await self.eth.get_entities_to_expire_at_block( # type: ignore[attr-defined]
167+
block_number
168+
),
162169
)
163170
)
164171

165172
async def get_entity_count(self) -> int:
166173
"""Get the total entity count in Golem Base."""
167-
return await self.eth.get_entity_count() # type: ignore
174+
return cast(int, await self.eth.get_entity_count()) # type: ignore[attr-defined]
168175

169176
async def get_all_entity_keys(self) -> Sequence[EntityKey]:
170177
"""Get all entity keys in Golem Base."""
171178
return list(
172179
map(
173180
lambda e: EntityKey(GenericBytes.from_hex_string(e)),
174-
await self.eth.get_all_entity_keys(), # type: ignore
181+
await self.eth.get_all_entity_keys(), # type: ignore[attr-defined]
175182
)
176183
)
177184

@@ -182,9 +189,7 @@ async def get_entities_of_owner(
182189
return list(
183190
map(
184191
lambda e: EntityKey(GenericBytes.from_hex_string(e)),
185-
# https://github.com/pylint-dev/pylint/issues/3162
186-
# pylint: disable=no-member
187-
await self.eth.get_entities_of_owner(owner), # type: ignore
192+
await self.eth.get_entities_of_owner(owner), # type: ignore[attr-defined]
188193
)
189194
)
190195

@@ -195,7 +200,7 @@ async def query_entities(self, query: str) -> Sequence[QueryEntitiesResult]:
195200
lambda result: QueryEntitiesResult(
196201
entity_key=result.key, storage_value=base64.b64decode(result.value)
197202
),
198-
await self.eth.query_entities(query), # type: ignore
203+
await self.eth.query_entities(query), # type: ignore[attr-defined]
199204
)
200205
)
201206

@@ -228,6 +233,24 @@ async def create(
228233
ws_client = await AsyncWeb3(WebSocketProvider(ws_url))
229234
return GolemBaseClient(rpc_url, ws_client, private_key)
230235

236+
async def _start_subscription_loop(self) -> None:
237+
"""Create a long running task to handle subscriptions."""
238+
# The loop will finish when there are no subscriptions left, so this method
239+
# gets called every time a subscription is created, and we'll check
240+
# whether we need to make a new task or whether one is already running.
241+
if not self._background_tasks:
242+
# Start the asyncio event loop
243+
task = asyncio.create_task(
244+
self.ws_client().subscription_manager.handle_subscriptions()
245+
)
246+
self._background_tasks.add(task)
247+
248+
def task_done(task: asyncio.Task[None]) -> None:
249+
logger.info("Subscription background task done, removing...")
250+
self._background_tasks.discard(task)
251+
252+
task.add_done_callback(task_done)
253+
231254
def __init__(self, rpc_url: str, ws_client: AsyncWeb3, private_key: bytes) -> None:
232255
"""Initialise the GolemBaseClient instance."""
233256
self._http_client = GolemBaseHttpClient(rpc_url)
@@ -262,7 +285,11 @@ async def inner(show_traceback: bool) -> bool:
262285
# The method on the provider is usually not called directly, instead you
263286
# can call the eponymous method on the client, which will delegate to the
264287
# provider.
265-
self.http_client().provider.is_connected = is_connected(self.http_client()) # type: ignore
288+
object.__setattr__(
289+
self.http_client().provider,
290+
"is_connected",
291+
is_connected(self.http_client()),
292+
)
266293

267294
# Allow caching of certain methods to improve performance
268295
self.http_client().provider.cache_allowed_requests = True
@@ -303,7 +330,7 @@ def ws_client(self) -> AsyncWeb3:
303330

304331
async def is_connected(self) -> bool:
305332
"""Check whether the client's underlying http client is connected."""
306-
return await self.http_client().is_connected()
333+
return cast(bool, await self.http_client().is_connected()) # type: ignore[redundant-cast]
307334

308335
async def disconnect(self) -> None:
309336
"""
@@ -318,7 +345,7 @@ async def disconnect(self) -> None:
318345

319346
def get_account_address(self) -> ChecksumAddress:
320347
"""Get the address associated with the private key of this client."""
321-
return self.account.address # type: ignore
348+
return cast(ChecksumAddress, self.account.address)
322349

323350
async def get_storage_value(self, entity_key: EntityKey) -> bytes:
324351
"""Get the storage value stored in the given entity."""
@@ -382,10 +409,10 @@ async def extend_entities(
382409

383410
async def send_transaction(
384411
self,
385-
creates: Optional[Sequence[GolemBaseCreate]] = None,
386-
updates: Optional[Sequence[GolemBaseUpdate]] = None,
387-
deletes: Optional[Sequence[GolemBaseDelete]] = None,
388-
extensions: Optional[Sequence[GolemBaseExtend]] = None,
412+
creates: Sequence[GolemBaseCreate] | None = None,
413+
updates: Sequence[GolemBaseUpdate] | None = None,
414+
deletes: Sequence[GolemBaseDelete] | None = None,
415+
extensions: Sequence[GolemBaseExtend] | None = None,
389416
) -> GolemBaseTransactionReceipt:
390417
"""
391418
Send a generic transaction to Golem Base.
@@ -529,11 +556,13 @@ async def _send_gb_transaction(
529556

530557
async def watch_logs(
531558
self,
532-
create_callback: Callable[[CreateEntityReturnType], None],
533-
update_callback: Callable[[UpdateEntityReturnType], None],
534-
delete_callback: Callable[[EntityKey], None],
535-
extend_callback: Callable[[ExtendEntityReturnType], None],
536-
) -> None:
559+
*,
560+
label: str,
561+
create_callback: Callable[[CreateEntityReturnType], None] | None = None,
562+
update_callback: Callable[[UpdateEntityReturnType], None] | None = None,
563+
delete_callback: Callable[[EntityKey], None] | None = None,
564+
extend_callback: Callable[[ExtendEntityReturnType], None] | None = None,
565+
) -> WatchLogsHandle:
537566
"""
538567
Subscribe to events on Golem Base.
539568
@@ -550,44 +579,62 @@ async def log_handler(
550579
logger.debug("New log: %s", log_receipt)
551580
res = await self._process_golem_base_log_receipt(log_receipt)
552581

553-
for create in res.creates:
554-
create_callback(create)
555-
for update in res.updates:
556-
update_callback(update)
557-
for key in res.deletes:
558-
delete_callback(key)
559-
for extension in res.extensions:
560-
extend_callback(extension)
582+
if create_callback:
583+
for create in res.creates:
584+
create_callback(create)
585+
if update_callback:
586+
for update in res.updates:
587+
update_callback(update)
588+
if delete_callback:
589+
for key in res.deletes:
590+
delete_callback(key)
591+
if extend_callback:
592+
for extension in res.extensions:
593+
extend_callback(extension)
561594

562595
def create_subscription(topic: HexStr) -> LogsSubscription:
563596
return LogsSubscription(
564-
label=f"Golem Base subscription to topic {topic}",
597+
label=f"Golem Base subscription to topic {topic} with label {label}",
565598
address=self.golem_base_contract.address,
566599
topics=[topic],
567600
handler=log_handler,
568601
# optional `handler_context` args to help parse a response
569602
handler_context={},
570603
)
571604

572-
async def handle_subscriptions() -> None:
573-
await self._ws_client.subscription_manager.subscribe(
574-
list(
575-
map(
576-
lambda event: create_subscription(event.topic),
577-
self.golem_base_contract.all_events(),
578-
)
605+
event_names = []
606+
if create_callback:
607+
event_names.append("GolemBaseStorageEntityCreated")
608+
if update_callback:
609+
event_names.append("GolemBaseStorageEntityUpdated")
610+
if delete_callback:
611+
event_names.append("GolemBaseStorageEntityDeleted")
612+
if extend_callback:
613+
event_names.extend(
614+
[
615+
"GolemBaseStorageEntityBTLExtended",
616+
"GolemBaseStorageEntityBTLExptended",
617+
"GolemBaseStorageEntityTTLExptended",
618+
]
619+
)
620+
621+
events = list(
622+
map(
623+
lambda event_name: create_subscription(
624+
self.golem_base_contract.get_event_by_name(event_name).topic
579625
),
626+
event_names,
580627
)
581-
# handle subscriptions via configured handlers:
582-
await self.ws_client().subscription_manager.handle_subscriptions()
628+
)
629+
subscription_ids = await self._ws_client.subscription_manager.subscribe(
630+
events,
631+
)
632+
logger.info("Sub ID: %s", subscription_ids)
583633

584-
# Create a long running task to handle subscriptions that we can run on
585-
# the asyncio event loop
586-
task = asyncio.create_task(handle_subscriptions())
587-
self._background_tasks.add(task)
634+
# Start a subscription loop in case there is none running
635+
await self._start_subscription_loop()
588636

589-
def task_done(task: asyncio.Task[None]) -> None:
590-
logger.info("Subscription background task done, removing...")
591-
self._background_tasks.discard(task)
637+
async def unsubscribe() -> None:
638+
await self._ws_client.subscription_manager.unsubscribe(subscription_ids)
592639

593-
task.add_done_callback(task_done)
640+
return WatchLogsHandle(_unsubscribe=unsubscribe)

0 commit comments

Comments
 (0)