1616
1717import dataclasses
1818import datetime
19- import threading
20- from typing import List , Optional
21- import weakref
19+ from typing import Any , Callable , Optional , Set
20+ import uuid
2221
2322import google .cloud .bigquery ._job_helpers
2423import google .cloud .bigquery .job .query
2524import google .cloud .bigquery .table
2625
27- import bigframes .formatting_helpers
2826import bigframes .session .executor
2927
3028
31- @dataclasses .dataclass (frozen = True )
3229class Subscriber :
33- callback_ref : weakref .ref
34- # TODO(tswast): Add block_id to allow filter in context managers.
30+ def __init__ (self , callback : Callable [[Event ], None ], * , publisher : Publisher ):
31+ self ._publisher = publisher
32+ self ._callback = callback
33+ self ._subscriber_id = str (uuid .uuid4 ())
34+
35+ def __call__ (self , * args , ** kwargs ):
36+ return self ._callback (* args , ** kwargs )
37+
38+ def __hash__ (self ) -> int :
39+ return hash (self ._subscriber_id )
40+
41+ def __eq__ (self , value : object ):
42+ if not isinstance (value , Subscriber ):
43+ return NotImplemented
44+ return value ._subscriber_id == self ._subscriber_id
45+
46+ def close (self ):
47+ self ._publisher .unsubscribe (self )
48+ del self ._publisher
49+ del self ._callback
50+
51+ def __enter__ (self ):
52+ return self
53+
54+ def __exit__ (self , exc_type , exc_value , traceback ):
55+ if exc_value is not None :
56+ self (
57+ UnknownErrorEvent (
58+ exc_type = exc_type ,
59+ exc_value = exc_value ,
60+ traceback = traceback ,
61+ )
62+ )
63+ self .close ()
3564
3665
3766class Publisher :
3867 def __init__ (self ):
39- self ._subscribers : List [Subscriber ] = []
40- self ._subscribers_lock = threading .Lock ()
41-
42- def subscribe (self , callback ):
43- subscriber = Subscriber (callback_ref = weakref .ref (callback ))
44-
45- with self ._subscribers_lock :
46- # TODO(tswast): Add block_id to allow filter in context managers.
47- self ._subscribers .append (subscriber )
48-
49- def send (self , event : Event ):
50- to_delete = []
51- to_call = []
68+ self ._subscribers : Set [Subscriber ] = set ()
5269
53- with self ._subscribers_lock :
54- for sid , subscriber in enumerate (self ._subscribers ):
55- callback = subscriber .callback_ref ()
70+ def subscribe (self , callback : Callable [[Event ], None ]) -> Subscriber :
71+ # TODO(b/448176657): figure out how to handle subscribers/publishers in
72+ # a background thread. Maybe subscribers should be thread-local?
73+ subscriber = Subscriber (callback , publisher = self )
74+ self ._subscribers .add (subscriber )
75+ return subscriber
5676
57- if callback is None :
58- to_delete .append (sid )
59- else :
60- # TODO(tswast): Add if statement for block_id to allow filter
61- # in context managers.
62- to_call .append (callback )
77+ def unsubscribe (self , subscriber : Subscriber ):
78+ self ._subscribers .remove (subscriber )
6379
64- for sid in reversed (to_delete ):
65- del self ._subscribers [sid ]
66-
67- for callback in to_call :
68- callback (event )
80+ def publish (self , event : Event ):
81+ for subscriber in self ._subscribers :
82+ subscriber (event )
6983
7084
7185class Event :
@@ -90,6 +104,13 @@ class ExecutionFinished(Event):
90104 result : Optional [bigframes .session .executor .ExecuteResult ] = None
91105
92106
107+ @dataclasses .dataclass (frozen = True )
108+ class UnknownErrorEvent (Event ):
109+ exc_type : Any
110+ exc_value : Any
111+ traceback : Any
112+
113+
93114@dataclasses .dataclass (frozen = True )
94115class BigQuerySentEvent (ExecutionRunning ):
95116 """Query sent to BigQuery."""
0 commit comments