Skip to content

Commit

Permalink
Fix type hints for render and on (#8429)
Browse files Browse the repository at this point in the history
* type hint

* add changeset

* Use union

* type check

* lint

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot authored Jun 3, 2024
1 parent 341844f commit d393a4a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .changeset/breezy-bottles-hide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Fix type hints for render and on
60 changes: 49 additions & 11 deletions gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@

import dataclasses
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Sequence,
Union,
cast,
)

from gradio_client.documentation import document
from jinja2 import Template
Expand Down Expand Up @@ -145,6 +155,32 @@ class EventListenerMethod:
event_name: str


if TYPE_CHECKING:
EventListenerCallable = Callable[
[
Union[Callable, None],
Union[Component, Sequence[Component], None],
Union[Block, Sequence[Block], Sequence[Component], Component, None],
Union[str, None, Literal[False]],
bool,
Literal["full", "minimal", "hidden"],
Union[bool, None],
bool,
int,
bool,
bool,
Union[Dict[str, Any], List[Dict[str, Any]], None],
Union[float, None],
Union[Literal["once", "multiple", "always_last"], None],
Union[str, None],
Union[int, None, Literal["default"]],
Union[str, None],
bool,
],
Dependency,
]


class EventListener(str):
def __new__(cls, event_name, *_args, **_kwargs):
return super().__new__(cls, event_name)
Expand Down Expand Up @@ -331,7 +367,7 @@ def inner(*args, **kwargs):


def on(
triggers: Sequence[Any] | Any | None = None,
triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Block | list[Block] | list[Component] | None = None,
Expand Down Expand Up @@ -376,8 +412,10 @@ def on(
"""
from gradio.components.base import Component

if isinstance(triggers, EventListener):
triggers = [triggers]
triggers_typed = cast(EventListener, triggers)

if isinstance(triggers_typed, EventListener):
triggers_typed = [triggers_typed]
if isinstance(inputs, Component):
inputs = [inputs]

Expand Down Expand Up @@ -418,18 +456,18 @@ def inner(*args, **kwargs):
if root_block is None:
raise Exception("Cannot call on() outside of a gradio.Blocks context.")
if triggers is None:
triggers = (
methods = (
[EventListenerMethod(input, "change") for input in inputs]
if inputs is not None
else []
) # type: ignore
else:
triggers = [
EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name)
for t in triggers
] # type: ignore
methods = [
EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) # type: ignore
for t in triggers_typed
]
dep, dep_index = root_block.set_event_trigger(
triggers,
methods,
fn,
inputs,
outputs,
Expand All @@ -448,7 +486,7 @@ def inner(*args, **kwargs):
show_api=show_api,
trigger_mode=trigger_mode,
)
set_cancel_events(triggers, cancels)
set_cancel_events(methods, cancels)
return Dependency(None, dep.get_config(), dep_index, fn)


Expand Down
19 changes: 13 additions & 6 deletions gradio/renderable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, Literal
from typing import TYPE_CHECKING, Callable, List, Literal, Sequence, Union, cast

from gradio_client.documentation import document

Expand All @@ -10,6 +10,9 @@
from gradio.events import EventListener, EventListenerMethod
from gradio.layouts import Column, Row

if TYPE_CHECKING:
from gradio.events import EventListenerCallable


class Renderable:
def __init__(
Expand Down Expand Up @@ -76,8 +79,8 @@ def apply(self, *args, **kwargs):

@document()
def render(
inputs: list[Component] | None = None,
triggers: list[EventListener] | EventListener | None = None,
inputs: list[Component] | Component | None = None,
triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
*,
queue: bool = True,
trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last",
Expand Down Expand Up @@ -116,23 +119,27 @@ def show_split(text):
btn = gr.Button("Clear")
btn.click(lambda: gr.Textbox(value=""), None, text)
"""
new_triggers = cast(Union[List[EventListener], EventListener, None], triggers)

if Context.root_block is None:
raise ValueError("Reactive render must be inside a Blocks context.")

inputs = (
[inputs] if isinstance(inputs, Component) else [] if inputs is None else inputs
)
_triggers: list[tuple[Block | None, str]] = []
if triggers is None:
if new_triggers is None:
_triggers = [(Context.root_block, "load")]
for input in inputs:
if hasattr(input, "change"):
_triggers.append((input, "change"))
else:
triggers = [triggers] if isinstance(triggers, EventListener) else triggers
new_triggers = (
[new_triggers] if isinstance(new_triggers, EventListener) else new_triggers
)
_triggers = [
(getattr(t, "__self__", None) if t.has_trigger else None, t.event_name)
for t in triggers
for t in new_triggers
]

def wrapper_function(fn):
Expand Down

0 comments on commit d393a4a

Please sign in to comment.