Skip to content

Commit

Permalink
Fixes #95
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Oct 16, 2023
1 parent e990cd7 commit 6d3fb5a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 4 deletions.
28 changes: 28 additions & 0 deletions docs/examples/plot_6_custom_axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
# Custom Axes
A lot of the time you might have your own plots ready to go.
In this case, you can manually invoke ChainConsumer's plotting functions.
Here's an example, noting that there are also `plot_point`, `plot_surface` available
that I haven't explicitly shown.
"""
import matplotlib.pyplot as plt

from chainconsumer import Chain, Truth, make_sample
from chainconsumer.plotting import plot_contour, plot_truths

# %%

df = make_sample(num_dimensions=2, seed=1)

# Custom plotting code
fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
axes[0].hist(df["A"], bins=50, color="skyblue", density=True)

# We can use ChainConsumer to plot a truth value on top of this histogram
truth = Truth(location={"A": 0, "B": 5}, line_style=":")
plot_truths(axes[0], [truth], px="A")
# And also add a contour to the other axis
plot_contour(axes[1], Chain(samples=df, name="Example"), px="A", py="B")
3 changes: 1 addition & 2 deletions src/chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from .chain import Chain, ChainName, ColumnName
from .color_finder import colors
from .helpers import get_bins, get_extents, get_grid_bins, get_smoothed_bins
from .plotting import add_watermark, plot_surface
from .plotting.config import PlotConfig
from .plotting.contours import plot_surface
from .plotting.watermark import add_watermark


class PlottingBase(BetterBase):
Expand Down
12 changes: 12 additions & 0 deletions src/chainconsumer/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .contours import plot_cloud, plot_contour, plot_point, plot_surface
from .truth import plot_truths
from .watermark import add_watermark

__all__ = [
"add_watermark",
"plot_cloud",
"plot_contour",
"plot_point",
"plot_surface",
"plot_truths",
]
8 changes: 6 additions & 2 deletions src/chainconsumer/plotting/contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ def plot_surface(
chains: list[Chain],
px: ColumnName,
py: ColumnName,
config: PlotConfig,
config: PlotConfig | None = None,
) -> dict[ColumnName, PathCollection]:
"""Plot the chains onto a 2D surface, using clouds, contours and points.
Returns:
A map from column name to paths to be added as colorbars.
"""
if config is None:
config = PlotConfig()
paths: dict[ColumnName, PathCollection] = {}
for chain in chains:
if px not in chain.plotting_columns or py not in chain.plotting_columns:
Expand Down Expand Up @@ -61,7 +63,7 @@ def plot_cloud(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName) -> dict[C
return {}


def plot_contour(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName, config: PlotConfig) -> None:
def plot_contour(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName, config: PlotConfig | None = None) -> None:
"""A lightweight method to plot contours in an external axis given two specified parameters
Args:
Expand All @@ -70,6 +72,8 @@ def plot_contour(ax: Axes, chain: Chain, px: ColumnName, py: ColumnName, config:
px: The parameter to plot on the x axis
py: The parameter to plot on the y axis
"""
if config is None:
config = PlotConfig()
levels = _get_levels(chain.sigmas, config)
contour_colours = _scale_colours(colors.format(chain.color), len(levels), chain.shade_gradient)
sub = max(0.1, 1 - 0.2 * chain.shade_gradient)
Expand Down

0 comments on commit 6d3fb5a

Please sign in to comment.