Skip to content

Commit

Permalink
Visualizations post ChiRho (#469)
Browse files Browse the repository at this point in the history
* Removing things from old locations.

* Moving visual components to new structure

* Moving tests and implementing in pytest (first draft)

* Parameterizing histogram test.

* Parameterizing schema-based tests.

* Extending results processing to label output timepoints (optionally).  Visualization tests depend on new version.

* check images and svg trouble-shooting

* trying to set seed on random data gen

* need to install additional files

* add packages to install requires

* correct modified schema saving location

* more informative error

* more informative error messages

* remove difflib

* try new png test

* remove spring graph

* tried recreating the environment

* test with just 100 dim for all images svg

* reference to false

* Only appending param/state/observable tag if not already present

* Trajectories test running from simulation (not stored data).

* Removing example trajectory files.

* add reference images created in docker

* don't create reference files

* Removing dependency on stored data-cube.

Histogram picks up label from series  (if present)

* gray scale

* jenson shannon threshold for png comparision

* add doc strings

* test mismach with test_plots, add white background when testing grayscale

* add JS score to message

* fix typo

* test with 100 bins

* threshold 0.02 for shannonjansen

* add map heatmap isocontour

* add data

* xmldiff-based check for approximate SVG output changes

* Cleaning up reference image.  Refining SVG and PNG comparisons.

* Reverting to append '_state' after observables (but won't do '_state_state' anymore)

* Removing debug helper notebook.

* Removing geomap+heatmap schema not ready for use (yet)

* Fixing list-type (for 3.8 compatibility)

* Enforcing sort-order for better result color-mapping stability

* Addressing mypy linting errors.

* More mypy fixes

* Fixing import and type complaint.

* Type checking passes for histogram.

* Fixing error around where multi-line titles would somtimes break the schema.

* Adding 'visual_options' back into the interface outputs.

* Fixing dict-type for 3.8

* Fixing import orders

* Reference image generating utility.

* Improving stability of trajectories schema by sorting input values.

Visual refinements for reference points (larger marker, dotted connector).

Multi-point example trace added.

* Import order fixes (two related to out-dated line-wrapping rules).

* Fixing formatting to black-defaults (as opposed to old VSCode black defaults)

* Changing trajectories example to use output of newer vl-convert-python.

* flake8 compliance

* Fixing typo of output option and makign cases consistent

* Ignoring a test-related output directory

---------

Co-authored-by: Oostrom, Marjolein T <marjolein.oostrom@pnnl.gov>
  • Loading branch information
JosephCottam and marjoleinpnnl authored Feb 7, 2024
1 parent f18fafc commit d89690f
Show file tree
Hide file tree
Showing 37 changed files with 4,173 additions and 5 deletions.
35 changes: 30 additions & 5 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,45 @@
from typing import Any, Dict
from typing import Any, Dict, Iterable, Optional, Union

import numpy as np
import pandas as pd
import torch

from pyciemss.visuals import plots


def prepare_interchange_dictionary(
samples: Dict[str, torch.Tensor],
time_unit: Optional[str] = None,
timepoints: Optional[Iterable[float]] = None,
visual_options: Union[None, bool, Dict[str, Any]] = None,
) -> Dict[str, Any]:
processed_samples = convert_to_output_format(samples)
processed_samples = convert_to_output_format(
samples, time_unit=time_unit, timepoints=timepoints
)

result = {"data": processed_samples, "unprocessed_result": samples}

if visual_options:
visual_options = {} if visual_options is True else visual_options
schema = plots.trajectories(processed_samples, **visual_options)
result["schema"] = schema

return result


def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame:
def convert_to_output_format(
samples: Dict[str, torch.Tensor],
*,
time_unit: Optional[str] = None,
timepoints: Optional[Iterable[float]] = None,
) -> pd.DataFrame:
"""
Convert the samples from the Pyro model to a DataFrame in the TA4 requested format.
"""

if time_unit is not None and timepoints is None:
raise ValueError("`timeponts` must be supplied when a `time_unit` is supplied")

pyciemss_results: Dict[str, Dict[str, torch.Tensor]] = {
"parameters": {},
"states": {},
Expand All @@ -29,12 +49,12 @@ def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame:
if sample.ndim == 1:
# Any 1D array is a sample from the distribution over parameters.
# Any 2D array is a sample from the distribution over states, unless it's a model weight.
name = name + "_param"
name = name + "_param" if not name.endswith("_param") else name
pyciemss_results["parameters"][name] = (
sample.data.detach().cpu().numpy().astype(np.float64)
)
else:
name = name + "_state"
name = name + "_state" if not (name.endswith("_state")) else name
pyciemss_results["states"][name] = (
sample.data.detach().cpu().numpy().astype(np.float64)
)
Expand Down Expand Up @@ -64,4 +84,9 @@ def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame:
}

result = pd.DataFrame(output)
if time_unit is not None and timepoints is not None:
timepoints = [*timepoints]
all_timepoints = result["timepoint_id"].map(lambda v: timepoints[v])
result = result.assign(**{f"timepoint_{time_unit}": all_timepoints})

return result
Empty file added pyciemss/visuals/__init__.py
Empty file.
128 changes: 128 additions & 0 deletions pyciemss/visuals/barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Optional

import matplotlib.tri as tri
import numpy as np
import pandas as pd
import torch
from pyro.distributions import Dirichlet

from . import vega


def triangle_weights(samples, concentration=20, subdiv=7):
# Adapted from https://blog.bogatron.net/blog/2014/02/02/visualizing-dirichlet-distributions/
# TODO: This method works...but it quite the monstrosity! Look into ways to simplify...

AREA = 0.5 * 1 * 0.75**0.5

def _tri_area(xy, pair):
return 0.5 * np.linalg.norm(np.cross(*(pair - xy)))

def _xy2bc(xy, tol=1.0e-4):
"""Converts 2D Cartesian coordinates to barycentric."""
coords = np.array([_tri_area(xy, p) for p in pairs]) / AREA
return np.clip(coords, tol, 1.0 - tol)

corners = np.array([[0, 0], [1, 0], [0.5, 0.75**0.5]])
triangle = tri.Triangulation(corners[:, 0], corners[:, 1])
refiner = tri.UniformTriRefiner(triangle)
trimesh = refiner.refine_triangulation(subdiv=subdiv)

# For each corner of the triangle, the pair of other corners
pairs = [corners[np.roll(range(3), -i)[1:]] for i in range(3)]
# The area of the triangle formed by point xy and another pair or points

# convert to coordinates with 3, rather than to points of reference for Direichlet input
points = torch.tensor(np.array([(_xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]))
points /= torch.sum(points, dim=1, keepdim=True)

alpha = samples * concentration
vals = torch.stack(
[
torch.exp(Dirichlet(alpha).log_prob(points[i, :]))
for i in range(points.shape[0])
]
)
vals /= torch.max(vals, dim=0, keepdim=True)[0]
vals = torch.sum(vals, dim=1)
vals /= torch.sum(vals)

coordinates_dict = {}

# skip every line as alternates half of each lines
y_num = 0
not_use_trimesh_y = []
for y in np.unique(trimesh.y):
y_num += 1
if y_num % 2 == 0:
not_use_trimesh_y.append(y)

df_coord = pd.DataFrame({"x": trimesh.x, "y": trimesh.y, "z": vals.tolist()})
not_use_trimesh_x = list(
np.unique(df_coord[df_coord.y == not_use_trimesh_y[0]]["x"].tolist())
)

# save all existing coordinates
for x, y, z in zip(trimesh.x, trimesh.y, vals):
coordinates_dict[(x, y)] = z.item()

# fill in missing part of square grid
for x in np.unique(trimesh.x):
for y in np.unique(trimesh.y):
if (x, y) not in coordinates_dict.keys():
coordinates_dict[x, y] = 0

# convert to dataframe and sort with y first in descending order
df = pd.DataFrame(coordinates_dict.items(), columns=["x,y", "val"])
df[["x", "y"]] = pd.DataFrame(df["x,y"].tolist(), index=df.index)
df = df.sort_values(["y", "x"], ascending=[False, True])

# remove the alternative values, (every other y and all the values associated with that y)
df_use = df[(~df.x.isin(not_use_trimesh_x)) & (~df.y.isin(not_use_trimesh_y))]

json_dict = {}
json_dict["width"] = len(np.unique(df_use.x))
json_dict["height"] = len(np.unique(df_use.y))
json_dict["values"] = df_use["val"].tolist()

return json_dict


def triangle_contour(
data: pd.DataFrame, *, title: Optional[str] = None, contour: bool = True
) -> vega.VegaSchema:
"""Create a contour plot from the passed datasource.
datasource --
* filename: File to load data from that will be loaded via vega's "url" facility
Path should be relative to the running file-server, as they will be
resolved in that context. If in a notebook, it is relative to the notebook
(not the root notebook server processes).
* dataframe: A dataframe ready for rendering. The data will be inserted into the schema
as a record-oriented dictionary.
kwargs -- If passing filename, extra parameters to the vega's url facility
"""
mesh_data = triangle_weights(data)

schema = vega.load_schema("barycenter_triangle.vg.json")
schema["data"] = vega.replace_named_with(
schema["data"],
"contributions",
["values"],
mesh_data,
)

if title:
schema = vega.set_title(schema, title)

if not contour:
contours = vega.find_keyed(schema["marks"], "name", "_contours")
contours["encode"]["enter"]["stroke"] = {
"scale": "color",
"field": "contour.value",
}

return schema
28 changes: 28 additions & 0 deletions pyciemss/visuals/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd

from . import vega


def calibration(datasource: pd.DataFrame) -> vega.VegaSchema:
"""Create a contour plot from the passed datasource.
datasource -- A dataframe ready for rendering. Should include:
- time (int)
- column_names (str)
- calibration (bool)
- y --- Will be shown as a line
- y1 --- Upper range of values
- y0 --- Lower range of values
"""
schema = vega.load_schema("calibrate_chart.vg.json")

data = vega.find_keyed(schema["data"], "name", "table")
del data["url"]
data["values"] = datasource.to_dict(orient="records")

options = sorted(datasource["column_names"].unique().tolist())
var_filter = vega.find_keyed(schema["signals"], "name", "Variable")
var_filter["bind"]["options"] = options
var_filter["value"] = options[0]

return schema
Loading

0 comments on commit d89690f

Please sign in to comment.