Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Name Endpoints if api_name is None #5782

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/tame-chairs-tan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Name Endpoints if api_name is None
25 changes: 24 additions & 1 deletion gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import string
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand Down Expand Up @@ -168,7 +169,7 @@ def event_trigger(
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be given the name of the python function fn. If no fn is passed in, it will be given the name 'unnamed'. If set to a string, the endpoint will be exposed in the api docs with the given name.
status_tracker: Deprecated and has no effect.
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
Expand Down Expand Up @@ -224,6 +225,28 @@ def inner(*args, **kwargs):
if isinstance(show_progress, bool):
show_progress = "full" if show_progress else "hidden"

if api_name is None:
if fn is not None:
Copy link
Member

@abidlabs abidlabs Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few edge cases where a user could supply a function, but it doesn't have a __name__ (e.g. if a user supplies a functools.partial). We should also set the name to "unnamed" in those cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea or a class with a __call__ method (reason why CI is failing)

if not hasattr(fn, "__name__"):
if hasattr(fn, "__class__") and hasattr(
fn.__class__, "__name__"
):
name = fn.__class__.__name__
else:
name = "unnamed"
else:
name = fn.__name__
api_name = "".join(
[
s
for s in name
if s not in set(string.punctuation) - {"-", "_"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could potentially use utils.strip_invalid_filename_characters for this

]
)
else:
# Don't document _js only events
api_name = False

dep, dep_index = block.set_event_trigger(
_event_name,
fn,
Expand Down
14 changes: 8 additions & 6 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,12 @@ def cleanup():
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
api_name=False,
).then(
self.fn,
self.input_components,
self.output_components,
api_name=self.api_name if i == 0 else None,
api_name=self.api_name if i == 0 else False,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
Expand All @@ -647,6 +648,7 @@ def cleanup():
inputs=None,
outputs=extra_output, # type: ignore
queue=False,
api_name=False,
)

stop_btn.click(
Expand All @@ -655,6 +657,7 @@ def cleanup():
outputs=[submit_btn, stop_btn],
cancels=predict_events,
queue=False,
api_name=False,
)
else:
for i, trigger in enumerate(triggers):
Expand All @@ -663,7 +666,7 @@ def cleanup():
fn,
self.input_components,
self.output_components,
api_name=self.api_name if i == 0 else None,
api_name=self.api_name if i == 0 else False,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
Expand Down Expand Up @@ -740,19 +743,18 @@ def attach_flagging_events(
None,
flag_btn,
queue=False,
api_name=False,
)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=flag_btn,
preprocess=False,
queue=False,
api_name=False,
)
clear_btn.click(
flag_method.reset,
None,
flag_btn,
queue=False,
flag_method.reset, None, flag_btn, queue=False, api_name=False
)

def render_examples(self):
Expand Down
28 changes: 14 additions & 14 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self):
lambda x: f"Hello, {x}", inputs=input_, outputs=output, queue=True
)

with pytest.raises(ValueError, match="The queue is enabled for event 0"):
with pytest.raises(ValueError, match="The queue is enabled for event lambda"):
demo.launch(prevent_thread_lock=True)

demo.close()
Expand Down Expand Up @@ -463,8 +463,8 @@ def create_images(n_images):
outputs=gallery,
)
with connect(demo) as client:
client.predict(3)
_ = client.predict(3)
client.predict(3, api_name="/predict")
_ = client.predict(3, api_name="/predict")
# only three files created and in temp directory
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 3

Expand All @@ -478,9 +478,9 @@ def test_no_empty_image_files(self, gradio_temp_dir, connect):
outputs=gr.Image(),
)
with connect(demo) as client:
_ = client.predict(image)
_ = client.predict(image)
_ = client.predict(image)
_ = client.predict(image, api_name="/predict")
_ = client.predict(image, api_name="/predict")
_ = client.predict(image, api_name="/predict")
# Upload creates a file. image preprocessing creates another one.
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

