From b2119f2214b3f61a8a06c78d0da596e88b56b4a7 Mon Sep 17 00:00:00 2001 From: emher Date: Thu, 8 Jul 2021 14:48:41 +0200 Subject: [PATCH 1/2] First take at reactive implementation --- dash_extensions/enrich.py | 46 ++++++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/dash_extensions/enrich.py b/dash_extensions/enrich.py index 46c98e7..07c900a 100644 --- a/dash_extensions/enrich.py +++ b/dash_extensions/enrich.py @@ -9,7 +9,6 @@ import dash_html_components as html import dash.dependencies as dd import plotly - from dash.dependencies import Input, Output, State, MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction from dash.development.base_component import Component from flask import session @@ -29,9 +28,11 @@ class DashProxy(dash.Dash): def __init__(self, *args, transforms=None, **kwargs): super().__init__(*args, **kwargs) self.callbacks = [] + self.reactive_variables = [] self.clientside_callbacks = [] self.arg_types = [dd.Output, dd.Input, dd.State] self.transforms = transforms if transforms is not None else [] + self.layout_extension = LayoutExtension() # Do the transform initialization. for transform in self.transforms: transform.init(self) @@ -84,6 +85,23 @@ def clientside_callback(self, clientside_function, *args, **kwargs): callback["f"] = clientside_function self.clientside_callbacks.append(callback) + def reactive(self, *args, **kwargs): + """ + This method saves the callbacks on the DashTransformer object. It acts as a proxy for the Dash app callback. + """ + output = Output(None, "data") + callback = self._collect_callback(output, *args, **kwargs) + self.callbacks.append(callback) + + def wrapper(f): + component_id = f.__name__ + output.component_id = component_id + self.reactive_variables.append(component_id) + self.layout_extension.components.append(dcc.Store(component_id, "data")) + callback["f"] = f + + return wrapper + def _register_callbacks(self, app=None): callbacks, clientside_callbacks = self._resolve_callbacks() app = super() if app is None else app @@ -98,7 +116,7 @@ def _layout_value(self): layout = self._layout() if self._layout_is_function else self._layout for transform in self.transforms: layout = transform.layout(layout, self._layout_is_function) - return layout + return self.layout_extension.layout(layout, self._layout_is_function) def _setup_server(self): """ @@ -193,6 +211,19 @@ def layout(self, layout, layout_is_function): return layout +class LayoutExtension: + + def __init__(self): + self.initialized = False + self.components = [] + + def layout(self, layout, layout_is_function): + if layout_is_function or not self.initialized: + children = _as_list(layout.children) + self.components + layout.children = children + self.initialized = True + return layout + # endregion # region Prefix ID transform @@ -662,15 +693,10 @@ def get(self, key, ignore_expired=False): class NoOutputTransform(DashTransform): def __init__(self): - self.initialized = False - self.hidden_divs = [] + self.layout_extension = LayoutExtension() def layout(self, layout, layout_is_function): - if layout_is_function or not self.initialized: - children = _as_list(layout.children) + self.hidden_divs - layout.children = children - self.initialized = True - return layout + return self.layout_extension.layout(layout, layout_is_function) def _apply(self, callbacks): for callback in callbacks: @@ -678,7 +704,7 @@ def _apply(self, callbacks): output_id = _get_output_id(callback) hidden_div = html.Div(id=output_id, style={"display": "none"}) callback[dd.Output] = [dd.Output(output_id, "children")] - self.hidden_divs.append(hidden_div) + self.layout_extension.components.append(hidden_div) return callbacks def apply_serverside(self, callbacks): From aabfe05609b0900249229f34a27fe4eb0ade5793 Mon Sep 17 00:00:00 2001 From: emher Date: Sat, 7 Aug 2021 19:08:40 +0200 Subject: [PATCH 2/2] update --- dash_extensions/enrich.py | 81 +++++++++++++++++++++++++++++++-------- reactive_example.py | 26 +++++++++++++ reactive_example2.py | 30 +++++++++++++++ 3 files changed, 121 insertions(+), 16 deletions(-) create mode 100644 reactive_example.py create mode 100644 reactive_example2.py diff --git a/dash_extensions/enrich.py b/dash_extensions/enrich.py index 07c900a..62af214 100644 --- a/dash_extensions/enrich.py +++ b/dash_extensions/enrich.py @@ -9,7 +9,7 @@ import dash_html_components as html import dash.dependencies as dd import plotly -from dash.dependencies import Input, Output, State, MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction +from dash.dependencies import MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction from dash.development.base_component import Component from flask import session from flask_caching.backends import FileSystemCache, RedisCache @@ -85,23 +85,38 @@ def clientside_callback(self, clientside_function, *args, **kwargs): callback["f"] = clientside_function self.clientside_callbacks.append(callback) - def reactive(self, *args, **kwargs): - """ - This method saves the callbacks on the DashTransformer object. It acts as a proxy for the Dash app callback. - """ - output = Output(None, "data") + def _collect_reactive(self, name): + self.reactive_variables.append(name) + self.layout_extension.components.append(dcc.Store(name, "data")) + + def reactive(self, *args, serverside=None, output=None, **kwargs): + # If the output is not specified, create it. Per default, use serverside if available. + serverside = True if serverside is None else serverside + if output is None: + if serverside and any([isinstance(t, ServersideOutputTransform) for t in self.transforms]): + output = ServersideOutput(None, None) + else: + output = Output(None) + # Collect the callback, delay binding of output id. callback = self._collect_callback(output, *args, **kwargs) self.callbacks.append(callback) def wrapper(f): component_id = f.__name__ output.component_id = component_id - self.reactive_variables.append(component_id) - self.layout_extension.components.append(dcc.Store(component_id, "data")) + output.component_property = "data" + self._collect_reactive(component_id) callback["f"] = f return wrapper - + + def clientside_reactive(self, name, clientside_function, *args, **kwargs): + output = Output(name, "data") + self._collect_reactive(name) + callback = self._collect_callback(output, *args, **kwargs) + callback["f"] = clientside_function + self.clientside_callbacks.append(callback) + def _register_callbacks(self, app=None): callbacks, clientside_callbacks = self._resolve_callbacks() app = super() if app is None else app @@ -130,11 +145,25 @@ def _setup_server(self): if not self.server.secret_key: self.server.secret_key = secrets.token_urlsafe(16) + def _resolve_reactive_variables(self, callbacks): + for callback in callbacks: + for item in callback[dd.Input]: + if item.component_id in self.reactive_variables: + item.component_property = "data" + for item in callback[dd.State]: + if item.component_id in self.reactive_variables: + item.component_property = "data" + return callbacks + def _resolve_callbacks(self): """ This method resolves the callbacks, i.e. it applies the callback injections. """ callbacks, clientside_callbacks = self.callbacks, self.clientside_callbacks + # Resolve reactive variables. + callbacks = self._resolve_reactive_variables(callbacks) + clientside_callbacks = self._resolve_reactive_variables(clientside_callbacks) + # Apply transforms. for transform in self.transforms: callbacks, clientside_callbacks = transform.apply(callbacks, clientside_callbacks) return callbacks, clientside_callbacks @@ -211,6 +240,7 @@ def layout(self, layout, layout_is_function): return layout + class LayoutExtension: def __init__(self): @@ -226,6 +256,25 @@ def layout(self, layout, layout_is_function): # endregion +# region Default component property values + +class Input(dd.Input): + def __init__(self, component_id, component_property=None): + component_property = "value" if component_property is None else component_property + super().__init__(component_id, component_property) + +class State(dd.State): + def __init__(self, component_id, component_property=None): + component_property = "value" if component_property is None else component_property + super().__init__(component_id, component_property) + +class Output(dd.Output): + def __init__(self, component_id, component_property=None): + component_property = "children" if component_property is None else component_property + super().__init__(component_id, component_property) + +# endregion + # region Prefix ID transform class PrefixIdTransform(DashTransform): @@ -429,7 +478,7 @@ def apply(self, callbacks, clientside_callbacks): # Group by output. output_map = defaultdict(list) for callback in all_callbacks: - for output in callback[Output]: + for output in callback[dd.Output]: output_map[output].append(callback) # Apply multiplexer where needed. for output in output_map: @@ -449,7 +498,7 @@ def _apply_multiplexer(self, output, callbacks): # Create proxy element. proxies.append(_mp_element(mp_id_escaped)) # Assign proxy element as output. - callback[Output][callback[Output].index(output)] = Output(mp_id_escaped, _mp_prop()) + callback[dd.Output][callback[dd.Output].index(output)] = Output(mp_id_escaped, _mp_prop()) # Create proxy input. inputs.append(Input(mp_id, _mp_prop())) # Collect proxy elements to add to layout. @@ -579,7 +628,7 @@ def decorated_function(*args): # Figure out if an update is necessary. unique_ids = [] update_needed = False - for i, output in enumerate(callback[Output]): + for i, output in enumerate(callback[dd.Output]): # Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere). is_trigger = trigger_filter(callback["sorted_args"]) filtered_args = [arg for i, arg in enumerate(args) if not is_trigger[i]] @@ -591,20 +640,20 @@ def decorated_function(*args): break # If not update is needed, just return the ids (or values, if not serverside output). if not update_needed: - results = [uid if isinstance(callback[Output][i], ServersideOutput) else - callback[Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)] + results = [uid if isinstance(callback[dd.Output][i], ServersideOutput) else + callback[dd.Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)] return results if multi_output else results[0] # Do the update. data = f(*args) data = list(data) if multi_output else [data] if callable(memoize): data = memoize(data) - for i, output in enumerate(callback[Output]): + for i, output in enumerate(callback[dd.Output]): # Skip no_update updates. if isinstance(data[i], type(dash.no_update)): continue # Replace only for server side outputs. - serverside_output = isinstance(callback[Output][i], ServersideOutput) + serverside_output = isinstance(callback[dd.Output][i], ServersideOutput) if serverside_output or memoize: # Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere). is_trigger = trigger_filter(callback["sorted_args"]) diff --git a/reactive_example.py b/reactive_example.py new file mode 100644 index 0000000..b9efad4 --- /dev/null +++ b/reactive_example.py @@ -0,0 +1,26 @@ +import dash_core_components as dcc +import dash_html_components as html +import plotly.graph_objects as go +from dash_extensions.enrich import Dash, Output, Input + +app = Dash() +app.layout = html.Div([dcc.Input(value=1, id='x', type='number'), + dcc.Input(value=1, id='power', type='number'), + html.Div(id='result'), dcc.Graph(id='graph')]) + +@app.reactive(Input('x'), Input('power')) +def z(x, y): + return x ** y if (x and y) else None + +#app.clientside_reactive("z", "function(x,y){return x**y}", Input('x'), Input('power')) ?? + +@app.callback(Output('result'), Input('x'), Input('power'), Input('z')) +def display_result(x, y, z): + return f"{x}^{y} is {z}" + +@app.callback(Output('graph', 'figure'), Input('x'), Input('power'), Input('z')) +def plot_result(x, y, z): + return go.Figure([go.Bar(x=['x', 'y', 'x**y'], y=[x, y, z])]) + +if __name__ == "__main__": + app.run_server() \ No newline at end of file diff --git a/reactive_example2.py b/reactive_example2.py new file mode 100644 index 0000000..56e93fe --- /dev/null +++ b/reactive_example2.py @@ -0,0 +1,30 @@ +import dash_core_components as dcc +import dash_html_components as html +import plotly.express as px +import dash_table +from dash_extensions.enrich import DashProxy, Output, Input, ServersideOutputTransform + +# Read the full, complex dataset here (for the sake of simplicify, a small px dataset is used). ?? +df_all = px.data.gapminder() +# Example app demonstrating how to share state between callbacks via a reactive variable. +app = DashProxy(transforms=[ServersideOutputTransform()]) +app.layout = html.Div([ + dcc.Dropdown(options=[dict(value=x, label=x) for x in df_all.country.unique()], id="country", value="Denmark"), + dcc.Graph(id='graph'), + dash_table.DataTable(id='table', columns=[{"name": i, "id": i} for i in df_all.columns]) +]) + +@app.reactive(Input('country')) # default prop for input/state is "value" +def df_filtered(country): # reactive variable name = function name + return df_all[df_all.country == country] # defaults to serverside output, i.e. json serialization is not needed + +@app.callback(Output('table', 'data'), Input('df_filtered')) # access reactive variable via it's ID +def update_table(df): + return df.to_dict('records') # the reactive variable was stored serverside, i.e. deserialize is not needed + +@app.callback(Output('graph', 'figure'), Input('df_filtered')) +def update_graph(df): + return px.bar(df, x='year', y='pop', color='gdpPercap') + +if __name__ == "__main__": + app.run_server() \ No newline at end of file