diff --git a/plugins/plotly-express/src/deephaven/plot/express/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/__init__.py
index cca071669..6a4694442 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/__init__.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/__init__.py
@@ -41,6 +41,7 @@
density_mapbox,
line_geo,
line_mapbox,
+ density_heatmap,
)
from .data import data_generators
diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py
index ba8e38695..0e019f5e4 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py
@@ -3,5 +3,5 @@
DeephavenFigureNode,
)
from .generate import generate_figure, update_traces
-from .custom_draw import draw_ohlc, draw_candlestick
+from .custom_draw import draw_ohlc, draw_candlestick, draw_density_heatmap
from .RevisionManager import RevisionManager
diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py
index 21330a98b..4d822dd55 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py
@@ -6,6 +6,7 @@
from pandas import DataFrame
import plotly.graph_objects as go
from plotly.graph_objects import Figure
+from plotly.validators.heatmap import ColorscaleValidator
def draw_finance(
@@ -109,3 +110,79 @@ def draw_candlestick(
"""
return draw_finance(data_frame, x_finance, open, high, low, close, go.Candlestick)
+
+
+def draw_density_heatmap(
+ data_frame: DataFrame,
+ x: str,
+ y: str,
+ z: str,
+ labels: dict[str, str] | None = None,
+ range_color: list[float] | None = None,
+ color_continuous_scale: str | list[str] | None = "plasma",
+ color_continuous_midpoint: float | None = None,
+ opacity: float = 1.0,
+ title: str | None = None,
+ template: str | None = None,
+) -> Figure:
+ """Create a density heatmap
+
+ Args:
+ data_frame: The data frame to draw with
+ x: The name of the column containing x-axis values
+ y: The name of the column containing y-axis values
+ z: The name of the column containing bin values
+ labels: A dictionary of labels mapping columns to new labels
+ color_continuous_scale: A color scale or list of colors for a continuous scale
+ range_color: A list of two numbers that form the endpoints of the color axis
+ color_continuous_midpoint: A number that is the midpoint of the color axis
+ opacity: Opacity to apply to all markers. 0 is completely transparent
+ and 1 is completely opaque.
+ title: The title of the chart
+ template: The template for the chart.
+
+ Returns:
+ The plotly density heatmap
+
+ """
+
+ # currently, most plots rely on px setting several attributes such as coloraxis, opacity, etc.
+ # so we need to set some things manually
+ # this could be done with handle_custom_args in generate.py in the future if
+ # we need to provide more options, but it's much easier to just set it here
+ # and doesn't risk breaking any other plots
+
+ heatmap = go.Figure(
+ go.Heatmap(
+ x=data_frame[x],
+ y=data_frame[y],
+ z=data_frame[z],
+ coloraxis="coloraxis1",
+ opacity=opacity,
+ )
+ )
+
+ range_color_list = range_color or [None, None]
+
+ colorscale_validator = ColorscaleValidator("colorscale", "draw_density_heatmap")
+
+ coloraxis_layout = dict(
+ colorscale=colorscale_validator.validate_coerce(color_continuous_scale),
+ cmid=color_continuous_midpoint,
+ cmin=range_color_list[0],
+ cmax=range_color_list[1],
+ )
+
+ if labels:
+ x = labels.get(x, x)
+ y = labels.get(y, y)
+
+ heatmap.update_layout(
+ coloraxis1=coloraxis_layout,
+ title=title,
+ template=template,
+ xaxis_title=x,
+ yaxis_title=y,
+ )
+
+ return heatmap
diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py
index 7b174ef9b..63d20ad55 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py
@@ -116,6 +116,7 @@
"current_partition",
"colors",
"unsafe_update_figure",
+ "heatmap_agg_label",
}
# these are columns that are "attached" sequentially to the traces
@@ -669,8 +670,9 @@ def handle_custom_args(
elif arg == "bargap" or arg == "rangemode":
fig.update_layout({arg: val})
- # x_axis_generators.append(key_val_generator("bargap", [val]))
- # y_axis_generators.append(key_val_generator("bargap", [val]))
+
+ elif arg == "heatmap_agg_label":
+ fig.update_coloraxes(colorbar_title_text=val)
trace_generator = combined_generator(trace_generators)
@@ -781,7 +783,6 @@ def get_hover_body(
def hover_text_generator(
hover_mapping: list[dict[str, str]],
- # hover_data - todo, dependent on arrays supported in data mappings
types: set[str] | None = None,
current_partition: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
@@ -824,6 +825,7 @@ def hover_text_generator(
def compute_labels(
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
+ heatmap_agg_label: str | None,
# hover_data - todo, dependent on arrays supported in data mappings
types: set[str],
labels: dict[str, str] | None,
@@ -836,6 +838,7 @@ def compute_labels(
Args:
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
+ heatmap_agg_label: The aggregate density heatmap column title
types: Any types of this chart that require special processing
labels: A dictionary of old column name to new column name mappings
current_partition: The columns that this figure is partitioned by
@@ -846,9 +849,36 @@ def compute_labels(
calculate_hist_labels(hist_val_name, hover_mapping[0])
+ calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels)
+
relabel_columns(labels, hover_mapping, types, current_partition)
+def calculate_density_heatmap_labels(
+ heatmap_agg_label: str | None,
+ hover_mapping: dict[str, str],
+ labels: dict[str, str] | None,
+) -> None:
+ """Calculate the labels for a density heatmap
+ The z column is renamed to the heatmap_agg_label
+
+ Args:
+ heatmap_agg_label: The name of the heatmap aggregate label
+ hover_mapping: The mapping of variables to columns
+ labels: A dictionary of labels mapping columns to new labels.
+
+ """
+ labels = labels or {}
+ if heatmap_agg_label:
+ # the last part of the label is the z column, and could be replaced by labels
+ split_label = heatmap_agg_label.split(" ")
+ split_label[-1] = labels.get(split_label[-1], split_label[-1])
+ # it's also possible that someone wants to override the whole label
+ # plotly doesn't seem to do that, but it seems reasonable to allow
+ new_label = " ".join(split_label)
+ hover_mapping["z"] = labels.get(new_label, new_label)
+
+
def calculate_hist_labels(
hist_val_name: str | None, current_mapping: dict[str, str]
) -> None:
@@ -871,6 +901,7 @@ def add_axis_titles(
custom_call_args: dict[str, Any],
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
+ heatmap_agg_label: str | None,
) -> None:
"""Add axis titles. Generally, this only applies when there is a list variable
@@ -879,6 +910,7 @@ def add_axis_titles(
create hover and axis titles
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
+ heatmap_agg_label: The aggregate density heatmap column title
"""
# Although hovertext is handled above for all plot types, plotly still
@@ -892,6 +924,9 @@ def add_axis_titles(
new_xaxis_titles = [hover_mapping[0].get("x", None)]
new_yaxis_titles = [hover_mapping[0].get("y", None)]
+ if heatmap_agg_label:
+ custom_call_args["heatmap_agg_label"] = heatmap_agg_label
+
# a specified axis title update should override this
if new_xaxis_titles:
custom_call_args["xaxis_titles"] = custom_call_args.get(
@@ -928,6 +963,9 @@ def create_hover_and_axis_titles(
(such as the y-axis if the x-axis is specified). Otherwise, there is a
legend or not depending on if there is a list of columns or not.
+ Density heatmaps are also an exception. If "heatmap_agg_label" is specified,
+ the z column is renamed to this label.
+
Args:
custom_call_args: The custom_call_args that are used to
create hover and axis titles
@@ -941,14 +979,26 @@ def create_hover_and_axis_titles(
labels = custom_call_args.get("labels", None)
hist_val_name = custom_call_args.get("hist_val_name", None)
+ heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None)
current_partition = custom_call_args.get("current_partition", {})
- compute_labels(hover_mapping, hist_val_name, types, labels, current_partition)
+ compute_labels(
+ hover_mapping,
+ hist_val_name,
+ heatmap_agg_label,
+ types,
+ labels,
+ current_partition,
+ )
hover_text = hover_text_generator(hover_mapping, types, current_partition)
- add_axis_titles(custom_call_args, hover_mapping, hist_val_name)
+ if heatmap_agg_label:
+ # it's possible that heatmap_agg_label was relabeled, so grab the new label
+ heatmap_agg_label = hover_mapping[0]["z"]
+
+ add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label)
return hover_text
diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py
index 2377d0bfa..5d4cd4a99 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py
@@ -581,6 +581,7 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]:
"preprocess_hist" in self.groups
or "preprocess_freq" in self.groups
or "preprocess_time" in self.groups
+ or "preprocess_heatmap" in self.groups
) and self.preprocessor:
# still need to preprocess the base table
table, arg_update = cast(
diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py
index 8558d8a14..f87c84086 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py
@@ -9,3 +9,4 @@
from ._layer import layer
from .subplots import make_subplots
from .maps import scatter_geo, scatter_mapbox, density_mapbox, line_geo, line_mapbox
+from .heatmap import density_heatmap
diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py
index a3f902977..47108eff3 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py
@@ -409,8 +409,8 @@ def histogram(
range_y: A list of two numbers that specify the range of the y-axis.
range_bins: A list of two numbers that specify the range of data that is used.
histfunc: The function to use when aggregating within bins. One of
- 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', 'sum',
- or 'var'
+ 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std',
+ 'sum', or 'var'
cumulative: If True, values are cumulative.
nbins: The number of bins to use.
text_auto: If True, display the value at each bar.
diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py
new file mode 100644
index 000000000..ba40fc27f
--- /dev/null
+++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+from typing import Callable, Literal
+
+from deephaven.plot.express.shared import default_callback
+
+from ._private_utils import process_args
+from ..deephaven_figure import DeephavenFigure, draw_density_heatmap
+from deephaven.table import Table
+
+
+def density_heatmap(
+ table: Table,
+ x: str | None = None,
+ y: str | None = None,
+ z: str | None = None,
+ labels: dict[str, str] | None = None,
+ color_continuous_scale: str | list[str] | None = None,
+ range_color: list[float] | None = None,
+ color_continuous_midpoint: float | None = None,
+ opacity: float = 1.0,
+ log_x: bool = False,
+ log_y: bool = False,
+ range_x: list[float] | None = None,
+ range_y: list[float] | None = None,
+ range_bins_x: list[float | None] | None = None,
+ range_bins_y: list[float | None] | None = None,
+ histfunc: str = "count",
+ nbinsx: int = 10,
+ nbinsy: int = 10,
+ empty_bin_default: float | Literal["NaN"] | None = None,
+ title: str | None = None,
+ template: str | None = None,
+ unsafe_update_figure: Callable = default_callback,
+) -> DeephavenFigure:
+ """
+ A density heatmap creates a grid of colored bins. Each bin represents an aggregation of data points in that region.
+
+ Args:
+ table: A table to pull data from.
+ x: A column that contains x-axis values.
+ y: A column that contains y-axis values.
+ z: A column that contains z-axis values. If not provided, the count of joint occurrences of x and y will be used.
+ labels: A dictionary of labels mapping columns to new labels.
+ color_continuous_scale: A color scale or list of colors for a continuous scale
+ range_color: A list of two numbers that form the endpoints of the color axis
+ color_continuous_midpoint: A number that is the midpoint of the color axis
+ opacity: Opacity to apply to all markers. 0 is completely transparent
+ and 1 is completely opaque.
+ log_x: A boolean that specifies if the corresponding axis is a log axis or not.
+ log_y: A boolean that specifies if the corresponding axis is a log axis or not.
+ range_x: A list of two numbers that specify the range of the x axes.
+ None can be specified for no range
+ range_y: A list of two numbers that specify the range of the y axes.
+ None can be specified for no range
+ range_bins_x: A list of two numbers that specify the range of data that is used for x.
+ None can be specified to use the min and max of the data.
+ None can also be specified for either of the list values to use the min or max of the data, respectively.
+ range_bins_y: A list of two numbers that specify the range of data that is used for y.
+ None can be specified to use the min and max of the data.
+ None can also be specified for either of the list values to use the min or max of the data, respectively.
+ histfunc: The function to use when aggregating within bins. One of
+ 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std',
+ 'sum', or 'var'
+ nbinsx: The number of bins to use for the x-axis
+ nbinsy: The number of bins to use for the y-axis
+ empty_bin_default: The value to use for bins that have no data.
+ If None and histfunc is 'count' or 'count_distinct', 0 is used.
+ Otherwise, if None or 'NaN', NaN is used.
+ 'NaN' forces the bin to be NaN if no data is present, even if histfunc is 'count' or 'count_distinct'.
+ Note that if multiple points are required to color a bin, such as the case for a histfunc of 'std' or var,
+ the bin will still be NaN if less than the required number of points are present.
+ title: The title of the chart
+ template: The template for the chart.
+ unsafe_update_figure: An update function that takes a plotly figure
+ as an argument and optionally returns a plotly figure. If a figure is
+ not returned, the plotly figure passed will be assumed to be the return
+ value. Used to add any custom changes to the underlying plotly figure.
+ Note that the existing data traces should not be removed. This may lead
+ to unexpected behavior if traces are modified in a way that break data
+ mappings.
+
+
+
+ Returns:
+ DeephavenFigure: A DeephavenFigure that contains the density heatmap
+ """
+ args = locals()
+
+ return process_args(args, {"preprocess_heatmap"}, px_func=draw_density_heatmap)
diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py
new file mode 100644
index 000000000..241dc44cb
--- /dev/null
+++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py
@@ -0,0 +1,149 @@
+from __future__ import annotations
+
+from typing import Any, Generator
+
+from deephaven import new_table
+from deephaven.column import long_col
+
+from ..shared import get_unique_names
+from .utilities import (
+ create_range_table,
+ validate_heatmap_histfunc,
+ create_tmp_view,
+ aggregate_heatmap_bins,
+ calculate_bin_locations,
+)
+from deephaven.table import Table
+
+
+class HeatmapPreprocessor:
+ """
+ Preprocessor for heatmaps.
+
+ Attributes:
+ args: dict[str, Any]: The arguments used to create the plot
+ histfunc: str: The histfunc to use
+ nbinsx: int: The number of bins in the x direction
+ nbinsy: int: The number of bins in the y direction
+ range_bins_x: list[float | None]: The range of the x bins
+ range_bins_y: list[float | None]: The range of the y bins
+ names: dict[str, str]: A mapping of ideal name to unique names
+ Also contains the names of the x, y, and z columns for ease of use
+ """
+
+ def __init__(self, args: dict[str, Any]):
+ self.args = args
+ self.histfunc = args.pop("histfunc")
+ self.nbinsx = args.pop("nbinsx")
+ self.nbinsy = args.pop("nbinsy")
+ self.range_bins_x = args.pop("range_bins_x")
+ self.range_bins_y = args.pop("range_bins_y")
+ self.empty_bin_default = args.pop("empty_bin_default")
+ if (
+ self.histfunc in {"count", "count_distinct"}
+ and self.empty_bin_default is None
+ ):
+ self.empty_bin_default = 0
+ # create unique names for the columns to ensure no collisions
+ self.names = get_unique_names(
+ self.args["table"],
+ [
+ "range_index_x",
+ "range_index_y",
+ "range_x",
+ "range_y",
+ "bin_min_x",
+ "bin_max_x",
+ "bin_min_y",
+ "bin_max_y",
+ "tmp_x",
+ "tmp_y",
+ "agg_col",
+ self.histfunc,
+ ],
+ )
+
+ # add the column names to names as well for ease of use
+ self.names.update(
+ {
+ "x": self.args["x"],
+ "y": self.args["y"],
+ "z": self.args["z"],
+ }
+ )
+
+ def preprocess_partitioned_tables(
+ self, tables: list[Table], column: str | None = None
+ ) -> Generator[tuple[Table, dict[str, str | None]], None, None]:
+ """
+ Preprocess params into an appropriate table
+
+ Args:
+ tables: a list of tables to preprocess
+ column: the column to aggregate on
+ ignored for this preprocessor because heatmap always gets the joint count
+ distribution of x and y or the histfunc of z depending on if z is provided
+
+ Returns:
+ A tuple containing (the new table, an update to make to the args)
+ The update should contain the z column name and the heatmap_agg_label
+ which is the histfunc of z if z is not None, otherwise just the histfunc
+
+ """
+
+ range_index_x = self.names["range_index_x"]
+ range_index_y = self.names["range_index_y"]
+ range_x = self.names["range_x"]
+ range_y = self.names["range_y"]
+ histfunc_col = self.names[self.histfunc]
+ x = self.names["x"]
+ y = self.names["y"]
+ z = self.names["z"]
+
+ validate_heatmap_histfunc(z, self.histfunc)
+
+ # there will only be one table, so we can just grab the first one
+ table = tables[0]
+
+ range_table_x = create_range_table(
+ table, x, self.range_bins_x, self.nbinsx, range_x
+ )
+ range_table_y = create_range_table(
+ table, y, self.range_bins_y, self.nbinsy, range_y
+ )
+ range_table = range_table_x.join(range_table_y)
+
+ # ensure that all possible bins are created so that the rendered chart draws spaces for empty bins
+ bin_counts_x = new_table(
+ [long_col(range_index_x, [i for i in range(self.nbinsx)])]
+ )
+ bin_counts_y = new_table(
+ [long_col(range_index_y, [i for i in range(self.nbinsy)])]
+ )
+ bin_counts = bin_counts_x.join(bin_counts_y)
+
+ tmp_view = create_tmp_view(self.names)
+
+ # filter to only the tmp (data) columns, and join the range table to the tmp
+ ranged_tmp_view = table.view(tmp_view).join(range_table)
+
+ agg_table = aggregate_heatmap_bins(ranged_tmp_view, self.names, self.histfunc)
+
+ # join the aggregated values to the already created comprehensive bin table
+ bin_counts = bin_counts.natural_join(
+ agg_table, on=[range_index_x, range_index_y], joins=[self.names["agg_col"]]
+ )
+
+ # join the range table to the bin counts - this is needed because the ranges were dropped in the aggregation
+ ranged_bin_counts = bin_counts.join(range_table)
+
+ bin_counts_with_midpoint = calculate_bin_locations(
+ ranged_bin_counts, self.names, histfunc_col, self.empty_bin_default
+ )
+
+ heatmap_agg_label = f"{self.histfunc} of {z}" if z else self.histfunc
+
+ yield bin_counts_with_midpoint.view([x, y, histfunc_col]), {
+ "z": histfunc_col,
+ "heatmap_agg_label": heatmap_agg_label,
+ }
diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py
index a285185c8..197b0b73d 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py
@@ -9,19 +9,7 @@
from ..shared import get_unique_names
from deephaven.column import long_col
from deephaven.updateby import cum_sum
-
-# Used to aggregate within histogram bins
-HISTFUNC_MAP = {
- "avg": agg.avg,
- "count": agg.count_,
- "count_distinct": agg.count_distinct,
- "max": agg.max_,
- "median": agg.median,
- "min": agg.min_,
- "std": agg.std,
- "sum": agg.sum_,
- "var": agg.var,
-}
+from .utilities import create_range_table, HISTFUNC_AGGS
def get_aggs(
@@ -83,42 +71,14 @@ def prepare_preprocess(self) -> None:
self.args["table"],
["range_index", "range", "bin_min", "bin_max", self.histfunc, "total"],
)
- self.range_table = self.create_range_table()
-
- def create_range_table(self) -> Table:
- """
- Create a table that contains the bin ranges
-
- Returns:
- A table containing the bin ranges
- """
- # partitioned tables need range calculated on all
- table = (
- self.table.merge()
- if isinstance(self.table, PartitionedTable)
- else self.table
+ self.range_table = create_range_table(
+ self.args["table"],
+ self.cols,
+ self.range_bins,
+ self.nbins,
+ self.names["range"],
)
- if self.range_bins:
- range_min = self.range_bins[0]
- range_max = self.range_bins[1]
- table = empty_table(1)
- else:
- range_min = "RangeMin"
- range_max = "RangeMax"
- # need to find range across all columns
- min_aggs, min_cols = get_aggs("RangeMin", self.cols)
- max_aggs, max_cols = get_aggs("RangeMax", self.cols)
- table = table.agg_by([agg.min_(min_aggs), agg.max_(max_aggs)]).update(
- [f"RangeMin = min({min_cols})", f"RangeMax = max({max_cols})"]
- )
-
- return table.update(
- f"{self.names['range']} = new io.deephaven.plot.datasets.histogram."
- f"DiscretizedRangeEqual({range_min},{range_max}, "
- f"{self.nbins})"
- ).view(self.names["range"])
-
def create_count_tables(
self, tables: list[Table], column: str | None = None
) -> Generator[tuple[Table, str], None, None]:
@@ -134,7 +94,7 @@ def create_count_tables(
"""
range_index, range_ = self.names["range_index"], self.names["range"]
- agg_func = HISTFUNC_MAP[self.histfunc]
+ agg_func = HISTFUNC_AGGS[self.histfunc]
if not self.range_table:
raise ValueError("Range table not created")
for i, table in enumerate(tables):
diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py
index 808c5a75f..1cba930a0 100644
--- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py
+++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py
@@ -9,6 +9,7 @@
from .FreqPreprocessor import FreqPreprocessor
from .HistPreprocessor import HistPreprocessor
from .TimePreprocessor import TimePreprocessor
+from .HeatmapPreprocessor import HeatmapPreprocessor
class Preprocessor:
@@ -54,6 +55,8 @@ def prepare_preprocess(self) -> None:
AttachedPreprocessor(self.args, self.always_attached)
elif "preprocess_time" in self.groups:
self.preprocesser = TimePreprocessor(self.args)
+ elif "preprocess_heatmap" in self.groups:
+ self.preprocesser = HeatmapPreprocessor(self.args)
def preprocess_partitioned_tables(
self, tables: list[Table] | None, column: str | None = None
diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py
new file mode 100644
index 000000000..4c6ac80aa
--- /dev/null
+++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py
@@ -0,0 +1,283 @@
+from __future__ import annotations
+
+from typing import Generator, Literal
+
+from deephaven import agg, empty_table
+from deephaven.plot.express.shared import get_unique_names
+from deephaven.table import PartitionedTable, Table
+
+# Used to aggregate within bins
+HISTFUNC_AGGS = {
+ "abs_sum": agg.abs_sum,
+ "avg": agg.avg,
+ "count": agg.count_,
+ "count_distinct": agg.count_distinct,
+ "max": agg.max_,
+ "median": agg.median,
+ "min": agg.min_,
+ "std": agg.std,
+ "sum": agg.sum_,
+ "var": agg.var,
+}
+
+
+def get_aggs(
+ base: str,
+ columns: list[str],
+) -> tuple[list[str], str]:
+ """Create aggregations over all columns
+
+ Args:
+ base:
+ The base of the new columns that store the agg per column
+ columns:
+ All columns joined for the sake of taking min or max over
+ the columns
+
+ Returns:
+ A tuple containing (a list of the new columns,
+ a joined string of "NewCol, NewCol2...")
+
+ """
+ return (
+ [f"{base}{column}={column}" for column in columns],
+ ", ".join([f"{base}{column}" for column in columns]),
+ )
+
+
+def single_table(table: Table | PartitionedTable) -> Table:
+ """
+ Merge a table if it is partitioned table
+
+ Args:
+ table: The table to merge
+
+ Returns:
+ The table if it is not a partitioned table, otherwise the merged table
+ """
+ return table.merge() if isinstance(table, PartitionedTable) else table
+
+
+def discretized_range_view(
+ table: Table,
+ range_min: float | str,
+ range_max: float | str,
+ nbins: int,
+ range_name: str,
+) -> Table:
+ """
+ Create a discretized range view that can be joined with a table to compute indices
+
+ Args:
+ table: The table to create the range view from
+ range_min: The minimum value of the range. Can be a number or a column name
+ range_max: The maximum value of the range. Can be a number or a column name
+ nbins: The number of bins to create
+ range_name: The name of the range object in the resulting table
+
+ Returns:
+ A table that contains the range object for the given table
+ """
+
+ return table.update(
+ f"{range_name} = new io.deephaven.plot.datasets.histogram."
+ f"DiscretizedRangeEqual({range_min},{range_max}, "
+ f"{nbins})"
+ ).view(range_name)
+
+
+def create_range_table(
+ table: Table,
+ cols: str | list[str],
+ range_bins: list[float | None] | None,
+ nbins: int,
+ range_name: str,
+) -> Table:
+ """
+ Create single row tables with range objects that can compute bin membership
+
+ Args:
+ table: The table to create the range table from
+ cols: The columns to create the range table from. The resulting range table will have
+ its range calculated over all of these columns.
+ range_bins: The range to create the bins over.
+ If None, the range will be calculated over the columns.
+ If a list of two numbers, the range will be set to these numbers.
+ The values within this list can also be None, in which case the range will be calculated over the columns
+ for whichever value is None.
+ nbins: The number of bins to create
+ range_name: The name of the range object in the resulting table
+
+ Returns:
+ A table that contains the range object for the given
+ """
+
+ cols = [cols] if isinstance(cols, str) else cols
+
+ range_min = (
+ range_bins[0] if range_bins and range_bins[0] is not None else "RangeMin"
+ )
+ range_max = (
+ range_bins[1] if range_bins and range_bins[1] is not None else "RangeMax"
+ )
+
+ min_table = empty_table(1)
+ max_table = empty_table(1)
+
+ if range_min == "RangeMin":
+ min_aggs, min_cols = get_aggs("RangeMin", cols)
+ min_table = table.agg_by([agg.min_(min_aggs)]).update(
+ [f"RangeMin = min({min_cols})"]
+ )
+ if range_max == "RangeMax":
+ max_aggs, max_cols = get_aggs("RangeMax", cols)
+ max_table = table.agg_by([agg.max_(max_aggs)]).update(
+ [f"RangeMax = max({max_cols})"]
+ )
+
+ return discretized_range_view(
+ min_table.join(max_table), range_min, range_max, nbins, range_name
+ )
+
+
+def validate_heatmap_histfunc(z: str | None, histfunc: str) -> None:
+ """
+ Check if the histfunc is valid
+
+ Args:
+ z: The column that contains z-axis values.
+ histfunc: The function to use when aggregating within bins. Should be 'count' if z is None.
+
+ Raises:
+ ValueError: If the histfunc is not valid
+ """
+ if z is None and histfunc != "count":
+ raise ValueError("z must be specified for histfunc other than count")
+ elif histfunc not in HISTFUNC_AGGS:
+ raise ValueError(f"{histfunc} is not a valid histfunc")
+
+
+def create_tmp_view(
+ names: dict[str, str],
+) -> list[str]:
+ """
+ Create a temporary view that avoids column name collisions
+
+ Args:
+ names: The names used for columns so that they don't collide
+
+ Returns:
+ A list of strings that are used to create a temporary view
+ """
+
+ x = names["x"]
+ y = names["y"]
+ z = names["z"]
+ tmp_x = names["tmp_x"]
+ tmp_y = names["tmp_y"]
+ agg_col = names["agg_col"]
+
+ tmp_view = [f"{tmp_x} = {x}", f"{tmp_y} = {y}"]
+
+ if z is not None:
+ tmp_view.append(f"{agg_col} = {z}")
+ else:
+ # if z is not specified, just count the number of occurrences, so tmp_x or tmp_y can be used
+ names["agg_col"] = tmp_x
+
+ return tmp_view
+
+
+def aggregate_heatmap_bins(
+ table: Table,
+ names: dict[str, str],
+ histfunc: str,
+) -> Table:
+ """
+ Create count tables that aggregate up values into bins
+
+ Args:
+ table: The table to aggregate. Should contain the tmp data columns and the range columns
+ names: The names used for columns so that they don't collide
+ histfunc: The function to use when aggregating within bins. Should be 'count' if z is None.
+
+ Yields:
+ A tuple containing the table and a temporary column that contains the aggregated values
+ """
+
+ range_x = names["range_x"]
+ range_y = names["range_y"]
+ range_index_x = names["range_index_x"]
+ range_index_y = names["range_index_y"]
+
+ tmp_x = names["tmp_x"]
+ tmp_y = names["tmp_y"]
+ agg_col = names["agg_col"]
+
+ count_table = (
+ table.update_view(
+ [
+ f"{range_index_x} = {range_x}.index({tmp_x})",
+ f"{range_index_y} = {range_y}.index({tmp_y})",
+ ]
+ )
+ .where([f"!isNull({range_index_x})", f"!isNull({range_index_y})"])
+ .agg_by([HISTFUNC_AGGS[histfunc](agg_col)], [range_index_x, range_index_y])
+ )
+ return count_table
+
+
+def calculate_bin_locations(
+ ranged_bin_counts: Table,
+ names: dict[str, str],
+ histfunc_col: str,
+ empty_bin_default: float | Literal["NaN"] | None,
+) -> Table:
+ """
+ Compute the center of the bins for the x and y axes
+ plotly requires the center of the bins to plot the heatmap and will calculate the width automatically
+
+ Args
+ bin_counts_ranged: A table that contains the bin counts and the range columns
+ names: The names used for columns so that they don't collide
+ histfunc_col: The column that contains the aggregated values
+ empty_bin_default: The default value to use for bins that have no data
+
+ Returns:
+ A table that contains the bin counts and the center of the bins
+ """
+ range_index_x = names["range_index_x"]
+ range_index_y = names["range_index_y"]
+ range_x = names["range_x"]
+ range_y = names["range_y"]
+ bin_min_x = names["bin_min_x"]
+ bin_max_x = names["bin_max_x"]
+ bin_min_y = names["bin_min_y"]
+ bin_max_y = names["bin_max_y"]
+ x = names["x"]
+ y = names["y"]
+ agg_col = names["agg_col"]
+
+ # both "NaN" and None require no replacement
+ # it is assumed that default_bin_value has already been set to a number
+ # if needed, such as in the case of a histfunc of count or count_distinct
+
+ if isinstance(empty_bin_default, str) and empty_bin_default != "NaN":
+ raise ValueError("empty_bin_default must be 'NaN' if it is a string")
+
+ if empty_bin_default not in {"NaN", None}:
+ ranged_bin_counts = ranged_bin_counts.update_view(
+ f"{agg_col} = replaceIfNull({agg_col}, {empty_bin_default})"
+ )
+
+ return ranged_bin_counts.update_view(
+ [
+ f"{bin_min_x} = {range_x}.binMin({range_index_x})",
+ f"{bin_max_x} = {range_x}.binMax({range_index_x})",
+ f"{x}=0.5*({bin_min_x}+{bin_max_x})",
+ f"{bin_min_y} = {range_y}.binMin({range_index_y})",
+ f"{bin_max_y} = {range_y}.binMax({range_index_y})",
+ f"{y}=0.5*({bin_min_y}+{bin_max_y})",
+ f"{histfunc_col} = {agg_col}",
+ ]
+ )
diff --git a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py
index d043c39c0..6b1a88315 100644
--- a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py
+++ b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py
@@ -1,9 +1,26 @@
import unittest
from unittest.mock import patch
+import pandas as pd
from deephaven_server import Server
+def remap_types(
+ df: pd.DataFrame,
+) -> None:
+ """
+ Remap the types of the columns to the correct types
+
+ Args:
+ df: The dataframe to remap the types of
+ """
+ for col in df.columns:
+ if df[col].dtype == "int64":
+ df[col] = df[col].astype("Int64")
+ elif df[col].dtype == "float64":
+ df[col] = df[col].astype("Float64")
+
+
class BaseTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
diff --git a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py
new file mode 100644
index 000000000..bc5ad0b1b
--- /dev/null
+++ b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py
@@ -0,0 +1,321 @@
+import unittest
+
+from ..BaseTest import BaseTestCase
+
+
+class HeatmapTestCase(BaseTestCase):
+ def setUp(self) -> None:
+ from deephaven import new_table
+ from deephaven.column import int_col
+
+ self.source = new_table(
+ [
+ int_col("X", [1, 2, 3, 4, 5]),
+ int_col("Y", [1, 2, 3, 4, 5]),
+ int_col("Z", [1, 2, 3, 4, 5]),
+ ]
+ )
+
+ def test_basic_heatmap(self):
+ import src.deephaven.plot.express as dx
+ from deephaven.constants import NULL_LONG, NULL_DOUBLE
+
+ chart = dx.density_heatmap(self.source, x="X", y="Y").to_dict(self.exporter)
+ plotly, deephaven = chart["plotly"], chart["deephaven"]
+
+ # pop template as we currently do not modify it
+ plotly["layout"].pop("template")
+
+ expected_data = [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "X=%{x}
Y=%{y}
count=%{z}",
+ "opacity": 1.0,
+ "x": [NULL_DOUBLE],
+ "y": [NULL_DOUBLE],
+ "z": [NULL_LONG],
+ "type": "heatmap",
+ }
+ ]
+
+ self.assertEqual(plotly["data"], expected_data)
+
+ expected_layout = {
+ "coloraxis": {
+ "colorbar": {"title": {"text": "count"}},
+ "colorscale": [
+ [0.0, "#0d0887"],
+ [0.1111111111111111, "#46039f"],
+ [0.2222222222222222, "#7201a8"],
+ [0.3333333333333333, "#9c179e"],
+ [0.4444444444444444, "#bd3786"],
+ [0.5555555555555556, "#d8576b"],
+ [0.6666666666666666, "#ed7953"],
+ [0.7777777777777778, "#fb9f3a"],
+ [0.8888888888888888, "#fdca26"],
+ [1.0, "#f0f921"],
+ ],
+ },
+ "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "X"}},
+ "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Y"}},
+ }
+
+ self.assertEqual(plotly["layout"], expected_layout)
+
+ expected_mappings = [
+ {
+ "data_columns": {
+ "X": ["/plotly/data/0/x"],
+ "Y": ["/plotly/data/0/y"],
+ "count": ["/plotly/data/0/z"],
+ },
+ "table": 0,
+ }
+ ]
+
+ self.assertEqual(deephaven["mappings"], expected_mappings)
+
+ self.assertEqual(deephaven["is_user_set_template"], False)
+ self.assertEqual(deephaven["is_user_set_color"], False)
+
+ def test_heatmap_relabel_z(self):
+ import src.deephaven.plot.express as dx
+ from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT
+
+ chart = dx.density_heatmap(
+ self.source,
+ x="X",
+ y="Y",
+ z="Z",
+ labels={"X": "Column X", "Y": "Column Y", "Z": "Column Z"},
+ ).to_dict(self.exporter)
+ plotly, deephaven = chart["plotly"], chart["deephaven"]
+
+ # pop template as we currently do not modify it
+ plotly["layout"].pop("template")
+
+ expected_data = [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Column X=%{x}
Column Y=%{y}
count of Column Z=%{z}",
+ "opacity": 1.0,
+ "x": [NULL_DOUBLE],
+ "y": [NULL_DOUBLE],
+ "z": [NULL_LONG],
+ "type": "heatmap",
+ }
+ ]
+ self.assertEqual(plotly["data"], expected_data)
+
+ expected_layout = {
+ "coloraxis": {
+ "colorbar": {"title": {"text": "count of Column Z"}},
+ "colorscale": [
+ [0.0, "#0d0887"],
+ [0.1111111111111111, "#46039f"],
+ [0.2222222222222222, "#7201a8"],
+ [0.3333333333333333, "#9c179e"],
+ [0.4444444444444444, "#bd3786"],
+ [0.5555555555555556, "#d8576b"],
+ [0.6666666666666666, "#ed7953"],
+ [0.7777777777777778, "#fb9f3a"],
+ [0.8888888888888888, "#fdca26"],
+ [1.0, "#f0f921"],
+ ],
+ },
+ "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}},
+ "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}},
+ }
+
+ self.assertEqual(plotly["layout"], expected_layout)
+
+ expected_mappings = [
+ {
+ "table": 0,
+ "data_columns": {
+ "X": ["/plotly/data/0/x"],
+ "Y": ["/plotly/data/0/y"],
+ "count": ["/plotly/data/0/z"],
+ },
+ }
+ ]
+
+ self.assertEqual(deephaven["mappings"], expected_mappings)
+
+ self.assertEqual(deephaven["is_user_set_template"], False)
+ self.assertEqual(deephaven["is_user_set_color"], False)
+
+ def test_heatmap_relabel_agg_z(self):
+ import src.deephaven.plot.express as dx
+ from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT
+
+ chart = dx.density_heatmap(
+ self.source,
+ x="X",
+ y="Y",
+ z="Z",
+ labels={"X": "Column X", "Y": "Column Y", "count of Z": "count"},
+ ).to_dict(self.exporter)
+
+ plotly, deephaven = chart["plotly"], chart["deephaven"]
+
+ # pop template as we currently do not modify it
+ plotly["layout"].pop("template")
+
+ expected_data = [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Column X=%{x}
Column Y=%{y}
count=%{z}",
+ "opacity": 1.0,
+ "x": [NULL_DOUBLE],
+ "y": [NULL_DOUBLE],
+ "z": [NULL_LONG],
+ "type": "heatmap",
+ }
+ ]
+ self.assertEqual(plotly["data"], expected_data)
+
+ expected_layout = {
+ "coloraxis": {
+ "colorbar": {"title": {"text": "count"}},
+ "colorscale": [
+ [0.0, "#0d0887"],
+ [0.1111111111111111, "#46039f"],
+ [0.2222222222222222, "#7201a8"],
+ [0.3333333333333333, "#9c179e"],
+ [0.4444444444444444, "#bd3786"],
+ [0.5555555555555556, "#d8576b"],
+ [0.6666666666666666, "#ed7953"],
+ [0.7777777777777778, "#fb9f3a"],
+ [0.8888888888888888, "#fdca26"],
+ [1.0, "#f0f921"],
+ ],
+ },
+ "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}},
+ "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}},
+ }
+
+ self.assertEqual(plotly["layout"], expected_layout)
+
+ expected_mappings = [
+ {
+ "table": 0,
+ "data_columns": {
+ "X": ["/plotly/data/0/x"],
+ "Y": ["/plotly/data/0/y"],
+ "count": ["/plotly/data/0/z"],
+ },
+ }
+ ]
+
+ self.assertEqual(deephaven["mappings"], expected_mappings)
+
+ self.assertEqual(deephaven["is_user_set_template"], False)
+ self.assertEqual(deephaven["is_user_set_color"], False)
+
+ def test_full_heatmap(self):
+ import src.deephaven.plot.express as dx
+ from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT
+
+ chart = dx.density_heatmap(
+ self.source,
+ x="X",
+ y="Y",
+ z="Z",
+ labels={
+ "X": "Column X",
+ "Y": "Column Y",
+ "Z": "Column Z",
+ "sum of Column Z": "sum",
+ },
+ color_continuous_scale="Magma",
+ range_color=[0, 10],
+ color_continuous_midpoint=5,
+ opacity=0.5,
+ log_x=True,
+ log_y=False,
+ range_x=[0, 1],
+ range_y=[0, 10],
+ range_bins_x=[5, 10],
+ range_bins_y=[5, 10],
+ empty_bin_default=0,
+ histfunc="sum",
+ nbinsx=2,
+ nbinsy=2,
+ title="Test Title",
+ ).to_dict(self.exporter)
+ plotly, deephaven = chart["plotly"], chart["deephaven"]
+
+ # pop template as we currently do not modify it
+ plotly["layout"].pop("template")
+
+ expected_data = [
+ {
+ "coloraxis": "coloraxis",
+ "hovertemplate": "Column X=%{x}
Column Y=%{y}
sum=%{z}",
+ "opacity": 0.5,
+ "type": "heatmap",
+ "x": [NULL_DOUBLE],
+ "y": [NULL_DOUBLE],
+ "z": [NULL_LONG],
+ }
+ ]
+
+ self.assertEqual(plotly["data"], expected_data)
+
+ expected_layout = {
+ "coloraxis": {
+ "cmax": 10,
+ "cmid": 5,
+ "cmin": 0,
+ "colorbar": {"title": {"text": "sum"}},
+ "colorscale": [
+ [0.0, "#000004"],
+ [0.1111111111111111, "#180f3d"],
+ [0.2222222222222222, "#440f76"],
+ [0.3333333333333333, "#721f81"],
+ [0.4444444444444444, "#9e2f7f"],
+ [0.5555555555555556, "#cd4071"],
+ [0.6666666666666666, "#f1605d"],
+ [0.7777777777777778, "#fd9668"],
+ [0.8888888888888888, "#feca8d"],
+ [1.0, "#fcfdbf"],
+ ],
+ },
+ "title": {"text": "Test Title"},
+ "xaxis": {
+ "anchor": "y",
+ "range": [0, 1],
+ "side": "bottom",
+ "title": {"text": "Column X"},
+ "type": "log",
+ },
+ "yaxis": {
+ "anchor": "x",
+ "range": [0, 10],
+ "side": "left",
+ "title": {"text": "Column Y"},
+ },
+ }
+
+ self.assertEqual(plotly["layout"], expected_layout)
+
+ expected_mappings = [
+ {
+ "data_columns": {
+ "X": ["/plotly/data/0/x"],
+ "Y": ["/plotly/data/0/y"],
+ "sum": ["/plotly/data/0/z"],
+ },
+ "table": 0,
+ }
+ ]
+
+ self.assertEqual(deephaven["mappings"], expected_mappings)
+
+ self.assertEqual(deephaven["is_user_set_template"], False)
+ self.assertEqual(deephaven["is_user_set_color"], False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py
new file mode 100644
index 000000000..4e18dd7e4
--- /dev/null
+++ b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py
@@ -0,0 +1,561 @@
+import unittest
+
+import numpy as np
+import pandas as pd
+
+from ..BaseTest import BaseTestCase, remap_types
+
+
+class HeatmapPreprocessorTestCase(BaseTestCase):
+ def setUp(self) -> None:
+ from deephaven import new_table, merge
+ from deephaven.column import int_col
+
+ self.source = new_table(
+ [
+ int_col("X", [0, 4, 0, 4]),
+ int_col("Y", [0, 0, 4, 4]),
+ int_col("Z", [1, 1, 2, 2]),
+ ]
+ )
+
+ # Variance and standard deviation require at least two rows to be non-null, so stack them
+ # for a resulting var and std of 0 in all grid cells
+ self.var_source = merge([self.source, self.source])
+
+ def tables_equal(self, args, expected_df, t=None, post_process=None) -> None:
+ """
+ Compare the expected dataframe to the actual dataframe generated by the preprocessor
+
+ Args:
+ args: The arguments to pass to the preprocessor
+ expected_df: The expected dataframe
+ t: The table to preprocess, defaults to self.source
+ """
+ from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import (
+ HeatmapPreprocessor,
+ )
+ import deephaven.pandas as dhpd
+
+ if t is None:
+ t = self.source
+
+ args_copy = args.copy()
+
+ heatmap_preprocessor = HeatmapPreprocessor(args_copy)
+
+ new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables([t])
+ new_table, _ = next(new_table_gen)
+
+ new_df = dhpd.to_pandas(new_table)
+
+ if post_process is not None:
+ post_process(new_df)
+
+ self.assertTrue(expected_df.equals(new_df))
+
+ def test_basic_preprocessor(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": None,
+ "histfunc": "count",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.0, 1.0, 3.0, 3.0],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+
+ self.tables_equal(args, expected_df)
+
+ def test_basic_preprocessor_z(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": "Z",
+ "histfunc": "sum",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]}
+ )
+ remap_types(expected_df)
+
+ self.tables_equal(args, expected_df)
+
+ def test_partial_range_preprocessor(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": None,
+ "histfunc": "count",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": [None, 6],
+ "range_bins_y": None,
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.5, 1.5, 4.5, 4.5],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+
+ self.tables_equal(args, expected_df)
+
+ def test_full_preprocessor(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": None,
+ "histfunc": "count",
+ "nbinsx": 2,
+ "nbinsy": 3,
+ "range_bins_x": [1, 5],
+ "range_bins_y": [-1, 5],
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [2.0, 2.0, 2.0, 4.0, 4.0, 4.0],
+ "Y": [0.0, 2.0, 4.0, 0.0, 2.0, 4.0],
+ "count": [0, 0, 0, 1, 0, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count"] = expected_df["count"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ def test_full_preprocessor_z(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": "Z",
+ "histfunc": "avg",
+ "nbinsx": 3,
+ "nbinsy": 2,
+ "range_bins_x": [-1, 5],
+ "range_bins_y": [1, 5],
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.0, 0.0, 2.0, 2.0, 4.0, 4.0],
+ "Y": [2.0, 4.0, 2.0, 4.0, 2.0, 4.0],
+ "avg": [pd.NA, 2.0, pd.NA, pd.NA, pd.NA, 2.0],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["avg"] = expected_df["avg"].astype("Float64")
+
+ self.tables_equal(args, expected_df)
+
+ def test_preprocessor_aggs(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": "Z",
+ "histfunc": "abs_sum",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.0, 1.0, 3.0, 3.0],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "abs_sum": [1, 2, 1, 2],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["abs_sum"] = expected_df["abs_sum"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "avg"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "avg": [1, 2, 1, 2]}
+ )
+ remap_types(expected_df)
+ expected_df["avg"] = expected_df["avg"].astype("Float64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "count"
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.0, 1.0, 3.0, 3.0],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count"] = expected_df["count"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "count_distinct"
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.0, 1.0, 3.0, 3.0],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "count_distinct": [1, 1, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "max"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "max": [1, 2, 1, 2]}
+ )
+ remap_types(expected_df)
+ expected_df["max"] = expected_df["max"].astype("Int32")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "median"
+
+ expected_df = pd.DataFrame(
+ {
+ "X": [1.0, 1.0, 3.0, 3.0],
+ "Y": [1.0, 3.0, 1.0, 3.0],
+ "median": [1, 2, 1, 2],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["median"] = expected_df["median"].astype("Float64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "min"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "min": [1, 2, 1, 2]}
+ )
+ remap_types(expected_df)
+ expected_df["min"] = expected_df["min"].astype("Int32")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "std"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "std": [0, 0, 0, 0]}
+ )
+ remap_types(expected_df)
+ expected_df["std"] = expected_df["std"].astype("Float64")
+
+ self.tables_equal(args, expected_df, self.var_source)
+
+ args["histfunc"] = "sum"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]}
+ )
+ remap_types(expected_df)
+ expected_df["sum"] = expected_df["sum"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "var"
+
+ expected_df = pd.DataFrame(
+ {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "var": [0, 0, 0, 0]}
+ )
+ remap_types(expected_df)
+ expected_df["var"] = expected_df["var"].astype("Float64")
+
+ self.tables_equal(args, expected_df, self.var_source)
+
+ def test_histfunc_z_mismatch(self):
+ from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import (
+ HeatmapPreprocessor,
+ )
+
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": None,
+ "histfunc": "sum",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": None,
+ "table": self.source,
+ }
+
+ heatmap_preprocessor = HeatmapPreprocessor(args)
+
+ new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables(
+ [self.source]
+ )
+
+ self.assertRaises(ValueError, lambda: next(new_table_gen))
+
+ def test_empty_bin_default(self):
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": "Z",
+ "histfunc": "sum",
+ "nbinsx": 4,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": 0,
+ "table": self.source,
+ }
+
+ # sum, 0 - default to 0
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "sum": [1, 2, 0, 0, 0, 0, 1, 2],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["sum"] = expected_df["sum"].astype("Int64")
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = "NaN"
+
+ # sum, "NaN" - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["sum"] = expected_df["sum"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = None
+
+ # sum, None - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["sum"] = expected_df["sum"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "var"
+
+ # var, None - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "var": [np.nan, np.nan, pd.NA, pd.NA, pd.NA, pd.NA, np.nan, np.nan],
+ }
+ )
+ remap_types(expected_df)
+
+ expected_df["var"] = expected_df["var"].astype("Float64")
+
+ def na_conversion(df):
+ # convert to pandas.NA for whole df comparison,
+ # but verify that the nans are there because they come from there being no data in the bin
+ arr = df["var"].to_numpy()
+ for i in [0, 1, 6, 7]:
+ self.assertTrue(np.isnan(arr[i]))
+ df.at[i, "var"] = pd.NA
+
+ self.tables_equal(args, expected_df, post_process=na_conversion)
+
+ args["empty_bin_default"] = "NaN"
+
+ # var, "NaN" - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "var": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["var"] = expected_df["var"].astype("Float64")
+
+ self.tables_equal(args, expected_df, post_process=na_conversion)
+
+ args["empty_bin_default"] = 3
+
+ # var, 1 - default to 1 if at least two data points
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "var": [pd.NA, pd.NA, 3, 3, 3, 3, pd.NA, pd.NA],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["var"] = expected_df["var"].astype("Float64")
+
+ self.tables_equal(args, expected_df, post_process=na_conversion)
+
+ args["histfunc"] = "count"
+
+ # count, 3 - default to 3
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, 3, 3, 3, 3, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count"] = expected_df["count"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = None
+
+ # count, None - default to 0
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, 0, 0, 0, 0, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count"] = expected_df["count"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = "NaN"
+
+ # count, "NaN" - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count"] = expected_df["count"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["histfunc"] = "count_distinct"
+
+ # count_distinct, "NaN" - no default
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count_distinct": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = None
+
+ # count_distinct, None - default to 0
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count_distinct": [1, 1, 0, 0, 0, 0, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ args["empty_bin_default"] = 2
+
+ # count_distinct, 2 - default to 2
+ expected_df = pd.DataFrame(
+ {
+ "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5],
+ "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0],
+ "count_distinct": [1, 1, 2, 2, 2, 2, 1, 1],
+ }
+ )
+ remap_types(expected_df)
+ expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64")
+
+ self.tables_equal(args, expected_df)
+
+ def test_bad_empty_bin_default(self):
+ from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import (
+ HeatmapPreprocessor,
+ )
+
+ args = {
+ "x": "X",
+ "y": "Y",
+ "z": None,
+ "histfunc": "count",
+ "nbinsx": 2,
+ "nbinsy": 2,
+ "range_bins_x": None,
+ "range_bins_y": None,
+ "empty_bin_default": "bad",
+ "table": self.source,
+ }
+
+ process = HeatmapPreprocessor(args)
+
+ self.assertRaises(
+ ValueError,
+ lambda: next(process.preprocess_partitioned_tables([self.source])),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()