Skip to content

Commit

Permalink
WIP: Plot.get_choice, style comparisons
Browse files Browse the repository at this point in the history
Doing a syntax review/comparison after implementing Plot.get_choice/Plot.get_in and facet-based gridding.

It's starting to look like idiomatic Observable.Plot & custom "dimensional" choicemap accessors may not result in the nicest api surface. List comprehensions + custom views are hard to beat.
  • Loading branch information
mhuebert committed Jun 5, 2024
1 parent 17e02ae commit 79234d9
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 108 deletions.
54 changes: 43 additions & 11 deletions notebooks/plot_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
%autoreload 2

import gen.studio.plot as Plot
from gen.studio.js_modules import Hiccup
import numpy as np
import genjax as genjax
from genjax import gen
Expand Down Expand Up @@ -119,27 +118,60 @@ def full_model(xs):
key, *sub_keys = jrand.split(key, 10)
traces = jax.vmap(lambda k: full_model.simulate(k, (data,)))(jnp.array(sub_keys))

# Plot.get_choice(traces, "ys", {...: 'sample'}, "y", "v", {'as': 'y'})
Plot.dot(traces.get_choices()["ys", ..., "y", "v"],
dimensions=["sample", "ys", {'as': 'y'}],
grid=["sample", {"columns": 3}],
# Compare the following two plots.

# using Plot.get_choice & Observable faceting:

Plot.dot(Plot.get_choice(traces, ["ys", {...: 'sample'}, "y", "v", {...: "ys"}, {'leaves': 'y'}]),
facetGrid="sample",
x=Plot.repeat(data),
y='y') + {'height': 600} + Plot.frame()

# using list comprehensions & hiccup composition:

Plot.small_multiples([
Plot.dot({"x": data, "y": ys}) + Plot.frame() for ys in traces.get_sample()["ys", ..., "y", "v"]
])

#%%

ch = traces.get_choices()
Plot.get_choice(traces, ["ys", {...: 'sample'}, "y", "v", {...: "ys"}, {'leaves': 'y'}]).flatten()
# ch("ys")(...)("y")("v")


# %% [markdown]

#### Flattening dimensions with `get_choice` and `get_in`
#### Handling multi-dimensional data

# When working with Gen
# When working with JAX or Numpy we often receive nested arrays.
# `Plot.get_in` and `Plot.get_choice` retrieve values from nested structures,
# while specifying names for the dimensions and leaf nodes encountered.

import random
bean_data = [[0 for _ in range(20)]]
for day in range(1, 21):
bean_data.append([height + random.uniform(0.5, 5) for height in bean_data[-1]])
bean_data

Plot.dot(Plot.get_in(bean_data, {...: 'day'}, {...: 'bean'}, {'as': 'height'}),
# Bean data:

# using Plot.get_in & Observable faceting:

Plot.dot(Plot.get_in(bean_data, {...: 'day'}, {...: 'bean'}, {'leaves': 'height'}),
{'x': 'day',
'y': 'height',
'grid': ["bean", {'columns': 3}],
'r': 2,
}) + Plot.frame() + {'height': 800}
'facetGrid': "bean"
}) + Plot.frame() + {'height': 800}

# using list comprehensions & hiccup composition:

Plot.small_multiples([
Plot.dot({'x': list(range(1, 21)),
'y': [bean_data[day][bean] for day in range(1, 21)]}) for bean in range(20)
])

#### Grid views

#%%

96 changes: 56 additions & 40 deletions src/gen/studio/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,25 @@ def repeat(data):
class Dimensioned:
def __init__(self, value, path):
self.value = value
self.dimensions = [rename_key(segment, ..., 'key') for segment in path if isinstance(segment, dict)]
self.dimensions = [rename_key(segment, ..., 'key') for segment in path if isinstance(segment, dict)]
def shape(self):
shape = ()
current_value = self.value
for dimension in self.dimensions:
if 'leaves' not in dimension:
shape += (len(current_value),)
current_value = current_value[0]
return shape
def names(self):
return [dimension.get('key', dimension.get('leaves')) for dimension in self.dimensions]
def __repr__(self):
return f"<Dimensioned shape={self.shape()}, names={self.names()}>"

def flatten(self):
# flattens the data in python, rather than js.
# currently we are not using/recommending this
# but it may be useful later or for debugging.
leaf = self.dimensions[-1]['as'] if isinstance(self.dimensions[-1], dict) and 'as' in self.dimensions[-1] else None
leaf = self.dimensions[-1]['leaves'] if isinstance(self.dimensions[-1], dict) and 'leaves' in self.dimensions[-1] else None
dimensions = self.dimensions[:-1] if leaf else self.dimensions

