diff --git a/dash/_callback.py b/dash/_callback.py index 071c209dec..80557550b9 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -3,6 +3,7 @@ from functools import wraps from typing import Callable, Optional, Any +import asyncio import flask from .dependencies import ( @@ -39,6 +40,16 @@ from ._callback_context import context_value +async def _async_invoke_callback( + func, *args, **kwargs +): # used to mark the frame for the debugger + # Check if the function is a coroutine function + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) # %% callback invoked %% + # If the function is not a coroutine, call it directly + return func(*args, **kwargs) # %% callback invoked %% + + def _invoke_callback(func, *args, **kwargs): # used to mark the frame for the debugger return func(*args, **kwargs) # %% callback invoked %% @@ -353,6 +364,230 @@ def wrap_func(func): ) @wraps(func) + async def async_add_context(*args, **kwargs): + output_spec = kwargs.pop("outputs_list") + app_callback_manager = kwargs.pop("long_callback_manager", None) + + callback_ctx = kwargs.pop( + "callback_context", AttributeDict({"updated_props": {}}) + ) + app = kwargs.pop("app", None) + callback_manager = long and long.get("manager", app_callback_manager) + error_handler = on_error or kwargs.pop("app_on_error", None) + original_packages = set(ComponentRegistry.registry) + + if has_output: + _validate.validate_output_spec(insert_output, output_spec, Output) + + context_value.set(callback_ctx) + + func_args, func_kwargs = _validate.validate_and_group_input_args( + args, inputs_state_indices + ) + + response: dict = {"multi": True} + has_update = False + + if long is not None: + if not callback_manager: + raise MissingLongCallbackManagerError( + "Running `long` callbacks requires a manager to be installed.\n" + "Available managers:\n" + "- Diskcache (`pip install dash[diskcache]`) to run callbacks in a separate Process" + " and store results on the local filesystem.\n" + "- Celery (`pip install dash[celery]`) to run callbacks in a celery worker" + " and store results on redis.\n" + ) + + progress_outputs = long.get("progress") + cache_key = flask.request.args.get("cacheKey") + job_id = flask.request.args.get("job") + old_job = flask.request.args.getlist("oldJob") + + current_key = callback_manager.build_cache_key( + func, + # Inputs provided as dict is kwargs. + func_args if func_args else func_kwargs, + long.get("cache_args_to_ignore", []), + ) + + if old_job: + for job in old_job: + callback_manager.terminate_job(job) + + if not cache_key: + cache_key = current_key + + job_fn = callback_manager.func_registry.get(long_key) + + ctx_value = AttributeDict(**context_value.get()) + ctx_value.ignore_register_page = True + ctx_value.pop("background_callback_manager") + ctx_value.pop("dash_response") + + job = callback_manager.call_job_fn( + cache_key, + job_fn, + func_args if func_args else func_kwargs, + ctx_value, + ) + + data = { + "cacheKey": cache_key, + "job": job, + } + + cancel = long.get("cancel") + if cancel: + data["cancel"] = cancel + + progress_default = long.get("progressDefault") + if progress_default: + data["progressDefault"] = { + str(o): x + for o, x in zip(progress_outputs, progress_default) + } + return to_json(data) + if progress_outputs: + # Get the progress before the result as it would be erased after the results. + progress = callback_manager.get_progress(cache_key) + if progress: + response["progress"] = { + str(x): progress[i] for i, x in enumerate(progress_outputs) + } + + output_value = callback_manager.get_result(cache_key, job_id) + # Must get job_running after get_result since get_results terminates it. + job_running = callback_manager.job_running(job_id) + if not job_running and output_value is callback_manager.UNDEFINED: + # Job canceled -> no output to close the loop. + output_value = NoUpdate() + + elif ( + isinstance(output_value, dict) + and "long_callback_error" in output_value + ): + error = output_value.get("long_callback_error", {}) + exc = LongCallbackError( + f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}" + ) + if error_handler: + output_value = error_handler(exc) + + if output_value is None: + output_value = NoUpdate() + # set_props from the error handler uses the original ctx + # instead of manager.get_updated_props since it runs in the + # request process. + has_update = ( + _set_side_update(callback_ctx, response) + or output_value is not None + ) + else: + raise exc + + if job_running and output_value is not callback_manager.UNDEFINED: + # cached results. + callback_manager.terminate_job(job_id) + + if multi and isinstance(output_value, (list, tuple)): + output_value = [ + NoUpdate() if NoUpdate.is_no_update(r) else r + for r in output_value + ] + updated_props = callback_manager.get_updated_props(cache_key) + if len(updated_props) > 0: + response["sideUpdate"] = updated_props + has_update = True + + if output_value is callback_manager.UNDEFINED: + return to_json(response) + else: + try: + output_value = await _async_invoke_callback( + func, *func_args, **func_kwargs + ) + except PreventUpdate as err: + raise err + except Exception as err: # pylint: disable=broad-exception-caught + if error_handler: + output_value = error_handler(err) + + # If the error returns nothing, automatically puts NoUpdate for response. + if output_value is None and has_output: + output_value = NoUpdate() + else: + raise err + + component_ids = collections.defaultdict(dict) + + if has_output: + if not multi: + output_value, output_spec = [output_value], [output_spec] + flat_output_values = output_value + else: + if isinstance(output_value, (list, tuple)): + # For multi-output, allow top-level collection to be + # list or tuple + output_value = list(output_value) + + if NoUpdate.is_no_update(output_value): + flat_output_values = [output_value] + else: + # Flatten grouping and validate grouping structure + flat_output_values = flatten_grouping(output_value, output) + + if not NoUpdate.is_no_update(output_value): + _validate.validate_multi_return( + output_spec, flat_output_values, callback_id + ) + + for val, spec in zip(flat_output_values, output_spec): + if NoUpdate.is_no_update(val): + continue + for vali, speci in ( + zip(val, spec) if isinstance(spec, list) else [[val, spec]] + ): + if not NoUpdate.is_no_update(vali): + has_update = True + id_str = stringify_id(speci["id"]) + prop = clean_property_name(speci["property"]) + component_ids[id_str][prop] = vali + else: + if output_value is not None: + raise InvalidCallbackReturnValue( + f"No-output callback received return value: {output_value}" + ) + output_value = [] + flat_output_values = [] + + if not long: + has_update = _set_side_update(callback_ctx, response) or has_update + + if not has_update: + raise PreventUpdate + + response["response"] = component_ids + + if len(ComponentRegistry.registry) != len(original_packages): + diff_packages = list( + set(ComponentRegistry.registry).difference(original_packages) + ) + if not allow_dynamic_callbacks: + raise ImportedInsideCallbackError( + f"Component librar{'y' if len(diff_packages) == 1 else 'ies'} was imported during callback.\n" + "You can set `_allow_dynamic_callbacks` to allow for development purpose only." + ) + dist = app.get_dist(diff_packages) + response["dist"] = dist + + try: + jsonResponse = to_json(response) + except TypeError: + _validate.fail_callback_output(output_value, output) + + return jsonResponse + def add_context(*args, **kwargs): output_spec = kwargs.pop("outputs_list") app_callback_manager = kwargs.pop("long_callback_manager", None) @@ -575,7 +810,10 @@ def add_context(*args, **kwargs): return jsonResponse - callback_map[callback_id]["callback"] = add_context + if asyncio.iscoroutinefunction(func): + callback_map[callback_id]["callback"] = async_add_context + else: + callback_map[callback_id]["callback"] = add_context return func diff --git a/dash/dash.py b/dash/dash.py index 40f65dff5f..a942a6d3fb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -18,6 +18,7 @@ from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, List +import asyncio import flask from importlib_metadata import version as _get_distribution_version @@ -149,6 +150,7 @@ def _get_traceback(secret, error: Exception): def _get_skip(error): from dash._callback import ( # pylint: disable=import-outside-toplevel _invoke_callback, + _async_invoke_callback, ) tb = error.__traceback__ @@ -156,7 +158,10 @@ def _get_skip(error): while tb.tb_next is not None: skip += 1 tb = tb.tb_next - if tb.tb_frame.f_code is _invoke_callback.__code__: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: return skip return skip @@ -164,11 +169,15 @@ def _get_skip(error): def _do_skip(error): from dash._callback import ( # pylint: disable=import-outside-toplevel _invoke_callback, + _async_invoke_callback, ) tb = error.__traceback__ while tb.tb_next is not None: - if tb.tb_frame.f_code is _invoke_callback.__code__: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: return tb.tb_next tb = tb.tb_next return error.__traceback__ @@ -192,6 +201,14 @@ def _do_skip(error): no_update = _callback.NoUpdate() # pylint: disable=protected-access +async def execute_async_function(func, *args, **kwargs): + # Check if the function is a coroutine function + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + # If the function is not a coroutine, call it directly + return func(*args, **kwargs) + + # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-arguments, too-many-locals class Dash: @@ -375,6 +392,10 @@ class Dash: an exception is raised. Receives the exception object as first argument. The callback_context can be used to access the original callback inputs, states and output. + + :param use_async: When True, the app will create async endpoints, as a dev, + they will be responsible for installing the `flask[async]` dependency. + :type use_async: boolean """ _plotlyjs_url: str @@ -422,8 +443,25 @@ def __init__( # pylint: disable=too-many-statements routing_callback_inputs: Optional[Dict[str, Union[Input, State]]] = None, description: Optional[str] = None, on_error: Optional[Callable[[Exception], Any]] = None, + use_async: Optional[bool] = None, **obsolete, ): + + if use_async is None: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + + use_async = True + except ImportError: + pass + elif use_async: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + except ImportError as exc: + raise Exception( + "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" + ) from exc + _validate.check_obsolete(obsolete) caller_name = None if name else get_caller_name() @@ -535,6 +573,7 @@ def __init__( # pylint: disable=too-many-statements self.validation_layout = None self._on_error = on_error self._extra_components = [] + self._use_async = use_async self._setup_dev_tools() self._hot_reload = AttributeDict( @@ -672,7 +711,10 @@ def _setup_routes(self): ) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - self._add_url("_dash-update-component", self.dispatch, ["POST"]) + if self._use_async: + self._add_url("_dash-update-component", self.async_dispatch, ["POST"]) + else: + self._add_url("_dash-update-component", self.dispatch, ["POST"]) self._add_url("_reload-hash", self.serve_reload_hash) self._add_url("_favicon.ico", self._serve_default_favicon) self._add_url("", self.index) @@ -1267,7 +1309,7 @@ def long_callback( ) # pylint: disable=R0915 - def dispatch(self): + async def async_dispatch(self): body = flask.request.get_json() g = AttributeDict({}) @@ -1371,20 +1413,154 @@ def dispatch(self): raise KeyError(msg) from missing_callback_function ctx = copy_context() + # Create a partial function with the necessary arguments # noinspection PyArgumentList - response.set_data( - ctx.run( - functools.partial( - func, - *args, - outputs_list=outputs_list, - long_callback_manager=self._background_manager, - callback_context=g, - app=self, - app_on_error=self._on_error, + partial_func = functools.partial( + execute_async_function, + func, + *args, + outputs_list=outputs_list, + long_callback_manager=self._background_manager, + callback_context=g, + app=self, + app_on_error=self._on_error, + app_use_async=self._use_async, + ) + + response_data = await ctx.run(partial_func) + + # Check if the response is a coroutine + if asyncio.iscoroutine(response_data): + response_data = await response_data + + response.set_data(response_data) + return response + + def dispatch(self): + body = flask.request.get_json() + + g = AttributeDict({}) + + g.inputs_list = inputs = body.get( # pylint: disable=assigning-non-slot + "inputs", [] + ) + g.states_list = state = body.get( # pylint: disable=assigning-non-slot + "state", [] + ) + output = body["output"] + outputs_list = body.get("outputs") + g.outputs_list = outputs_list # pylint: disable=assigning-non-slot + + g.input_values = ( # pylint: disable=assigning-non-slot + input_values + ) = inputs_to_dict(inputs) + g.state_values = inputs_to_dict(state) # pylint: disable=assigning-non-slot + changed_props = body.get("changedPropIds", []) + g.triggered_inputs = [ # pylint: disable=assigning-non-slot + {"prop_id": x, "value": input_values.get(x)} for x in changed_props + ] + + response = ( + g.dash_response # pylint: disable=assigning-non-slot + ) = flask.Response(mimetype="application/json") + + args = inputs_to_vals(inputs + state) + + try: + cb = self.callback_map[output] + func = cb["callback"] + g.background_callback_manager = ( + cb.get("manager") or self._background_manager + ) + g.ignore_register_page = cb.get("long", False) + + # Add args_grouping + inputs_state_indices = cb["inputs_state_indices"] + inputs_state = inputs + state + inputs_state = convert_to_AttributeDict(inputs_state) + + if cb.get("no_output"): + outputs_list = [] + elif not outputs_list: + # FIXME Old renderer support? + split_callback_id(output) + + # update args_grouping attributes + for s in inputs_state: + # check for pattern matching: list of inputs or state + if isinstance(s, list): + for pattern_match_g in s: + update_args_group(pattern_match_g, changed_props) + update_args_group(s, changed_props) + + args_grouping = map_grouping( + lambda ind: inputs_state[ind], inputs_state_indices + ) + + g.args_grouping = args_grouping # pylint: disable=assigning-non-slot + g.using_args_grouping = ( # pylint: disable=assigning-non-slot + not isinstance(inputs_state_indices, int) + and ( + inputs_state_indices + != list(range(grouping_len(inputs_state_indices))) ) ) + + # Add outputs_grouping + outputs_indices = cb["outputs_indices"] + if not isinstance(outputs_list, list): + flat_outputs = [outputs_list] + else: + flat_outputs = outputs_list + + if len(flat_outputs) > 0: + outputs_grouping = map_grouping( + lambda ind: flat_outputs[ind], outputs_indices + ) + g.outputs_grouping = ( + outputs_grouping # pylint: disable=assigning-non-slot + ) + g.using_outputs_grouping = ( # pylint: disable=assigning-non-slot + not isinstance(outputs_indices, int) + and outputs_indices != list(range(grouping_len(outputs_indices))) + ) + else: + g.outputs_grouping = [] + g.using_outputs_grouping = [] + g.updated_props = {} + + g.cookies = dict(**flask.request.cookies) + g.headers = dict(**flask.request.headers) + g.path = flask.request.full_path + g.remote = flask.request.remote_addr + g.origin = flask.request.origin + + except KeyError as missing_callback_function: + msg = f"Callback function not found for output '{output}', perhaps you forgot to prepend the '@'?" + raise KeyError(msg) from missing_callback_function + + ctx = copy_context() + # Create a partial function with the necessary arguments + # noinspection PyArgumentList + partial_func = functools.partial( + func, + *args, + outputs_list=outputs_list, + long_callback_manager=self._background_manager, + callback_context=g, + app=self, + app_on_error=self._on_error, + app_use_async=self._use_async, ) + + response_data = ctx.run(partial_func) + + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async], please install the dependencies via `pip install dash[async]` and make sure you arent passing `use_async=False` to the app." + ) + + response.set_data(response_data) return response def _setup_server(self): @@ -2227,65 +2403,133 @@ def router(): } inputs.update(self.routing_callback_inputs) - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - ) - def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ + if self._use_async: - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, ) + async def update(pathname_, search_, **states): + """ + Updates dash.page_container layout on page navigation. + Updates the stored page title which will trigger the clientside callback to update the app title + """ + + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) + ) - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break + # get layout + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break + else: + layout = html.H1("404 - Page not found") + title = self.title else: - layout = html.H1("404 - Page not found") - title = self.title - else: - layout = page.get("layout", "") - title = page["title"] - - if callable(layout): - layout = ( - layout(**path_variables, **query_parameters, **states) - if path_variables - else layout(**query_parameters, **states) - ) - if callable(title): - title = title(**path_variables) if path_variables else title() + layout = page.get("layout", "") + title = page["title"] - return layout, {"title": title} + if callable(layout): + layout = await execute_async_function( + layout, + **{**(path_variables or {}), **query_parameters, **states}, + ) + if callable(title): + title = await execute_async_function( + title, **(path_variables or {}) + ) - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) + return layout, {"title": title} + + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) + + # Set validation_layout + if not self.config.suppress_callback_exceptions: + self.validation_layout = html.Div( + [ + asyncio.run(execute_async_function(page["layout"])) + if callable(page["layout"]) + else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + + [ + # pylint: disable=not-callable + self.layout() + if callable(self.layout) + else self.layout + ] + ) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") + else: - # Set validation_layout - if not self.config.suppress_callback_exceptions: - self.validation_layout = html.Div( - [ - page["layout"]() if callable(page["layout"]) else page["layout"] - for page in _pages.PAGE_REGISTRY.values() - ] - + [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, ) + def update(pathname_, search_, **states): + """ + Updates dash.page_container layout on page navigation. + Updates the stored page title which will trigger the clientside callback to update the app title + """ + + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) + ) + + # get layout + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break + else: + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] + + if callable(layout): + layout = layout( + **{**(path_variables or {}), **query_parameters, **states} + ) + if callable(title): + title = title(**(path_variables or {})) + + return layout, {"title": title} + + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) + + # Set validation_layout + if not self.config.suppress_callback_exceptions: + self.validation_layout = html.Div( + [ + page["layout"]() + if callable(page["layout"]) + else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + + [ + # pylint: disable=not-callable + self.layout() + if callable(self.layout) + else self.layout + ] + ) if _ID_CONTENT not in self.validation_layout: raise Exception("`dash.page_container` not found in the layout") diff --git a/requirements/async.txt b/requirements/async.txt new file mode 100644 index 0000000000..fafa8e7e6e --- /dev/null +++ b/requirements/async.txt @@ -0,0 +1 @@ +flask[async] diff --git a/setup.py b/setup.py index ea616e2a18..7ed781c20d 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ def read_req_file(req_type): install_requires=read_req_file("install"), python_requires=">=3.8", extras_require={ + "async": read_req_file("async"), "ci": read_req_file("ci"), "dev": read_req_file("dev"), "testing": read_req_file("testing"), diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index fa51cda9d3..3161b8d7ea 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -72,14 +72,14 @@ def test_dveh001_python_errors(dash_duo): assert "Special 2 clicks exception" in error0 assert "in bad_sub" not in error0 # dash and flask part of the traceback not included - assert "%% callback invoked %%" not in error0 + assert "dash.py" not in error0 assert "self.wsgi_app" not in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "%% callback invoked %%" not in error1 + assert "dash.py" not in error1 assert "self.wsgi_app" not in error1 @@ -108,14 +108,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "%% callback invoked %%" in error0 + assert "dash.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "%% callback invoked %%" in error1 + assert "dash.py" in error1 assert "self.wsgi_app" in error1