diff --git a/dash_extensions/enrich.py b/dash_extensions/enrich.py index 46c98e7..62af214 100644 --- a/dash_extensions/enrich.py +++ b/dash_extensions/enrich.py @@ -9,8 +9,7 @@ import dash_html_components as html import dash.dependencies as dd import plotly - -from dash.dependencies import Input, Output, State, MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction +from dash.dependencies import MATCH, ALL, ALLSMALLER, _Wildcard, ClientsideFunction from dash.development.base_component import Component from flask import session from flask_caching.backends import FileSystemCache, RedisCache @@ -29,9 +28,11 @@ class DashProxy(dash.Dash): def __init__(self, *args, transforms=None, **kwargs): super().__init__(*args, **kwargs) self.callbacks = [] + self.reactive_variables = [] self.clientside_callbacks = [] self.arg_types = [dd.Output, dd.Input, dd.State] self.transforms = transforms if transforms is not None else [] + self.layout_extension = LayoutExtension() # Do the transform initialization. for transform in self.transforms: transform.init(self) @@ -84,6 +85,38 @@ def clientside_callback(self, clientside_function, *args, **kwargs): callback["f"] = clientside_function self.clientside_callbacks.append(callback) + def _collect_reactive(self, name): + self.reactive_variables.append(name) + self.layout_extension.components.append(dcc.Store(name, "data")) + + def reactive(self, *args, serverside=None, output=None, **kwargs): + # If the output is not specified, create it. Per default, use serverside if available. + serverside = True if serverside is None else serverside + if output is None: + if serverside and any([isinstance(t, ServersideOutputTransform) for t in self.transforms]): + output = ServersideOutput(None, None) + else: + output = Output(None) + # Collect the callback, delay binding of output id. + callback = self._collect_callback(output, *args, **kwargs) + self.callbacks.append(callback) + + def wrapper(f): + component_id = f.__name__ + output.component_id = component_id + output.component_property = "data" + self._collect_reactive(component_id) + callback["f"] = f + + return wrapper + + def clientside_reactive(self, name, clientside_function, *args, **kwargs): + output = Output(name, "data") + self._collect_reactive(name) + callback = self._collect_callback(output, *args, **kwargs) + callback["f"] = clientside_function + self.clientside_callbacks.append(callback) + def _register_callbacks(self, app=None): callbacks, clientside_callbacks = self._resolve_callbacks() app = super() if app is None else app @@ -98,7 +131,7 @@ def _layout_value(self): layout = self._layout() if self._layout_is_function else self._layout for transform in self.transforms: layout = transform.layout(layout, self._layout_is_function) - return layout + return self.layout_extension.layout(layout, self._layout_is_function) def _setup_server(self): """ @@ -112,11 +145,25 @@ def _setup_server(self): if not self.server.secret_key: self.server.secret_key = secrets.token_urlsafe(16) + def _resolve_reactive_variables(self, callbacks): + for callback in callbacks: + for item in callback[dd.Input]: + if item.component_id in self.reactive_variables: + item.component_property = "data" + for item in callback[dd.State]: + if item.component_id in self.reactive_variables: + item.component_property = "data" + return callbacks + def _resolve_callbacks(self): """ This method resolves the callbacks, i.e. it applies the callback injections. """ callbacks, clientside_callbacks = self.callbacks, self.clientside_callbacks + # Resolve reactive variables. + callbacks = self._resolve_reactive_variables(callbacks) + clientside_callbacks = self._resolve_reactive_variables(clientside_callbacks) + # Apply transforms. for transform in self.transforms: callbacks, clientside_callbacks = transform.apply(callbacks, clientside_callbacks) return callbacks, clientside_callbacks @@ -193,6 +240,39 @@ def layout(self, layout, layout_is_function): return layout + +class LayoutExtension: + + def __init__(self): + self.initialized = False + self.components = [] + + def layout(self, layout, layout_is_function): + if layout_is_function or not self.initialized: + children = _as_list(layout.children) + self.components + layout.children = children + self.initialized = True + return layout + +# endregion + +# region Default component property values + +class Input(dd.Input): + def __init__(self, component_id, component_property=None): + component_property = "value" if component_property is None else component_property + super().__init__(component_id, component_property) + +class State(dd.State): + def __init__(self, component_id, component_property=None): + component_property = "value" if component_property is None else component_property + super().__init__(component_id, component_property) + +class Output(dd.Output): + def __init__(self, component_id, component_property=None): + component_property = "children" if component_property is None else component_property + super().__init__(component_id, component_property) + # endregion # region Prefix ID transform @@ -398,7 +478,7 @@ def apply(self, callbacks, clientside_callbacks): # Group by output. output_map = defaultdict(list) for callback in all_callbacks: - for output in callback[Output]: + for output in callback[dd.Output]: output_map[output].append(callback) # Apply multiplexer where needed. for output in output_map: @@ -418,7 +498,7 @@ def _apply_multiplexer(self, output, callbacks): # Create proxy element. proxies.append(_mp_element(mp_id_escaped)) # Assign proxy element as output. - callback[Output][callback[Output].index(output)] = Output(mp_id_escaped, _mp_prop()) + callback[dd.Output][callback[dd.Output].index(output)] = Output(mp_id_escaped, _mp_prop()) # Create proxy input. inputs.append(Input(mp_id, _mp_prop())) # Collect proxy elements to add to layout. @@ -548,7 +628,7 @@ def decorated_function(*args): # Figure out if an update is necessary. unique_ids = [] update_needed = False - for i, output in enumerate(callback[Output]): + for i, output in enumerate(callback[dd.Output]): # Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere). is_trigger = trigger_filter(callback["sorted_args"]) filtered_args = [arg for i, arg in enumerate(args) if not is_trigger[i]] @@ -560,20 +640,20 @@ def decorated_function(*args): break # If not update is needed, just return the ids (or values, if not serverside output). if not update_needed: - results = [uid if isinstance(callback[Output][i], ServersideOutput) else - callback[Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)] + results = [uid if isinstance(callback[dd.Output][i], ServersideOutput) else + callback[dd.Output][i].backend.get(uid) for i, uid in enumerate(unique_ids)] return results if multi_output else results[0] # Do the update. data = f(*args) data = list(data) if multi_output else [data] if callable(memoize): data = memoize(data) - for i, output in enumerate(callback[Output]): + for i, output in enumerate(callback[dd.Output]): # Skip no_update updates. if isinstance(data[i], type(dash.no_update)): continue # Replace only for server side outputs. - serverside_output = isinstance(callback[Output][i], ServersideOutput) + serverside_output = isinstance(callback[dd.Output][i], ServersideOutput) if serverside_output or memoize: # Filter out Triggers (a little ugly to do here, should ideally be handled elsewhere). is_trigger = trigger_filter(callback["sorted_args"]) @@ -662,15 +742,10 @@ def get(self, key, ignore_expired=False): class NoOutputTransform(DashTransform): def __init__(self): - self.initialized = False - self.hidden_divs = [] + self.layout_extension = LayoutExtension() def layout(self, layout, layout_is_function): - if layout_is_function or not self.initialized: - children = _as_list(layout.children) + self.hidden_divs - layout.children = children - self.initialized = True - return layout + return self.layout_extension.layout(layout, layout_is_function) def _apply(self, callbacks): for callback in callbacks: @@ -678,7 +753,7 @@ def _apply(self, callbacks): output_id = _get_output_id(callback) hidden_div = html.Div(id=output_id, style={"display": "none"}) callback[dd.Output] = [dd.Output(output_id, "children")] - self.hidden_divs.append(hidden_div) + self.layout_extension.components.append(hidden_div) return callbacks def apply_serverside(self, callbacks): diff --git a/reactive_example.py b/reactive_example.py new file mode 100644 index 0000000..b9efad4 --- /dev/null +++ b/reactive_example.py @@ -0,0 +1,26 @@ +import dash_core_components as dcc +import dash_html_components as html +import plotly.graph_objects as go +from dash_extensions.enrich import Dash, Output, Input + +app = Dash() +app.layout = html.Div([dcc.Input(value=1, id='x', type='number'), + dcc.Input(value=1, id='power', type='number'), + html.Div(id='result'), dcc.Graph(id='graph')]) + +@app.reactive(Input('x'), Input('power')) +def z(x, y): + return x ** y if (x and y) else None + +#app.clientside_reactive("z", "function(x,y){return x**y}", Input('x'), Input('power')) ?? + +@app.callback(Output('result'), Input('x'), Input('power'), Input('z')) +def display_result(x, y, z): + return f"{x}^{y} is {z}" + +@app.callback(Output('graph', 'figure'), Input('x'), Input('power'), Input('z')) +def plot_result(x, y, z): + return go.Figure([go.Bar(x=['x', 'y', 'x**y'], y=[x, y, z])]) + +if __name__ == "__main__": + app.run_server() \ No newline at end of file diff --git a/reactive_example2.py b/reactive_example2.py new file mode 100644 index 0000000..56e93fe --- /dev/null +++ b/reactive_example2.py @@ -0,0 +1,30 @@ +import dash_core_components as dcc +import dash_html_components as html +import plotly.express as px +import dash_table +from dash_extensions.enrich import DashProxy, Output, Input, ServersideOutputTransform + +# Read the full, complex dataset here (for the sake of simplicify, a small px dataset is used). ?? +df_all = px.data.gapminder() +# Example app demonstrating how to share state between callbacks via a reactive variable. +app = DashProxy(transforms=[ServersideOutputTransform()]) +app.layout = html.Div([ + dcc.Dropdown(options=[dict(value=x, label=x) for x in df_all.country.unique()], id="country", value="Denmark"), + dcc.Graph(id='graph'), + dash_table.DataTable(id='table', columns=[{"name": i, "id": i} for i in df_all.columns]) +]) + +@app.reactive(Input('country')) # default prop for input/state is "value" +def df_filtered(country): # reactive variable name = function name + return df_all[df_all.country == country] # defaults to serverside output, i.e. json serialization is not needed + +@app.callback(Output('table', 'data'), Input('df_filtered')) # access reactive variable via it's ID +def update_table(df): + return df.to_dict('records') # the reactive variable was stored serverside, i.e. deserialize is not needed + +@app.callback(Output('graph', 'figure'), Input('df_filtered')) +def update_graph(df): + return px.bar(df, x='year', y='pop', color='gdpPercap') + +if __name__ == "__main__": + app.run_server() \ No newline at end of file