Expand All @@ -489,8 +489,8 @@ def test_file_component_uploads(self, component, connect, gradio_temp_dir):
code_file = str(pathlib.Path(__file__))
demo = gr.Interface(lambda x: x.name, component(), gr.File())
with connect(demo) as client:
_ = client.predict(code_file)
_ = client.predict(code_file)
_ = client.predict(code_file, api_name="/predict")
_ = client.predict(code_file, api_name="/predict")
# the upload route hashees the files so we get 1 from there
# We create two tempfiles (empty) because API says we return
# preprocess/postprocess will create the same file as the upload route
Expand All @@ -502,8 +502,8 @@ def test_no_empty_video_files(self, gradio_temp_dir, connect):
video = str(file_dir / "video_sample.mp4")
demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video())
with connect(demo) as client:
_ = client.predict({"video": video})
_ = client.predict({"video": video})
_ = client.predict({"video": video}, api_name="/predict")
_ = client.predict({"video": video}, api_name="/predict")
# Upload route and postprocessing return the same file
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

Expand All @@ -517,8 +517,8 @@ def reverse_audio(audio):

demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio())
with connect(demo) as client:
_ = client.predict(audio)
_ = client.predict(audio)
_ = client.predict(audio, api_name="/predict")
_ = client.predict(audio, api_name="/predict")
# One for upload and one for reversal
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

Expand Down Expand Up @@ -1495,12 +1495,12 @@ def test_many_endpoints(self):
t5 = gr.Textbox()
t1.change(lambda x: x, t1, t2, api_name="change1")
t2.change(lambda x: x, t2, t3, api_name="change2")
t3.change(lambda x: x, t3, t4)
t3.change(lambda x: x, t3, t4, api_name=False)
t4.change(lambda x: x, t4, t5, api_name=False)

api_info = demo.get_api_info()
assert len(api_info["named_endpoints"]) == 2
assert len(api_info["unnamed_endpoints"]) == 1
assert len(api_info["unnamed_endpoints"]) == 0

def test_no_endpoints(self):
with gr.Blocks() as demo:
Expand Down
96 changes: 96 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains tests for networking.py and app.py"""
import functools
import json
import os
import tempfile
Expand All @@ -23,6 +24,7 @@
close_all,
routes,
)
from gradio.route_utils import FnIndexInferError


@pytest.fixture()
Expand Down Expand Up @@ -698,3 +700,97 @@ def test_file_route_does_not_allow_dot_paths(tmp_path):
assert client.get("/file=.env").status_code == 403
assert client.get("/file=subdir/.env").status_code == 403
assert client.get("/file=.versioncontrol/settings").status_code == 403


def test_api_name_set_for_all_events(connect):
with gr.Blocks() as demo:
i = Textbox()
o = Textbox()
btn = Button()
btn1 = Button()
btn2 = Button()
btn3 = Button()
btn4 = Button()
btn5 = Button()
btn6 = Button()
btn7 = Button()
btn8 = Button()

def greet(i):
return "Hello " + i

def goodbye(i):
return "Goodbye " + i

def greet_me(i):
return "Hello"

def say_goodbye(i):
return "Goodbye"

say_goodbye.__name__ = "Say_$$_goodbye"

# Otherwise changed by ruff
foo = lambda s: s # noqa

def foo2(s):
return s + " foo"

foo2.__name__ = "foo-2"

class Callable:
def __call__(self, a) -> str:
return "From __call__"

def from_partial(a, b):
return b + a

part = functools.partial(from_partial, b="From partial: ")

btn.click(greet, i, o)
btn1.click(goodbye, i, o)
btn2.click(greet_me, i, o)
btn3.click(say_goodbye, i, o)
btn4.click(None, i, o)
btn5.click(foo, i, o)
btn6.click(foo2, i, o)
btn7.click(Callable(), i, o)
btn8.click(part, i, o)

with closing(demo) as io:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
assert client.post(
"/api/greet", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Hello freddy"]
assert client.post(
"/api/goodbye", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Goodbye freddy"]
assert client.post(
"/api/greet_me", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Hello"]
assert client.post(
"/api/Say__goodbye", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Goodbye"]
assert client.post(
"/api/lambda", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["freddy"]
assert client.post(
"/api/foo-2", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["freddy foo"]
assert client.post(
"/api/Callable", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["From __call__"]
assert client.post(
"/api/partial", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["From partial: freddy"]
with pytest.raises(FnIndexInferError):
client.post(
"/api/Say_goodbye", json={"data": ["freddy"], "session_hash": "foo"}
)

with connect(demo) as client:
assert client.predict("freddy", api_name="/greet") == "Hello freddy"
assert client.predict("freddy", api_name="/goodbye") == "Goodbye freddy"
assert client.predict("freddy", api_name="/greet_me") == "Hello"
assert client.predict("freddy", api_name="/Say__goodbye") == "Goodbye"
Loading