diff --git a/.changeset/breezy-bottles-hide.md b/.changeset/breezy-bottles-hide.md new file mode 100644 index 0000000000000..1fa595ceda9f4 --- /dev/null +++ b/.changeset/breezy-bottles-hide.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Fix type hints for render and on diff --git a/gradio/events.py b/gradio/events.py index 8dad643f35b47..5ec9e444ea78c 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -5,7 +5,17 @@ import dataclasses from functools import partial, wraps -from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Sequence, + Union, + cast, +) from gradio_client.documentation import document from jinja2 import Template @@ -145,6 +155,32 @@ class EventListenerMethod: event_name: str +if TYPE_CHECKING: + EventListenerCallable = Callable[ + [ + Union[Callable, None], + Union[Component, Sequence[Component], None], + Union[Block, Sequence[Block], Sequence[Component], Component, None], + Union[str, None, Literal[False]], + bool, + Literal["full", "minimal", "hidden"], + Union[bool, None], + bool, + int, + bool, + bool, + Union[Dict[str, Any], List[Dict[str, Any]], None], + Union[float, None], + Union[Literal["once", "multiple", "always_last"], None], + Union[str, None], + Union[int, None, Literal["default"]], + Union[str, None], + bool, + ], + Dependency, + ] + + class EventListener(str): def __new__(cls, event_name, *_args, **_kwargs): return super().__new__(cls, event_name) @@ -331,7 +367,7 @@ def inner(*args, **kwargs): def on( - triggers: Sequence[Any] | Any | None = None, + triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None, fn: Callable | None | Literal["decorator"] = "decorator", inputs: Component | list[Component] | set[Component] | None = None, outputs: Block | list[Block] | list[Component] | None = None, @@ -376,8 +412,10 @@ def on( """ from gradio.components.base import Component - if isinstance(triggers, EventListener): - triggers = [triggers] + triggers_typed = cast(EventListener, triggers) + + if isinstance(triggers_typed, EventListener): + triggers_typed = [triggers_typed] if isinstance(inputs, Component): inputs = [inputs] @@ -418,18 +456,18 @@ def inner(*args, **kwargs): if root_block is None: raise Exception("Cannot call on() outside of a gradio.Blocks context.") if triggers is None: - triggers = ( + methods = ( [EventListenerMethod(input, "change") for input in inputs] if inputs is not None else [] ) # type: ignore else: - triggers = [ - EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) - for t in triggers - ] # type: ignore + methods = [ + EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) # type: ignore + for t in triggers_typed + ] dep, dep_index = root_block.set_event_trigger( - triggers, + methods, fn, inputs, outputs, @@ -448,7 +486,7 @@ def inner(*args, **kwargs): show_api=show_api, trigger_mode=trigger_mode, ) - set_cancel_events(triggers, cancels) + set_cancel_events(methods, cancels) return Dependency(None, dep.get_config(), dep_index, fn) diff --git a/gradio/renderable.py b/gradio/renderable.py index 52bd85d034dbc..aa41af68cdaf7 100644 --- a/gradio/renderable.py +++ b/gradio/renderable.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Literal +from typing import TYPE_CHECKING, Callable, List, Literal, Sequence, Union, cast from gradio_client.documentation import document @@ -10,6 +10,9 @@ from gradio.events import EventListener, EventListenerMethod from gradio.layouts import Column, Row +if TYPE_CHECKING: + from gradio.events import EventListenerCallable + class Renderable: def __init__( @@ -76,8 +79,8 @@ def apply(self, *args, **kwargs): @document() def render( - inputs: list[Component] | None = None, - triggers: list[EventListener] | EventListener | None = None, + inputs: list[Component] | Component | None = None, + triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None, *, queue: bool = True, trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last", @@ -116,6 +119,8 @@ def show_split(text): btn = gr.Button("Clear") btn.click(lambda: gr.Textbox(value=""), None, text) """ + new_triggers = cast(Union[List[EventListener], EventListener, None], triggers) + if Context.root_block is None: raise ValueError("Reactive render must be inside a Blocks context.") @@ -123,16 +128,18 @@ def show_split(text): [inputs] if isinstance(inputs, Component) else [] if inputs is None else inputs ) _triggers: list[tuple[Block | None, str]] = [] - if triggers is None: + if new_triggers is None: _triggers = [(Context.root_block, "load")] for input in inputs: if hasattr(input, "change"): _triggers.append((input, "change")) else: - triggers = [triggers] if isinstance(triggers, EventListener) else triggers + new_triggers = ( + [new_triggers] if isinstance(new_triggers, EventListener) else new_triggers + ) _triggers = [ (getattr(t, "__self__", None) if t.has_trigger else None, t.event_name) - for t in triggers + for t in new_triggers ] def wrapper_function(fn):