From f60053d85ffb7f0fd4bb648906914370b7aa4598 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 9 Dec 2022 12:14:07 -0300 Subject: [PATCH] gr.ScatterPlot component (#2764) * Try clean install * Resolve peer dependencies? * CHANGELOG * Add outbreak_forcast notebook * generate again * CHANGELOG * Add image to changelog * Color palette * Fix colors + legend * Tooltip * Add axis titles * Clean up code a bit + quant scales * Add code * Add size, shape + rename legend title * Fix demo * Add update + demo * Handle darkmode better * Try new font * Use sans-serif * Add caption * Changelog + tests * More tests * Address comments * Make caption fontsize smaller and enable interactivity * Add docstrings + add height + width * Use normal font weight * Make last values keyword only Co-authored-by: Abubakar Abid * Fix typo * Accept value as fn * reword changelog a bit Co-authored-by: Abubakar Abid --- CHANGELOG.md | 36 ++- demo/native_plots/requirements.txt | 1 + demo/native_plots/run.ipynb | 1 + demo/native_plots/run.py | 12 + demo/native_plots/scatter_plot_demo.py | 47 +++ demo/scatterplot_component/requirements.txt | 1 + demo/scatterplot_component/run.ipynb | 1 + demo/scatterplot_component/run.py | 18 ++ gradio/__init__.py | 1 + gradio/components.py | 295 +++++++++++++++++- scripts/copy_demos.py | 1 + test/test_blocks.py | 8 +- test/test_components.py | 134 ++++++++ ui/packages/app/src/Blocks.svelte | 2 + ui/packages/app/src/Render.svelte | 3 + .../app/src/components/Plot/Plot.svelte | 13 +- ui/packages/app/src/main.ts | 29 +- ui/packages/plot/package.json | 1 + ui/packages/plot/src/Plot.svelte | 54 +++- ui/packages/plot/src/utils.ts | 30 ++ ui/pnpm-lock.yaml | 2 + 21 files changed, 666 insertions(+), 24 deletions(-) create mode 100644 demo/native_plots/requirements.txt create mode 100644 demo/native_plots/run.ipynb create mode 100644 demo/native_plots/run.py create mode 100644 demo/native_plots/scatter_plot_demo.py create mode 100644 demo/scatterplot_component/requirements.txt create mode 100644 demo/scatterplot_component/run.ipynb create mode 100644 demo/scatterplot_component/run.py create mode 100644 ui/packages/plot/src/utils.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 1de7d3f810652..b6084c7db141b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,40 @@ ## New Features: +### Scatter plot component + +It is now possible to create a scatter plot natively in Gradio! + +The `gr.ScatterPlot` component accepts a pandas dataframe and some optional configuration parameters +and will automatically create a plot for you! + +This is the first of many native plotting components in Gradio! + +For an example of how to use `gr.ScatterPlot` see below: + +```python +import gradio as gr +from vega_datasets import data + +cars = data.cars() + +with gr.Blocks() as demo: + gr.ScatterPlot(show_label=False, + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car").style(container=False) + +demo.launch() +``` + +By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2764](https://github.com/gradio-app/gradio/pull/2764) + + ### Support for altair plots The `Plot` component can now accept altair plots as values! @@ -33,7 +67,7 @@ demo.launch() By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2741](https://github.com/gradio-app/gradio/pull/2741) -### Set the background color of a Label component +### Set the background color of a Label component The `Label` component now accepts a `color` argument by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2736](https://github.com/gradio-app/gradio/pull/2736). The `color` argument should either be a valid css color name or hexadecimal string. diff --git a/demo/native_plots/requirements.txt b/demo/native_plots/requirements.txt new file mode 100644 index 0000000000000..d1c8a7ae0396d --- /dev/null +++ b/demo/native_plots/requirements.txt @@ -0,0 +1 @@ +vega_datasets \ No newline at end of file diff --git a/demo/native_plots/run.ipynb b/demo/native_plots/run.ipynb new file mode 100644 index 0000000000000..4d3070de8e8e1 --- /dev/null +++ b/demo/native_plots/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/native_plots/run.py b/demo/native_plots/run.py new file mode 100644 index 0000000000000..a0194d2917a4e --- /dev/null +++ b/demo/native_plots/run.py @@ -0,0 +1,12 @@ +import gradio as gr + +from scatter_plot_demo import scatter_plot + + +with gr.Blocks() as demo: + with gr.Tabs(): + with gr.TabItem("Scatter Plot"): + scatter_plot.render() + +if __name__ == "__main__": + demo.launch() diff --git a/demo/native_plots/scatter_plot_demo.py b/demo/native_plots/scatter_plot_demo.py new file mode 100644 index 0000000000000..223a0110df4fe --- /dev/null +++ b/demo/native_plots/scatter_plot_demo.py @@ -0,0 +1,47 @@ +import gradio as gr + +from vega_datasets import data + +cars = data.cars() +iris = data.iris() + + +def scatter_plot_fn(dataset): + if dataset == "iris": + return gr.ScatterPlot.update( + value=iris, + x="petalWidth", + y="petalLength", + color="species", + title="Iris Dataset", + color_legend_title="Species", + x_title="Petal Width", + y_title="Petal Length", + tooltip=["petalWidth", "petalLength", "species"], + caption="", + ) + else: + return gr.ScatterPlot.update( + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car", + caption="MPG vs Horsepower of various cars" + ) + + +with gr.Blocks() as scatter_plot: + with gr.Row(): + with gr.Column(): + dataset = gr.Dropdown(choices=["cars", "iris"], value="cars") + with gr.Column(): + plot = gr.ScatterPlot(show_label=False).style(container=True) + dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot) + scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot) + +if __name__ == "__main__": + scatter_plot.launch() diff --git a/demo/scatterplot_component/requirements.txt b/demo/scatterplot_component/requirements.txt new file mode 100644 index 0000000000000..d1c8a7ae0396d --- /dev/null +++ b/demo/scatterplot_component/requirements.txt @@ -0,0 +1 @@ +vega_datasets \ No newline at end of file diff --git a/demo/scatterplot_component/run.ipynb b/demo/scatterplot_component/run.ipynb new file mode 100644 index 0000000000000..75d7a60528abc --- /dev/null +++ b/demo/scatterplot_component/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: scatterplot_component"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from vega_datasets import data\n", "\n", "cars = data.cars()\n", "\n", "with gr.Blocks() as demo:\n", " gr.ScatterPlot(show_label=False,\n", " value=cars,\n", " x=\"Horsepower\",\n", " y=\"Miles_per_Gallon\",\n", " color=\"Origin\",\n", " tooltip=\"Name\",\n", " title=\"Car Data\",\n", " y_title=\"Miles per Gallon\",\n", " color_legend_title=\"Origin of Car\").style(container=False)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/scatterplot_component/run.py b/demo/scatterplot_component/run.py new file mode 100644 index 0000000000000..4b005c6390215 --- /dev/null +++ b/demo/scatterplot_component/run.py @@ -0,0 +1,18 @@ +import gradio as gr +from vega_datasets import data + +cars = data.cars() + +with gr.Blocks() as demo: + gr.ScatterPlot(show_label=False, + value=cars, + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + tooltip="Name", + title="Car Data", + y_title="Miles per Gallon", + color_legend_title="Origin of Car").style(container=False) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/gradio/__init__.py b/gradio/__init__.py index 3c76ebbc9ec5b..426c51b74fe99 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -35,6 +35,7 @@ Number, Plot, Radio, + ScatterPlot, Slider, State, StatusTracker, diff --git a/gradio/components.py b/gradio/components.py index bf629f396f854..452a5a5dac436 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -20,6 +20,7 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +import altair as alt import matplotlib.figure import numpy as np import pandas as pd @@ -28,6 +29,7 @@ from ffmpy import FFmpeg from markdown_it import MarkdownIt from mdit_py_plugins.dollarmath import dollarmath_plugin +from pandas.api.types import is_numeric_dtype from gradio import media_data, processing_utils, utils from gradio.blocks import Block @@ -3965,8 +3967,297 @@ def postprocess(self, y: str | None) -> Dict[str, str] | None: out_y = y.to_json() return {"type": dtype, "plot": out_y} - def style(self): - return self + def style(self, container: Optional[bool] = None): + return IOComponent.style( + self, + container=container, + ) + + +@document("change", "clear") +class ScatterPlot(Plot): + """ + Create a scatter plot. + + Preprocessing: this component does *not* accept input. + Postprocessing: expects a pandas dataframe with the data to plot. + + Demos: native_plots + """ + + def __init__( + self, + value: Optional[pd.DataFrame | Callable] = None, + x: Optional[str] = None, + y: Optional[str] = None, + *, + color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, + title: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, + x_title: Optional[str] = None, + y_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + caption: Optional[str] = None, + interactive: Optional[bool] = True, + label: Optional[str] = None, + show_label: bool = True, + visible: bool = True, + elem_id: Optional[str] = None, + ): + """ + Parameters: + value: The pandas dataframe containing the data to display in a scatter plot. + x: Column corresponding to the x axis. + y: Column corresponding to the y axis. + color: The column to determine the point color. If the column contains numeric data, gradio will interpolate the column data so that small values correspond to light colors and large values correspond to dark values. + size: The column used to determine the point size. Should contain numeric data so that gradio can map the data to the point size. + shape: The column used to determine the point shape. Should contain categorical data. Gradio will map each unique value to a different shape. + title: The title to display on top of the chart. + tooltip: The column (or list of columns) to display on the tooltip when a user hovers a point on the plot. + x_title: The title given to the x axis. By default, uses the value of the x parameter. + y_title: The title given to the y axis. By default, uses the value of the y parameter. + color_legend_title: The title given to the color legend. By default, uses the value of color parameter. + size_legend_title: The title given to the size legend. By default, uses the value of the size parameter. + shape_legend_title: The title given to the shape legend. By default, uses the value of the shape parameter. + height: The height of the plot in pixels. + width: The width of the plot in pixels. + caption: The (optional) caption to display below the plot. + interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. + label: The (optional) label to display on the top left corner of the plot. + show_label: Whether the label should be displayed. + visible: Whether the plot should be visible. + elem_id: Unique id used for custom css targetting. + """ + self.x = x + self.y = y + self.color = color + self.size = size + self.shape = shape + self.tooltip = tooltip + self.title = title + self.x_title = x_title + self.y_title = y_title + self.color_legend_title = color_legend_title + self.size_legend_title = size_legend_title + self.shape_legend_title = shape_legend_title + self.caption = caption + self.interactive_chart = interactive + self.width = width + self.height = height + # self.value = None + # if value is not None: + # self.value = self.postprocess(value) + super().__init__( + value=value, + label=label, + show_label=show_label, + visible=visible, + elem_id=elem_id, + ) + + def get_config(self): + config = super().get_config() + config["caption"] = self.caption + return config + + def get_block_name(self) -> str: + return "plot" + + @staticmethod + def update( + value: Optional[Any] = _Keywords.NO_VALUE, + x: Optional[str] = None, + y: Optional[str] = None, + color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, + title: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, + x_title: Optional[str] = None, + y_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + interactive: Optional[bool] = None, + caption: Optional[str] = None, + label: Optional[str] = None, + show_label: Optional[bool] = None, + visible: Optional[bool] = None, + ): + """Update an existing plot component. + + If updating any of the plot properties (color, size, etc) the value, x, and y parameters must be specified. + + Parameters: + value: The pandas dataframe containing the data to display in a scatter plot. + x: Column corresponding to the x axis. + y: Column corresponding to the y axis. + color: The column to determine the point color. If the column contains numeric data, gradio will interpolate the column data so that small values correspond to light colors and large values correspond to dark values. + size: The column used to determine the point size. Should contain numeric data so that gradio can map the data to the point size. + shape: The column used to determine the point shape. Should contain categorical data. Gradio will map each unique value to a different shape. + title: The title to display on top of the chart. + tooltip: The column (or list of columns) to display on the tooltip when a user hovers a point on the plot. + x_title: The title given to the x axis. By default, uses the value of the x parameter. + y_title: The title given to the y axis. By default, uses the value of the y parameter. + color_legend_title: The title given to the color legend. By default, uses the value of color parameter. + size_legend_title: The title given to the size legend. By default, uses the value of the size parameter. + shape_legend_title: The title given to the shape legend. By default, uses the value of the shape parameter. + height: The height of the plot in pixels. + width: The width of the plot in pixels. + caption: The (optional) caption to display below the plot. + interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad. + label: The (optional) label to display in the top left corner of the plot. + show_label: Whether the label should be displayed. + visible: Whether the plot should be visible. + """ + properties = [ + x, + y, + color, + size, + shape, + title, + tooltip, + x_title, + y_title, + color_legend_title, + size_legend_title, + shape_legend_title, + interactive, + height, + width, + ] + if any(properties): + if value is _Keywords.NO_VALUE: + raise ValueError( + "In order to update plot properties the value parameter " + "must be provided. Please pass a value parameter to " + "gr.ScatterPlot.update." + ) + if x is None or y is None: + raise ValueError( + "In order to update plot properties, the x and y axis data " + "must be specified. Please pass valid values for x an y to " + "gr.ScatterPlot.update." + ) + chart = ScatterPlot.create_plot(value, *properties) + value = {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} + + updated_config = { + "label": label, + "show_label": show_label, + "visible": visible, + "value": value, + "caption": caption, + "__type__": "update", + } + return updated_config + + @staticmethod + def create_plot( + value: pd.DataFrame, + x: str, + y: str, + color: Optional[str] = None, + size: Optional[str] = None, + shape: Optional[str] = None, + title: Optional[str] = None, + tooltip: Optional[List[str] | str] = None, + x_title: Optional[str] = None, + y_title: Optional[str] = None, + color_legend_title: Optional[str] = None, + size_legend_title: Optional[str] = None, + shape_legend_title: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + interactive: Optional[bool] = True, + ): + """Helper for creating the scatter plot.""" + interactive = True if interactive is None else interactive + encodings = dict( + x=alt.X(x, title=x_title or x), + y=alt.Y(y, title=y_title or y), + ) + properties = {} + if title: + properties["title"] = title + if height: + properties["height"] = height + if width: + properties["width"] = width + if color: + if is_numeric_dtype(value[color]): + domain = [value[color].min(), value[color].max()] + range_ = [0, 1] + type_ = "quantitative" + else: + domain = value[color].unique().tolist() + range_ = list(range(len(domain))) + type_ = "nominal" + + encodings["color"] = { + "field": color, + "type": type_, + "legend": {"title": color_legend_title or color}, + "scale": {"domain": domain, "range": range_}, + } + if tooltip: + encodings["tooltip"] = tooltip + if size: + encodings["size"] = { + "field": size, + "type": "quantitative" if is_numeric_dtype(value[size]) else "nominal", + "legend": {"title": size_legend_title or size}, + } + if shape: + encodings["shape"] = { + "field": shape, + "type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal", + "legend": {"title": shape_legend_title or shape}, + } + chart = ( + alt.Chart(value) + .mark_point() + .encode(**encodings) + .properties(background="transparent", **properties) + ) + if interactive: + chart = chart.interactive() + + return chart + + def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: + # if None or update + if y is None or isinstance(y, Dict): + return y + chart = self.create_plot( + value=y, + x=self.x, + y=self.y, + color=self.color, + size=self.size, + shape=self.shape, + title=self.title, + tooltip=self.tooltip, + x_title=self.x_title, + y_title=self.y_title, + color_legend_title=self.color_legend_title, + size_legend_title=self.size_legend_title, + shape_legend_title=self.size_legend_title, + interactive=self.interactive_chart, + height=self.height, + width=self.width, + ) + + return {"type": "altair", "plot": chart.to_json(), "chart": "scatter"} @document("change") diff --git a/scripts/copy_demos.py b/scripts/copy_demos.py index efd59cd494c91..a33e88f50a77f 100644 --- a/scripts/copy_demos.py +++ b/scripts/copy_demos.py @@ -28,6 +28,7 @@ def copy_all_demos(source_dir: str, dest_dir: str): "kitchen_sink_random", "matrix_transpose", "model3D", + "native_plots", "reset_components", "reverse_audio", "stt_or_tts", diff --git a/test/test_blocks.py b/test/test_blocks.py index 818496ad65b84..6d5225d6b2c0d 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -294,7 +294,9 @@ def test_slider_random_value_config(self): assert not any([dep["queue"] for dep in demo.config["dependencies"]]) def test_io_components_attach_load_events_when_value_is_fn(self, io_components): - io_components = [comp for comp in io_components if not (comp == gr.State)] + io_components = [ + comp for comp in io_components if comp not in [gr.State, gr.ScatterPlot] + ] interface = gr.Interface( lambda *args: None, inputs=[comp(value=lambda: None) for comp in io_components], @@ -307,7 +309,9 @@ def test_io_components_attach_load_events_when_value_is_fn(self, io_components): assert len(dependencies_on_load) == len(io_components) def test_blocks_do_not_filter_none_values_from_updates(self, io_components): - io_components = [c() for c in io_components if c not in [gr.State, gr.Button]] + io_components = [ + c() for c in io_components if c not in [gr.State, gr.Button, gr.ScatterPlot] + ] with gr.Blocks() as demo: for component in io_components: component.render() diff --git a/test/test_components.py b/test/test_components.py index 74a9d781cf6a7..447e26b720082 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -20,6 +20,7 @@ import pandas as pd import PIL import pytest +import vega_datasets from scipy.io import wavfile import gradio as gr @@ -1892,3 +1893,136 @@ def test_dataset_calls_as_example(*mocks): ], ) assert all([m.called for m in mocks]) + + +cars = vega_datasets.data.cars() + + +class TestScatterPlot: + def test_get_config(self): + assert gr.ScatterPlot().get_config() == { + "caption": None, + "elem_id": None, + "interactive": None, + "label": None, + "name": "plot", + "root_url": None, + "show_label": True, + "style": {}, + "value": None, + "visible": True, + } + + def test_no_color(self): + plot = gr.ScatterPlot( + x="Horsepower", + y="Miles_per_Gallon", + tooltip="Name", + title="Car Data", + x_title="Horse", + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert config["encoding"]["x"]["field"] == "Horsepower" + assert config["encoding"]["x"]["title"] == "Horse" + assert config["encoding"]["y"]["field"] == "Miles_per_Gallon" + assert config["selection"] == { + "selector001": { + "bind": "scales", + "encodings": ["x", "y"], + "type": "interval", + } + } + assert config["title"] == "Car Data" + assert "height" not in config + assert "width" not in config + + def test_no_interactive(self): + plot = gr.ScatterPlot( + x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert "selection" not in config + + def test_height_width(self): + plot = gr.ScatterPlot( + x="Horsepower", y="Miles_per_Gallon", height=100, width=200 + ) + output = plot.postprocess(cars) + assert sorted(list(output.keys())) == ["chart", "plot", "type"] + config = json.loads(output["plot"]) + assert config["height"] == 100 + assert config["width"] == 200 + + def test_color_encoding(self): + plot = gr.ScatterPlot( + x="Horsepower", + y="Miles_per_Gallon", + tooltip="Name", + title="Car Data", + color="Origin", + ) + output = plot.postprocess(cars) + config = json.loads(output["plot"]) + assert config["encoding"]["color"]["field"] == "Origin" + assert config["encoding"]["color"]["scale"] == { + "domain": ["USA", "Europe", "Japan"], + "range": [0, 1, 2], + } + assert config["encoding"]["color"]["type"] == "nominal" + + def test_two_encodings(self): + plot = gr.ScatterPlot( + show_label=False, + title="Two encodings", + x="Horsepower", + y="Miles_per_Gallon", + color="Acceleration", + shape="Origin", + ) + output = plot.postprocess(cars) + config = json.loads(output["plot"]) + assert config["encoding"]["color"]["field"] == "Acceleration" + assert config["encoding"]["color"]["scale"] == { + "domain": [cars.Acceleration.min(), cars.Acceleration.max()], + "range": [0, 1], + } + assert config["encoding"]["color"]["type"] == "quantitative" + + assert config["encoding"]["shape"]["field"] == "Origin" + assert config["encoding"]["shape"]["type"] == "nominal" + + def test_update(self): + output = gr.ScatterPlot.update(value=cars, x="Horsepower", y="Miles_per_Gallon") + postprocessed = gr.ScatterPlot().postprocess(output["value"]) + assert postprocessed == output["value"] + + def test_update_visibility(self): + output = gr.ScatterPlot.update(visible=False) + assert not output["visible"] + assert output["value"] is gr.components._Keywords.NO_VALUE + + def test_update_errors(self): + with pytest.raises( + ValueError, match="In order to update plot properties the value parameter" + ): + gr.ScatterPlot.update(x="foo", y="bar") + + with pytest.raises( + ValueError, + match="In order to update plot properties, the x and y axis data", + ): + gr.ScatterPlot.update(value=cars, x="foo") + + def test_scatterplot_accepts_fn_as_value(self): + plot = gr.ScatterPlot( + value=lambda: cars.sample(frac=0.1, replace=False), + x="Horsepower", + y="Miles_per_Gallon", + color="Origin", + ) + assert isinstance(plot.value, dict) + assert isinstance(plot.value["plot"], str) diff --git a/ui/packages/app/src/Blocks.svelte b/ui/packages/app/src/Blocks.svelte index 57f1de65a7834..902b9acadc87f 100644 --- a/ui/packages/app/src/Blocks.svelte +++ b/ui/packages/app/src/Blocks.svelte @@ -40,6 +40,7 @@ export let show_api: boolean = true; export let control_page_title = false; export let app_mode: boolean; + export let theme: string; let loading_status = create_loading_status_store(); @@ -419,6 +420,7 @@ {instance_map} {root} {target} + {theme} on:mount={handle_mount} on:destroy={({ detail }) => handle_destroy(detail)} /> diff --git a/ui/packages/app/src/Render.svelte b/ui/packages/app/src/Render.svelte index 87adf5693d5b5..2e070e24e8b3c 100644 --- a/ui/packages/app/src/Render.svelte +++ b/ui/packages/app/src/Render.svelte @@ -14,6 +14,7 @@ export let has_modes: boolean | undefined; export let parent: string | null = null; export let target: HTMLElement; + export let theme: string; const dispatch = createEventDispatcher<{ mount: number; destroy: number }>(); @@ -56,6 +57,7 @@ on:prop_change={handle_prop_change} {target} {...props} + {theme} {root} > {#if children && children.length} @@ -70,6 +72,7 @@ children={_children} {dynamic_ids} {has_modes} + {theme} on:destroy on:mount /> diff --git a/ui/packages/app/src/components/Plot/Plot.svelte b/ui/packages/app/src/components/Plot/Plot.svelte index d2eabf4eb20a2..1672d62783ce3 100644 --- a/ui/packages/app/src/components/Plot/Plot.svelte +++ b/ui/packages/app/src/components/Plot/Plot.svelte @@ -7,6 +7,7 @@ import StatusTracker from "../StatusTracker/StatusTracker.svelte"; import type { LoadingStatus } from "../StatusTracker/types"; import { _ } from "svelte-i18n"; + import type { Styles } from "@gradio/utils"; export let value: null | string = null; export let elem_id: string = ""; @@ -16,12 +17,20 @@ export let label: string; export let show_label: boolean; export let target: HTMLElement; + export let style: Styles = {}; + export let theme: string; + export let caption: string; - + - + diff --git a/ui/packages/app/src/main.ts b/ui/packages/app/src/main.ts index 1d498f1cb82e3..f6ce362edd132 100644 --- a/ui/packages/app/src/main.ts +++ b/ui/packages/app/src/main.ts @@ -192,6 +192,7 @@ function create_custom_element() { root: ShadowRoot; wrapper: HTMLDivElement; _id: number; + theme: string; constructor() { super(); @@ -210,6 +211,7 @@ function create_custom_element() { this.wrapper.style.position = "relative"; this.wrapper.style.width = "100%"; this.wrapper.style.minHeight = "100vh"; + this.theme = "light"; window.__gradio_loader__[this._id] = new Loader({ target: this.wrapper, @@ -223,7 +225,7 @@ function create_custom_element() { this.root.append(this.wrapper); if (window.__gradio_mode__ !== "website") { - handle_darkmode(this.wrapper); + this.theme = handle_darkmode(this.wrapper); } } @@ -268,6 +270,7 @@ function create_custom_element() { mount_app( { ...config, + theme: this.theme, control_page_title: control_page_title && control_page_title === "true" ? true : false }, @@ -305,8 +308,9 @@ async function unscoped_mount() { mount_app({ ...config, control_page_title: true }, false, target, 0); } -function handle_darkmode(target: HTMLDivElement) { +function handle_darkmode(target: HTMLDivElement): string { let url = new URL(window.location.toString()); + let theme = "light"; const color_mode: "light" | "dark" | "system" | null = url.searchParams.get( "__theme" @@ -314,39 +318,44 @@ function handle_darkmode(target: HTMLDivElement) { if (color_mode !== null) { if (color_mode === "dark") { - darkmode(target); + theme = darkmode(target); } else if (color_mode === "system") { - use_system_theme(target); + theme = use_system_theme(target); } // light is default, so we don't need to do anything else } else if (url.searchParams.get("__dark-theme") === "true") { - darkmode(target); + theme = darkmode(target); } else { - use_system_theme(target); + theme = use_system_theme(target); } + return theme; } -function use_system_theme(target: HTMLDivElement) { - update_scheme(); +function use_system_theme(target: HTMLDivElement): string { + const theme = update_scheme(); window ?.matchMedia("(prefers-color-scheme: dark)") ?.addEventListener("change", update_scheme); function update_scheme() { + let theme = "light"; const is_dark = window?.matchMedia?.("(prefers-color-scheme: dark)").matches ?? null; if (is_dark) { - darkmode(target); + theme = darkmode(target); } + return theme; } + return theme; } -function darkmode(target: HTMLDivElement) { +function darkmode(target: HTMLDivElement): string { target.classList.add("dark"); if (app_mode) { document.body.style.backgroundColor = "rgb(11, 15, 25)"; // bg-gray-950 for scrolling outside the body } + return "dark"; } // dev mode or if inside an iframe diff --git a/ui/packages/plot/package.json b/ui/packages/plot/package.json index 2c9d2fa155de7..524489a0edfdd 100644 --- a/ui/packages/plot/package.json +++ b/ui/packages/plot/package.json @@ -10,6 +10,7 @@ "dependencies": { "@gradio/icons": "workspace:^0.0.1", "@gradio/utils": "workspace:^0.0.1", + "@gradio/theme": "workspace:^0.0.1", "@rollup/plugin-json": "^5.0.2", "plotly.js-dist-min": "^2.10.1", "svelte-vega": "^1.2.0", diff --git a/ui/packages/plot/src/Plot.svelte b/ui/packages/plot/src/Plot.svelte index 774cc50648c7f..e6cc174e8775f 100644 --- a/ui/packages/plot/src/Plot.svelte +++ b/ui/packages/plot/src/Plot.svelte @@ -1,20 +1,55 @@ -