From d977b4acc682fb55111e5ee144d8dc078528a431 Mon Sep 17 00:00:00 2001 From: Matthew Huebert Date: Fri, 7 Jun 2024 10:29:49 +0200 Subject: [PATCH] rename gen-studio/gen.studio to genstudio --- README.md | 8 +-- notebooks/anywidget_jax_callback.js | 2 +- notebooks/anywidget_jax_callback.py | 2 +- notebooks/plot_examples.py | 50 ++++++++++++++++--- pyproject.toml | 2 +- src/gen/studio/__init__.py | 0 src/{gen => genstudio}/__init__.py | 0 src/{gen/studio => genstudio}/js_modules.py | 2 +- src/{gen/studio => genstudio}/plot.py | 6 +-- .../scripts/observable_plot_metadata.js | 0 .../scripts/observable_plot_metadata.json | 0 .../studio => genstudio}/scripts/package.json | 0 .../studio => genstudio}/scripts/yarn.lock | 0 src/{gen/studio => genstudio}/util.py | 2 +- src/{gen/studio => genstudio}/widget.js | 43 +++++++--------- src/{gen/studio => genstudio}/widget.py | 2 +- tests/studio/test_plot.py | 6 +-- tests/studio/test_util.py | 2 +- 18 files changed, 76 insertions(+), 51 deletions(-) delete mode 100644 src/gen/studio/__init__.py rename src/{gen => genstudio}/__init__.py (100%) rename src/{gen/studio => genstudio}/js_modules.py (98%) rename src/{gen/studio => genstudio}/plot.py (99%) rename src/{gen/studio => genstudio}/scripts/observable_plot_metadata.js (100%) rename src/{gen/studio => genstudio}/scripts/observable_plot_metadata.json (100%) rename src/{gen/studio => genstudio}/scripts/package.json (100%) rename src/{gen/studio => genstudio}/scripts/yarn.lock (100%) rename src/{gen/studio => genstudio}/util.py (91%) rename src/{gen/studio => genstudio}/widget.js (94%) rename src/{gen/studio => genstudio}/widget.py (92%) diff --git a/README.md b/README.md index 8027aa8d..9ff6cdf4 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ _Visualization tools for GenJAX._ ----- -`gen.studio.plot` provides a composable way to create interactive plots using [Observable Plot](https://observablehq.com/plot/) +`genstudio.plot` provides a composable way to create interactive plots using [Observable Plot](https://observablehq.com/plot/) and [AnyWidget](https://github.com/manzt/anywidget), built on the work of [pyobsplot](https://github.com/juba/pyobsplot). Key features: @@ -17,10 +17,10 @@ Runnable examples are in `notebooks/plot_examples.py`. See [Observable Plot](htt ## Installation -gen-studio is published to the same artifact registry as genjax, so you can follow [these instructions](https://github.com/probcomp/genjax?tab=readme-ov-file#quickstart) but use `gen-studio` for the package name. +genstudio is published to the same artifact registry as genjax, so you can follow [these instructions](https://github.com/probcomp/genjax?tab=readme-ov-file#quickstart) but use `genstudio` for the package name. ``` -gen-studio = {version = "v2024.05.23.085705", source = "gcp"} +genstudio = {version = "v2024.05.23.085705", source = "gcp"} ``` @@ -30,7 +30,7 @@ gen-studio = {version = "v2024.05.23.085705", source = "gcp"} Given the following setup: ```py -import gen.studio.plot as Plot +import genstudio.plot as Plot import numpy as np def normal_100(): diff --git a/notebooks/anywidget_jax_callback.js b/notebooks/anywidget_jax_callback.js index 292807df..549f5f37 100644 --- a/notebooks/anywidget_jax_callback.js +++ b/notebooks/anywidget_jax_callback.js @@ -8,7 +8,7 @@ const html = htm.bind(React.createElement) const useCustomMessages = (model) => { const [messages, setMessages] = useState([]) const handleMessage = (message, info) => { - if (message.kind !== 'gen.studio') { + if (message.kind !== 'genstudio') { return; } setMessages((ms) => [...ms, message.content]); diff --git a/notebooks/anywidget_jax_callback.py b/notebooks/anywidget_jax_callback.py index 1a76cc94..1d34c35a 100644 --- a/notebooks/anywidget_jax_callback.py +++ b/notebooks/anywidget_jax_callback.py @@ -27,7 +27,7 @@ def _receive_message(self, msg, buffers): def effectful_fn(x): print(f"Performing an effect with {x}") - w.send({'kind': 'gen.studio', 'content': x.tolist()}) + w.send({'kind': 'genstudio', 'content': x.tolist()}) return x @jax.jit diff --git a/notebooks/plot_examples.py b/notebooks/plot_examples.py index f25ff0c1..f5bb45b0 100644 --- a/notebooks/plot_examples.py +++ b/notebooks/plot_examples.py @@ -15,10 +15,10 @@ # --- # %% -# %load_ext autoreload -# %autoreload 2 +%load_ext autoreload +%autoreload 2 -import gen.studio.plot as Plot +import genstudio.plot as Plot import numpy as np import genjax as genjax from genjax import gen @@ -93,14 +93,12 @@ def normal_100(): # %% [markdown] # # A GenJAX example -# %% - -key = jrand.PRNGKey(314159) -# %% [markdown] # A regression distribution. # %% +key = jrand.PRNGKey(314159) + @gen def regression(x, coefficients, sigma): basis_value = jnp.array([1.0, x, x**2]) @@ -151,7 +149,25 @@ def full_model(xs): # Data from GenJAX often comes in the form of multi-dimensional (nested) lists. # To prepare data for plotting, we can describe these dimensions using `Plot.dimensions`. # %% + ys = traces.get_choices()["ys", ..., "y", "v"] +data = Plot.dimensions(ys, ["sample", "ys"], leaves="y") + +# => + +data.flatten() +# => [{'sample': 0, 'ys': 0, 'y': Array(0.11651635, dtype=float32)}, +# {'sample': 0, 'ys': 1, 'y': Array(-5.046837, dtype=float32)}, +# {'sample': 0, 'ys': 2, 'y': Array(-0.9120707, dtype=float32)}, +# {'sample': 0, 'ys': 3, 'y': Array(0.4919241, dtype=float32)}, +# {'sample': 0, 'ys': 4, 'y': Array(1.081743, dtype=float32)}, +# {'sample': 0, 'ys': 5, 'y': Array(1.6471565, dtype=float32)}, +# {'sample': 0, 'ys': 6, 'y': Array(3.6472352, dtype=float32)}, +# {'sample': 0, 'ys': 7, 'y': Array(5.080149, dtype=float32)}, +# {'sample': 0, 'ys': 8, 'y': Array(6.961242, dtype=float32)}, +# {'sample': 0, 'ys': 9, 'y': Array(10.374397, dtype=float32)} ...] + +#%% # %% [markdown] # @@ -205,8 +221,23 @@ def full_model(xs): # Using `get_in` we've given names to each level of nesting (and leaf values), which we can see in the metadata # of the Dimensioned object: # %% + data = Plot.get_in(bean_data, [{...: "day"}, {...: "bean"}, {"leaves": "height"}]) -data +# => + +data.flatten() +# => [{'day': 0, 'bean': 0, 'height': 0}, +# {'day': 0, 'bean': 1, 'height': 0}, +# {'day': 0, 'bean': 2, 'height': 0}, +# {'day': 0, 'bean': 3, 'height': 0}, +# {'day': 0, 'bean': 4, 'height': 0}, +# {'day': 0, 'bean': 5, 'height': 0}, +# {'day': 0, 'bean': 6, 'height': 0}, +# {'day': 0, 'bean': 7, 'height': 0}, +# {'day': 1, 'bean': 0, 'height': 0.17486922945122418}, +# {'day': 1, 'bean': 1, 'height': 0.8780341204172442}, +# {'day': 1, 'bean': 2, 'height': 0.6476780304516665}, +# {'day': 1, 'bean': 3, 'height': 0.9339147036777222}, ...] # %%[markdown] # Now that our dimensions and leaf have names, we can pass them as options to `Plot.dot`. @@ -224,3 +255,6 @@ def full_model(xs): ) + Plot.frame() ) + +#%% +Plot.View.domainTest(Plot.dimensions(bean_data, ["day", "bean"], leaves="height")) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 13278557..4ae91c2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "gen-studio" +name = "genstudio" version = "0.1.0" description = "" authors = ["Matthew Huebert "] diff --git a/src/gen/studio/__init__.py b/src/gen/studio/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/gen/__init__.py b/src/genstudio/__init__.py similarity index 100% rename from src/gen/__init__.py rename to src/genstudio/__init__.py diff --git a/src/gen/studio/js_modules.py b/src/genstudio/js_modules.py similarity index 98% rename from src/gen/studio/js_modules.py rename to src/genstudio/js_modules.py index 5b0d522f..3f6c832b 100644 --- a/src/gen/studio/js_modules.py +++ b/src/genstudio/js_modules.py @@ -1,4 +1,4 @@ -from gen.studio.widget import Widget +from genstudio.widget import Widget class JSCall(dict): """Represents a JavaScript function call.""" diff --git a/src/gen/studio/plot.py b/src/genstudio/plot.py similarity index 99% rename from src/gen/studio/plot.py rename to src/genstudio/plot.py index a490e02b..f3c94fed 100644 --- a/src/gen/studio/plot.py +++ b/src/genstudio/plot.py @@ -3,9 +3,9 @@ import json import re -import gen.studio.util as util -from gen.studio.js_modules import JSRef, Hiccup, js, js_call -from gen.studio.widget import Widget +import genstudio.util as util +from genstudio.js_modules import JSRef, Hiccup, js, js_call +from genstudio.widget import Widget # This module provides a composable way to create interactive plots using Observable Plot # and AnyWidget, built on the work of pyobsplot. diff --git a/src/gen/studio/scripts/observable_plot_metadata.js b/src/genstudio/scripts/observable_plot_metadata.js similarity index 100% rename from src/gen/studio/scripts/observable_plot_metadata.js rename to src/genstudio/scripts/observable_plot_metadata.js diff --git a/src/gen/studio/scripts/observable_plot_metadata.json b/src/genstudio/scripts/observable_plot_metadata.json similarity index 100% rename from src/gen/studio/scripts/observable_plot_metadata.json rename to src/genstudio/scripts/observable_plot_metadata.json diff --git a/src/gen/studio/scripts/package.json b/src/genstudio/scripts/package.json similarity index 100% rename from src/gen/studio/scripts/package.json rename to src/genstudio/scripts/package.json diff --git a/src/gen/studio/scripts/yarn.lock b/src/genstudio/scripts/yarn.lock similarity index 100% rename from src/gen/studio/scripts/yarn.lock rename to src/genstudio/scripts/yarn.lock diff --git a/src/gen/studio/util.py b/src/genstudio/util.py similarity index 91% rename from src/gen/studio/util.py rename to src/genstudio/util.py index 3f5e17e5..34347e18 100644 --- a/src/gen/studio/util.py +++ b/src/genstudio/util.py @@ -36,6 +36,6 @@ def __exit__(self, *args): print(("%s : " + self.fmt + " seconds") % (self.msg, t)) self.time = t -PARENT_PATH = pathlib.Path(importlib.util.find_spec("gen.studio.util").origin).parent +PARENT_PATH = pathlib.Path(importlib.util.find_spec("genstudio.util").origin).parent # %% diff --git a/src/gen/studio/widget.js b/src/genstudio/widget.js similarity index 94% rename from src/gen/studio/widget.js rename to src/genstudio/widget.js index 8996a710..9e8b7374 100644 --- a/src/gen/studio/widget.js +++ b/src/genstudio/widget.js @@ -206,6 +206,20 @@ const repeat = (data) => { return (_, i) => data[i % data.length]; } +const getScales = (spec) => { + let plot = Plot.plot(spec); + let scales = {}; + for (const scaleName of ["x", "y", "r", "color", "opacity", "length", "symbol"]) { + scales[scaleName] = plot.scale(scaleName); + } + return scales; +} + +const domainTest = (data) => { + data = flatten(data.value, data.dimensions) + console.log(getScales({marks: [Plot.dot(data, {x: 'bean', y: 'height'})]})) +} + const scope = { d3, Plot, React, ReactDOM, @@ -216,7 +230,8 @@ const scope = { repeat, el, AutoGrid, - flatten + flatten, + domainTest }, } @@ -296,31 +311,7 @@ function PlotView({ spec, splitState }) { } function PlotWrapper({ spec }) { - const dimensionInfo = useMemo(() => { - return spec.marks.flatMap(mark => mark.dimensions).reduce((acc, dimension) => { - if (!dimension) { - acc - } else if (acc[dimension.key]) { - acc[dimension.key] = { ...acc[dimension.key], ...dimension }; - } else { - acc[dimension.key] = dimension; - } - return acc; - }, {}); - }, []); - - const [splitState, setSplitState] = useState( - Object.fromEntries(Object.entries(dimensionInfo).map(([k, d]) => [k, d.initial || 0])) - ); - return html` -
- <${PlotView} spec=${spec} splitState=${splitState}>
- <${SlidersView} - info=${dimensionInfo} - splitState=${splitState} - setSplitState=${setSplitState}/> - - ` + return html`<${PlotView} spec=${spec}>` } function normalizeDomains(PlotSpecs) { diff --git a/src/gen/studio/widget.py b/src/genstudio/widget.py similarity index 92% rename from src/gen/studio/widget.py rename to src/genstudio/widget.py index 3b4d0948..f52cd0df 100644 --- a/src/gen/studio/widget.py +++ b/src/genstudio/widget.py @@ -8,7 +8,7 @@ #%% # necessary for VS Code IPython interactive contexts -PARENT_PATH = Path(importlib.util.find_spec("gen.studio.widget").origin).parent +PARENT_PATH = Path(importlib.util.find_spec("genstudio.widget").origin).parent def to_json(data, _widget): def default(obj): diff --git a/tests/studio/test_plot.py b/tests/studio/test_plot.py index da23d909..148bb039 100644 --- a/tests/studio/test_plot.py +++ b/tests/studio/test_plot.py @@ -1,8 +1,8 @@ # %% -import gen.studio.plot as Plot -import gen.studio.util as util -from gen.studio.widget import Widget +import genstudio.plot as Plot +import genstudio.util as util +from genstudio.widget import Widget # Always reload (for dev) import importlib importlib.reload(Plot) diff --git a/tests/studio/test_util.py b/tests/studio/test_util.py index e7a41673..1c7337b7 100644 --- a/tests/studio/test_util.py +++ b/tests/studio/test_util.py @@ -1,6 +1,6 @@ #%% -from gen.studio.plot import JSRef, d3, Math +from genstudio.plot import JSRef, d3, Math def test_jswrapper_init(): wrapper = JSRef("TestModule", "test_method")