def _flatten(value, dims, prefix=None):
Expand All @@ -59,63 +72,66 @@ def _flatten(value, dims, prefix=None):
results.extend(_flatten(v, dims[1:], new_prefix))
return results
return _flatten(self.value, dimensions)

def to_json(self):
return {'value': self.value, 'dimensions': self.dimensions}

def rename_key(d, prev_k, new_k):
return {k if k != prev_k else new_k: v for k, v in d.items()}

# Probably will be deprecated - in favour of separate flattening functions.
def get_choice(tr, *path):
"""
Retrieve a choice value from a trace using a list of keys.
Dimension instances are treated like ... but also return dimension info.
"""
choices = tr.get_choices()
dimensions = [rename_key(segment, ..., 'key') for segment in path if isinstance(segment, dict) and ... in segment]
path = [... if (isinstance(segment, dict) and ... in segment) else segment for segment in path]
lastSegment = path and path[-1]
leaf = lastSegment if isinstance(lastSegment, dict) and 'as' in lastSegment else None
path = path[:-1] if leaf else path
if isinstance(path[-1], set):
# If the last entry in address is a set, we want to retrieve multiple values
# from the choices and return them as a dictionary with keys from the set.
keys = path[-1]
value = choices[*path[:-1]]
value = {key: value[key] for key in keys}
else:
# If the last entry is not a set, proceed as normal
value = choices[*path]
def get_choice(ch, path):

ch = ch.get_sample() if getattr(ch, 'get_sample', None) else ch

if dimensions:
return Dimensioned(dimensions, leaf=leaf, value=value)
def _get(value, path):
if not path:
return value
segment = path[0]
if not isinstance(segment, dict):
return _get(value(segment), path[1:])
elif ... in segment:
v = value.get_value()
if hasattr(value, 'get_submap') and v is None:
v = value.get_submap(...)
return _get(v, path[1:])
elif 'leaves' in segment:
return value
else:
raise TypeError(f"Invalid path segment, expected ... or 'leaves' key, got {segment}")

value = _get(ch, path)
value = value.get_value() if hasattr(value, 'get_value') else value

if any(isinstance(elem, dict) for elem in path):
return Dimensioned(value, path)
else:
return value


# useful for simulating multi-dimensional GenJAX choicemap lookups.
def get_in(data, *path, toplevel=True):
value = data
for i, part in enumerate(path):
if isinstance(part, dict) and ... in part:
part = ...
if part == ...:
def _get(value, path):
if not path:
return value
segment = path[0]
if not isinstance(segment, dict):
return _get(value[segment], path[1:])
elif ... in segment:
if isinstance(value, list):
value = [get_in(sub_result, *path[i+1:], toplevel=False) for sub_result in value]
break
p = path[1:]
return [get_in(v, *p, toplevel=False) for v in value]
else:
raise TypeError(f"Expected list at path index {i}, got {type(value).__name__}")
elif isinstance(part, dict) and 'as' in part:
break
elif 'leaves' in segment:
return value
else:
value = value[part]

raise TypeError(f"Invalid path segment, expected ... or 'leaves' key, got {segment}")

value = _get(data, path)

if toplevel and any(isinstance(elem, dict) for elem in path):
return Dimensioned(value, path)
else:
return value


