Skip to content

Commit

Permalink
Allows updating the dataset of a gr.Examples (#8745)
Browse files Browse the repository at this point in the history
* helpers

* add changeset

* changes

* add changeset

* changes

* tweak

* format

* example to docs

* add changeset

* fixes

* add tuple

* add changeset

* print

* format

* clean'

* clean

* format

* format backend

* fix backend tests

* format

* notebooks

* comment

* delete demo

* add changeset

* docstring

* docstring

* changes

* add changeset

* components

* changes

* changes

* format

* add test

* fix python test

* use deep_equal

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Jul 15, 2024
1 parent 2d179f6 commit 4030f28
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 69 deletions.
7 changes: 7 additions & 0 deletions .changeset/real-grapes-accept.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/dataframe": minor
"gradio": minor
"website": minor
---

feat:Allows updating the dataset of a `gr.Examples`
2 changes: 1 addition & 1 deletion demo/image_mod/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "\n", "demo = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/logo.png\"),\n", " os.path.join(os.path.abspath(''), \"images/tower.jpg\"),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "new_samples = [\n", " [os.path.join(os.path.abspath(''), \"images/logo.png\")],\n", " [os.path.join(os.path.abspath(''), \"images/tower.jpg\")],\n", "]\n", "\n", "with gr.Blocks() as demo:\n", " interface = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " )\n", "\n", " btn = gr.Button(\"Update Examples\")\n", " btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
30 changes: 18 additions & 12 deletions demo/image_mod/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@
def image_mod(image):
return image.rotate(45)

new_samples = [
[os.path.join(os.path.dirname(__file__), "images/logo.png")],
[os.path.join(os.path.dirname(__file__), "images/tower.jpg")],
]

demo = gr.Interface(
image_mod,
gr.Image(type="pil"),
"image",
flagging_options=["blurry", "incorrect", "other"],
examples=[
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
os.path.join(os.path.dirname(__file__), "images/logo.png"),
os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
],
)
with gr.Blocks() as demo:
interface = gr.Interface(
image_mod,
gr.Image(type="pil"),
"image",
flagging_options=["blurry", "incorrect", "other"],
examples=[
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
],
)

btn = gr.Button("Update Examples")
btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)

if __name__ == "__main__":
demo.launch()
1 change: 1 addition & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ async def postprocess_data(
kwargs.pop("value", None)
kwargs.pop("__type__")
kwargs["render"] = False

state[block._id] = block.__class__(**kwargs)

prediction_value = postprocess_update_dict(
Expand Down
24 changes: 14 additions & 10 deletions gradio/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
type: Literal["values", "index"] = "values",
type: Literal["values", "index", "tuple"] = "values",
samples_per_page: int = 10,
visible: bool = True,
elem_id: str | None = None,
Expand All @@ -51,7 +51,7 @@ def __init__(
components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video
samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component
headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels
type: 'values' if clicking on a sample should pass the value of the sample, or "index" if it should pass the index of the sample
type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample.
samples_per_page: how many examples to show per page.
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
Expand Down Expand Up @@ -95,18 +95,20 @@ def __init__(
self.proxy_url = proxy_url
for component in self._components:
component.proxy_url = proxy_url
self.samples = [[]] if samples is None else samples
for example in self.samples:
self.raw_samples = [[]] if samples is None else samples
self.samples: list[list] = []
for example in self.raw_samples:
self.samples.append([])
for i, (component, ex) in enumerate(zip(self._components, example)):
# If proxy_url is set, that means it is being loaded from an external Gradio app
# which means that the example has already been processed.
if self.proxy_url is None:
# The `as_example()` method has been renamed to `process_example()` but we
# use the previous name to be backwards-compatible with previously-created
# custom components
example[i] = component.as_example(ex)
example[i] = processing_utils.move_files_to_cache(
example[i], component, keep_in_cache=True
self.samples[-1].append(component.as_example(ex))
self.samples[-1][i] = processing_utils.move_files_to_cache(
self.samples[-1][i], component, keep_in_cache=True
)
self.type = type
self.label = label
Expand Down Expand Up @@ -137,19 +139,21 @@ def get_config(self):

return config

def preprocess(self, payload: int | None) -> int | list | None:
def preprocess(self, payload: int | None) -> int | list | tuple[int, list] | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index"), or as a `tuple` of the index and the data (if `type` is "tuple").
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]
return self.raw_samples[payload]
elif self.type == "tuple":
return payload, self.raw_samples[payload]

def postprocess(self, sample: int | list | None) -> int | None:
"""
Expand Down
73 changes: 41 additions & 32 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import ast
import copy
import csv
import inspect
import os
Expand All @@ -28,6 +29,7 @@
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger
from gradio.utils import UnhashableKeyDict

if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component
Expand Down Expand Up @@ -239,6 +241,7 @@ def __init__(
self.examples = examples
self.non_none_examples = non_none_examples
self.inputs = inputs
self.input_has_examples = input_has_examples
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs or []
self.fn = fn
Expand All @@ -248,35 +251,15 @@ def __init__(
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels

with utils.set_directory(working_directory):
self.processed_examples = []
for example in examples:
sub = []
for component, sample in zip(inputs, example):
prediction_value = component.postprocess(sample)
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
prediction_value = processing_utils.move_files_to_cache(
prediction_value,
component,
postprocess=True,
)
sub.append(prediction_value)
self.processed_examples.append(sub)

self.non_none_processed_examples = [
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in self.processed_examples
]
self.working_directory = working_directory

from gradio import components

with utils.set_directory(working_directory):
self.dataset = components.Dataset(
components=inputs_with_examples,
samples=non_none_examples,
type="index",
samples=copy.deepcopy(non_none_examples),
type="tuple",
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
Expand All @@ -290,13 +273,38 @@ def __init__(
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None
self.non_none_processed_examples = UnhashableKeyDict()

if self.dataset.samples:
for index, example in enumerate(self.non_none_examples):
self.non_none_processed_examples[self.dataset.samples[index]] = (
self._get_processed_example(example)
)

def _get_processed_example(self, example):
if example in self.non_none_processed_examples:
return self.non_none_processed_examples[example]
with utils.set_directory(self.working_directory):
sub = []
for component, sample in zip(self.inputs, example):
prediction_value = component.postprocess(sample)
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
prediction_value = processing_utils.move_files_to_cache(
prediction_value,
component,
postprocess=True,
)
sub.append(prediction_value)
return [ex for (ex, keep) in zip(sub, self.input_has_examples) if keep]

def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
component to hold the examples"""

async def load_example(example_id):
processed_example = self.non_none_processed_examples[example_id]
async def load_example(example_tuple):
_, example_value = example_tuple
processed_example = self._get_processed_example(example_value)
if len(self.inputs_with_examples) == 1:
return update(
value=processed_example[0],
Expand Down Expand Up @@ -496,9 +504,9 @@ async def get_final_item(*args):

if self.outputs is None:
raise ValueError("self.outputs is missing")
for example_id in range(len(self.examples)):
print(f"Caching example {example_id + 1}/{len(self.examples)}")
processed_input = self.processed_examples[example_id]
for i, example in enumerate(self.examples):
print(f"Caching example {i + 1}/{len(self.examples)}")
processed_input = self._get_processed_example(example)
if self.batch:
processed_input = [[value] for value in processed_input]
with utils.MatplotlibBackendMananger():
Expand All @@ -523,10 +531,11 @@ async def get_final_item(*args):
# method to be called independently of the create() method
blocks_config.fns.pop(self.load_input_event["id"])

def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + self.load_from_cache(example_id)
def load_example(example_tuple):
example_id, example_value = example_tuple
processed_example = self._get_processed_example(
example_value
) + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.cache_event = self.load_input_event = self.dataset.click(
Expand Down
40 changes: 40 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import wraps
from io import BytesIO
Expand Down Expand Up @@ -1424,3 +1425,42 @@ def error_payload(
content["duration"] = error.duration
content["visible"] = error.visible
return content


class UnhashableKeyDict(MutableMapping):
"""
Essentially a list of key-value tuples that allows for keys that are not hashable,
but acts like a dictionary for convenience.
"""

def __init__(self):
self.data = []

def __getitem__(self, key):
for k, v in self.data:
if deep_equal(k, key):
return v
raise KeyError(key)

def __setitem__(self, key, value):
for i, (k, _) in enumerate(self.data):
if deep_equal(k, key):
self.data[i] = (key, value)
return
self.data.append((key, value))

def __delitem__(self, key):
for i, (k, _) in enumerate(self.data):
if deep_equal(k, key):
del self.data[i]
return
raise KeyError(key)

def __iter__(self):
return (k for k, _ in self.data)

def __len__(self):
return len(self.data)

def as_list(self):
return [v for _, v in self.data]
25 changes: 25 additions & 0 deletions js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,31 @@ None
<ParamTable parameters={obj.attributes} />


### Examples

**Updating Examples**

In this demo, we show how to update the examples by updating the samples of the underlying dataset. Note that this only works if `cache_examples=False` as updating the underlying dataset does not update the cache.

```py
import gradio as gr

def update_examples(country):
if country == "USA":
return gr.Dataset(samples=[["Chicago"], ["Little Rock"], ["San Francisco"]])
else:
return gr.Dataset(samples=[["Islamabad"], ["Karachi"], ["Lahore"]])

with gr.Blocks() as demo:
dropdown = gr.Dropdown(label="Country", choices=["USA", "Pakistan"], value="USA")
textbox = gr.Textbox()
examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox)
dropdown.change(update_examples, dropdown, examples.dataset)

demo.launch()
```


{#if obj.demos && obj.demos.length > 0}
<!--- Demos -->
### Demos
Expand Down
20 changes: 20 additions & 0 deletions js/app/test/image_mod.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { test, expect } from "@gradio/tootils";

test("examples_get_updated_correctly", async ({ page }) => {
await page.locator(".gallery-item").first().click();
let image = await page.getByTestId("image").locator("img").first();
await expect(await image.getAttribute("src")).toContain("cheetah1.jpg");
await page.getByRole("button", { name: "Update Examples" }).click();

let example_image;
await expect(async () => {
example_image = await page.locator(".gallery-item").locator("img").first();
await expect(await example_image.getAttribute("src")).toContain("logo.png");
}).toPass();

await example_image.click();
await expect(async () => {
image = await page.getByTestId("image").locator("img").first();
await expect(await image.getAttribute("src")).toContain("logo.png");
}).toPass();
});
9 changes: 4 additions & 5 deletions js/dataframe/Example.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
export let index: number;
let hovered = false;
let loaded_value: (string | number)[][] | string = value;
let loaded = Array.isArray(loaded_value);
let loaded = Array.isArray(value);
</script>

{#if loaded}
Expand All @@ -19,11 +18,11 @@
on:mouseenter={() => (hovered = true)}
on:mouseleave={() => (hovered = false)}
>
{#if typeof loaded_value === "string"}
{loaded_value}
{#if typeof value === "string"}
{value}
{:else}
<table class="">
{#each loaded_value.slice(0, 3) as row, i}
{#each value.slice(0, 3) as row, i}
<tr>
{#each row.slice(0, 3) as cell, j}
<td>{cell}</td>
Expand Down
2 changes: 1 addition & 1 deletion test/components/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_preprocessing(self):
row = dataset.preprocess(1)
assert row[0] == 15
assert row[1] == "hi"
assert row[2]["path"].endswith("bus.png")
assert row[2].endswith("bus.png")
assert row[3] == "<i>Italics</i>"
assert row[4] == "*Italics*"

Expand Down
Loading

0 comments on commit 4030f28

Please sign in to comment.