Skip to content

Commit

Permalink
rename gen-studio/gen.studio to genstudio
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuebert committed Jun 7, 2024
1 parent 0dadb2d commit d977b4a
Show file tree
Hide file tree
Showing 18 changed files with 76 additions and 51 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _Visualization tools for GenJAX._

-----

`gen.studio.plot` provides a composable way to create interactive plots using [Observable Plot](https://observablehq.com/plot/)
`genstudio.plot` provides a composable way to create interactive plots using [Observable Plot](https://observablehq.com/plot/)
and [AnyWidget](https://github.com/manzt/anywidget), built on the work of [pyobsplot](https://github.com/juba/pyobsplot).

Key features:
Expand All @@ -17,10 +17,10 @@ Runnable examples are in `notebooks/plot_examples.py`. See [Observable Plot](htt

## Installation

gen-studio is published to the same artifact registry as genjax, so you can follow [these instructions](https://github.com/probcomp/genjax?tab=readme-ov-file#quickstart) but use `gen-studio` for the package name.
genstudio is published to the same artifact registry as genjax, so you can follow [these instructions](https://github.com/probcomp/genjax?tab=readme-ov-file#quickstart) but use `genstudio` for the package name.

```
gen-studio = {version = "v2024.05.23.085705", source = "gcp"}
genstudio = {version = "v2024.05.23.085705", source = "gcp"}
```


Expand All @@ -30,7 +30,7 @@ gen-studio = {version = "v2024.05.23.085705", source = "gcp"}
Given the following setup:

```py
import gen.studio.plot as Plot
import genstudio.plot as Plot
import numpy as np

def normal_100():
Expand Down
2 changes: 1 addition & 1 deletion notebooks/anywidget_jax_callback.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const html = htm.bind(React.createElement)
const useCustomMessages = (model) => {
const [messages, setMessages] = useState([])
const handleMessage = (message, info) => {
if (message.kind !== 'gen.studio') {
if (message.kind !== 'genstudio') {
return;
}
setMessages((ms) => [...ms, message.content]);
Expand Down
2 changes: 1 addition & 1 deletion notebooks/anywidget_jax_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _receive_message(self, msg, buffers):

def effectful_fn(x):
print(f"Performing an effect with {x}")
w.send({'kind': 'gen.studio', 'content': x.tolist()})
w.send({'kind': 'genstudio', 'content': x.tolist()})
return x

@jax.jit
Expand Down
50 changes: 42 additions & 8 deletions notebooks/plot_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
# ---

# %%
# %load_ext autoreload
# %autoreload 2
%load_ext autoreload
%autoreload 2

import gen.studio.plot as Plot
import genstudio.plot as Plot
import numpy as np
import genjax as genjax
from genjax import gen
Expand Down Expand Up @@ -93,14 +93,12 @@ def normal_100():
# %% [markdown]
#
# A GenJAX example
# %%

key = jrand.PRNGKey(314159)

# %% [markdown]
# A regression distribution.
# %%

key = jrand.PRNGKey(314159)

@gen
def regression(x, coefficients, sigma):
basis_value = jnp.array([1.0, x, x**2])
Expand Down Expand Up @@ -151,7 +149,25 @@ def full_model(xs):
# Data from GenJAX often comes in the form of multi-dimensional (nested) lists.
# To prepare data for plotting, we can describe these dimensions using `Plot.dimensions`.
# %%

ys = traces.get_choices()["ys", ..., "y", "v"]
data = Plot.dimensions(ys, ["sample", "ys"], leaves="y")

# => <Dimensioned shape=(9, 20), names=['sample', 'ys', 'y']>

data.flatten()
# => [{'sample': 0, 'ys': 0, 'y': Array(0.11651635, dtype=float32)},
# {'sample': 0, 'ys': 1, 'y': Array(-5.046837, dtype=float32)},
# {'sample': 0, 'ys': 2, 'y': Array(-0.9120707, dtype=float32)},
# {'sample': 0, 'ys': 3, 'y': Array(0.4919241, dtype=float32)},
# {'sample': 0, 'ys': 4, 'y': Array(1.081743, dtype=float32)},
# {'sample': 0, 'ys': 5, 'y': Array(1.6471565, dtype=float32)},
# {'sample': 0, 'ys': 6, 'y': Array(3.6472352, dtype=float32)},
# {'sample': 0, 'ys': 7, 'y': Array(5.080149, dtype=float32)},
# {'sample': 0, 'ys': 8, 'y': Array(6.961242, dtype=float32)},
# {'sample': 0, 'ys': 9, 'y': Array(10.374397, dtype=float32)} ...]

#%%

# %% [markdown]
#
Expand Down Expand Up @@ -205,8 +221,23 @@ def full_model(xs):
# Using `get_in` we've given names to each level of nesting (and leaf values), which we can see in the metadata
# of the Dimensioned object:
# %%

data = Plot.get_in(bean_data, [{...: "day"}, {...: "bean"}, {"leaves": "height"}])
data
# => <Dimensioned shape=(21, 8), names=['day', 'bean', 'height']>

data.flatten()
# => [{'day': 0, 'bean': 0, 'height': 0},
# {'day': 0, 'bean': 1, 'height': 0},
# {'day': 0, 'bean': 2, 'height': 0},
# {'day': 0, 'bean': 3, 'height': 0},
# {'day': 0, 'bean': 4, 'height': 0},
# {'day': 0, 'bean': 5, 'height': 0},
# {'day': 0, 'bean': 6, 'height': 0},
# {'day': 0, 'bean': 7, 'height': 0},
# {'day': 1, 'bean': 0, 'height': 0.17486922945122418},
# {'day': 1, 'bean': 1, 'height': 0.8780341204172442},
# {'day': 1, 'bean': 2, 'height': 0.6476780304516665},
# {'day': 1, 'bean': 3, 'height': 0.9339147036777222}, ...]

# %%[markdown]
# Now that our dimensions and leaf have names, we can pass them as options to `Plot.dot`.
Expand All @@ -224,3 +255,6 @@ def full_model(xs):
)
+ Plot.frame()
)

#%%
Plot.View.domainTest(Plot.dimensions(bean_data, ["day", "bean"], leaves="height"))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.poetry]
name = "gen-studio"
name = "genstudio"
version = "0.1.0"
description = ""
authors = ["Matthew Huebert <me@matt.is>"]
Expand Down
Empty file removed src/gen/studio/__init__.py
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gen.studio.widget import Widget
from genstudio.widget import Widget

class JSCall(dict):
"""Represents a JavaScript function call."""
Expand Down
6 changes: 3 additions & 3 deletions src/gen/studio/plot.py → src/genstudio/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import json
import re

import gen.studio.util as util
from gen.studio.js_modules import JSRef, Hiccup, js, js_call
from gen.studio.widget import Widget
import genstudio.util as util
from genstudio.js_modules import JSRef, Hiccup, js, js_call
from genstudio.widget import Widget

# This module provides a composable way to create interactive plots using Observable Plot
# and AnyWidget, built on the work of pyobsplot.
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/gen/studio/util.py → src/genstudio/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def __exit__(self, *args):
print(("%s : " + self.fmt + " seconds") % (self.msg, t))
self.time = t

PARENT_PATH = pathlib.Path(importlib.util.find_spec("gen.studio.util").origin).parent
PARENT_PATH = pathlib.Path(importlib.util.find_spec("genstudio.util").origin).parent

# %%
43 changes: 17 additions & 26 deletions src/gen/studio/widget.js → src/genstudio/widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ const repeat = (data) => {
return (_, i) => data[i % data.length];
}

const getScales = (spec) => {
let plot = Plot.plot(spec);
let scales = {};
for (const scaleName of ["x", "y", "r", "color", "opacity", "length", "symbol"]) {
scales[scaleName] = plot.scale(scaleName);
}
return scales;
}

const domainTest = (data) => {
data = flatten(data.value, data.dimensions)
console.log(getScales({marks: [Plot.dot(data, {x: 'bean', y: 'height'})]}))
}

const scope = {
d3,
Plot, React, ReactDOM,
Expand All @@ -216,7 +230,8 @@ const scope = {
repeat,
el,
AutoGrid,
flatten
flatten,
domainTest
},

}
Expand Down Expand Up @@ -296,31 +311,7 @@ function PlotView({ spec, splitState }) {
}

function PlotWrapper({ spec }) {
const dimensionInfo = useMemo(() => {
return spec.marks.flatMap(mark => mark.dimensions).reduce((acc, dimension) => {
if (!dimension) {
acc
} else if (acc[dimension.key]) {
acc[dimension.key] = { ...acc[dimension.key], ...dimension };
} else {
acc[dimension.key] = dimension;
}
return acc;
}, {});
}, []);

const [splitState, setSplitState] = useState(
Object.fromEntries(Object.entries(dimensionInfo).map(([k, d]) => [k, d.initial || 0]))
);
return html`
<div>
<${PlotView} spec=${spec} splitState=${splitState}></div>
<${SlidersView}
info=${dimensionInfo}
splitState=${splitState}
setSplitState=${setSplitState}/>
</div>
`
return html`<${PlotView} spec=${spec}></div>`
}

function normalizeDomains(PlotSpecs) {
Expand Down
2 changes: 1 addition & 1 deletion src/gen/studio/widget.py → src/genstudio/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#%%
# necessary for VS Code IPython interactive contexts
PARENT_PATH = Path(importlib.util.find_spec("gen.studio.widget").origin).parent
PARENT_PATH = Path(importlib.util.find_spec("genstudio.widget").origin).parent

def to_json(data, _widget):
def default(obj):
Expand Down
6 changes: 3 additions & 3 deletions tests/studio/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# %%

import gen.studio.plot as Plot
import gen.studio.util as util
from gen.studio.widget import Widget
import genstudio.plot as Plot
import genstudio.util as util
from genstudio.widget import Widget
# Always reload (for dev)
import importlib
importlib.reload(Plot)
Expand Down
2 changes: 1 addition & 1 deletion tests/studio/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#%%

from gen.studio.plot import JSRef, d3, Math
from genstudio.plot import JSRef, d3, Math

def test_jswrapper_init():
wrapper = JSRef("TestModule", "test_method")
Expand Down

0 comments on commit d977b4a

Please sign in to comment.