# Test case to verify traversal of more than one dimension
def test_get_in():
data = {
Expand All @@ -132,7 +148,7 @@ def test_get_in():
assert len(result.dimensions) == 2, f"Expected 2 dimensions, got {len(result.dimensions)}"
assert [d['key'] for d in result.dimensions] == ['first', 'second'], f"Expected dimension keys to be ['first', 'second'], got {[d['key'] for d in result.dimensions]}"

flattened = get_in(data, 'a', {...: 'first'}, 'b', {...: 'second'}, 'c', {'as': 'c'}).flatten()
flattened = get_in(data, 'a', {...: 'first'}, 'b', {...: 'second'}, 'c', {'leaves': 'c'}).flatten()
assert flattened == [
{'first': 0, 'second': 0, 'c': 1},
{'first': 0, 'second': 1, 'c': 2},
Expand Down
92 changes: 35 additions & 57 deletions src/gen/studio/widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ const el = (tag, props, ...children) => {
};

const flat = (data, dimensions) => {
let leaf;
if (typeof dimensions[dimensions.length - 1] === 'object' && 'as' in dimensions[dimensions.length - 1]) {
leaf = dimensions[dimensions.length - 1].as;
let leaves;
if (typeof dimensions[dimensions.length - 1] === 'object' && 'leaves' in dimensions[dimensions.length - 1]) {
leaves = dimensions[dimensions.length - 1]['leaves'];
dimensions = dimensions.slice(0, -1);
}

const _flat = (data, dim, prefix = null) => {
if (!dim.length) {
data = leaf ? { [leaf]: data } : data
data = leaves ? { [leaves]: data } : data
return prefix ? [{ ...prefix, ...data }] : [data];
}

Expand All @@ -104,33 +104,54 @@ class MarkSpec {
throw new Error(`Plot function "${name}" not found.`);
}

// unwrap data which includes dimension metadata
// Below is where we add functionality to Observable.Plot by preprocessing
// the data & options that are passed in.

// handle dimensional data passed in the 1st position
if (data.dimensions) {
options.dimensions = data.dimensions
data = data.value
}

// unwrap dimensional data
// flatten dimensional data
if (options.dimensions) {
options.dimensions = options.dimensions.map(dim => typeof dim === 'string' ? { 'key': dim } : dim);
data = flat(data, options.dimensions)
}

// handle columnar data in the 1st position
if (!Array.isArray(data) && !('length' in data)) {
let length = null
for (let [key, value] of Object.entries(data)) {
options[key] = value;
if (Array.isArray(value)) {
length = value.length
}
if (length === null) {
throw new Error("Invalid columnar data: at least one column must be an array.");
}
data = {length: value.length}
}

}

// handle facetWrap option (grid)
// see https://github.com/observablehq/plot/pull/892/files
if (options.grid) {
const [key, gridOpts] = options.grid
const { columns } = gridOpts
if (options.facetGrid) {
const facetGrid = (typeof options.facetGrid === 'string') ? [options.facetGrid, {}] : options.facetGrid
const [key, gridOpts] = facetGrid
const keys = Array.from(d3.union(data.map((d) => d[key])))
const index = new Map(keys.map((key, i) => [key, i]))
const columns = gridOpts.columns || Math.floor(Math.sqrt(keys.length))

const fx = (key) => index.get(key) % columns
const fy = (key) => Math.floor(index.get(key) / columns)
options.fx = (d) => fx(d[key])
options.fy = (d) => fy(d[key])
this.plotOptions = { fx: { axis: null }, fy: { axis: null } }
this.extraMarks.push(Plot.text(keys, {
fx, fy,
frameAnchor: "top",
frameAnchor: "top",
dy: 4
}))
}
Expand All @@ -139,60 +160,16 @@ class MarkSpec {
this.data = data;
this.options = options;
this.plotOptions = this.plotOptions || {};
this.format = Array.isArray(data) || 'length' in data ? 'array' : 'columnar';
}
}

function readMark(mark, dimensionState) {

if (!(mark instanceof MarkSpec)) {
return mark;
}
let { fn, data, options, format } = mark
let out
switch (format) {
case 'columnar':
// format columnar data for Observable.Plot;
// values go into the options map.
const formattedData = {};
for (let [key, value] of Object.entries(data)) {
if (value.dimensions) {
const dimensions = value.dimensions;
formattedData[key] = value.value;
for (const dimension of dimensions) {
if (!dimensionState.hasOwnProperty(dimension.key)) {
throw new Error(`Dimension state for ${dimension.key} is missing.`);
}
const i = dimensionState[dimension.key]
formattedData[key] = formattedData[key][i]
}
} else {
formattedData[key] = value;
}
}
out = fn({ length: Object.values(formattedData)[0].length }, { ...formattedData, ...options })
default:
out = fn(data, options)
}
return [out, ...mark.extraMarks]
}

function calculateDomain(values) {
let min = Infinity;
let max = -Infinity;

function findMinMax(value) {
if (Array.isArray(value)) {
for (const subValue of value) {
findMinMax(subValue); // Recursively find min and max in sub-arrays
}
} else {
if (value < min) min = value;
if (value > max) max = value;
}
}

findMinMax(values); // Start the recursive search with the initial array
return [min, max];
let { fn, data, options} = mark
return [fn(data, options), ...mark.extraMarks]
}

const repeat = (data) => {
Expand Down Expand Up @@ -275,6 +252,7 @@ function PlotView({ spec, splitState }) {
if (parent) {
const marks = spec.marks.flatMap((m) => readMark(m, splitState))
const plotOptions = spec.marks.reduce((acc, mark) => ({ ...acc, ...mark.plotOptions }), {});
console.log({ ...spec, ...plotOptions, marks: marks })
const plot = Plot.plot({ ...spec, ...plotOptions, marks: marks })
parent.appendChild(plot)
return () => parent.removeChild(plot)
Expand Down

0 comments on commit 79234d9

Please sign in to comment.