diff --git a/plugins/plotly-express/src/deephaven/plot/express/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/__init__.py index cca071669..6a4694442 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/__init__.py @@ -41,6 +41,7 @@ density_mapbox, line_geo, line_mapbox, + density_heatmap, ) from .data import data_generators diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py index ba8e38695..0e019f5e4 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py @@ -3,5 +3,5 @@ DeephavenFigureNode, ) from .generate import generate_figure, update_traces -from .custom_draw import draw_ohlc, draw_candlestick +from .custom_draw import draw_ohlc, draw_candlestick, draw_density_heatmap from .RevisionManager import RevisionManager diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index 21330a98b..4d822dd55 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -6,6 +6,7 @@ from pandas import DataFrame import plotly.graph_objects as go from plotly.graph_objects import Figure +from plotly.validators.heatmap import ColorscaleValidator def draw_finance( @@ -109,3 +110,79 @@ def draw_candlestick( """ return draw_finance(data_frame, x_finance, open, high, low, close, go.Candlestick) + + +def draw_density_heatmap( + data_frame: DataFrame, + x: str, + y: str, + z: str, + labels: dict[str, str] | None = None, + range_color: list[float] | None = None, + color_continuous_scale: str | list[str] | None = "plasma", + color_continuous_midpoint: float | None = None, + opacity: float = 1.0, + title: str | None = None, + template: str | None = None, +) -> Figure: + """Create a density heatmap + + Args: + data_frame: The data frame to draw with + x: The name of the column containing x-axis values + y: The name of the column containing y-axis values + z: The name of the column containing bin values + labels: A dictionary of labels mapping columns to new labels + color_continuous_scale: A color scale or list of colors for a continuous scale + range_color: A list of two numbers that form the endpoints of the color axis + color_continuous_midpoint: A number that is the midpoint of the color axis + opacity: Opacity to apply to all markers. 0 is completely transparent + and 1 is completely opaque. + title: The title of the chart + template: The template for the chart. + + Returns: + The plotly density heatmap + + """ + + # currently, most plots rely on px setting several attributes such as coloraxis, opacity, etc. + # so we need to set some things manually + # this could be done with handle_custom_args in generate.py in the future if + # we need to provide more options, but it's much easier to just set it here + # and doesn't risk breaking any other plots + + heatmap = go.Figure( + go.Heatmap( + x=data_frame[x], + y=data_frame[y], + z=data_frame[z], + coloraxis="coloraxis1", + opacity=opacity, + ) + ) + + range_color_list = range_color or [None, None] + + colorscale_validator = ColorscaleValidator("colorscale", "draw_density_heatmap") + + coloraxis_layout = dict( + colorscale=colorscale_validator.validate_coerce(color_continuous_scale), + cmid=color_continuous_midpoint, + cmin=range_color_list[0], + cmax=range_color_list[1], + ) + + if labels: + x = labels.get(x, x) + y = labels.get(y, y) + + heatmap.update_layout( + coloraxis1=coloraxis_layout, + title=title, + template=template, + xaxis_title=x, + yaxis_title=y, + ) + + return heatmap diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 7b174ef9b..63d20ad55 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -116,6 +116,7 @@ "current_partition", "colors", "unsafe_update_figure", + "heatmap_agg_label", } # these are columns that are "attached" sequentially to the traces @@ -669,8 +670,9 @@ def handle_custom_args( elif arg == "bargap" or arg == "rangemode": fig.update_layout({arg: val}) - # x_axis_generators.append(key_val_generator("bargap", [val])) - # y_axis_generators.append(key_val_generator("bargap", [val])) + + elif arg == "heatmap_agg_label": + fig.update_coloraxes(colorbar_title_text=val) trace_generator = combined_generator(trace_generators) @@ -781,7 +783,6 @@ def get_hover_body( def hover_text_generator( hover_mapping: list[dict[str, str]], - # hover_data - todo, dependent on arrays supported in data mappings types: set[str] | None = None, current_partition: dict[str, str] | None = None, ) -> Generator[dict[str, Any], None, None]: @@ -824,6 +825,7 @@ def hover_text_generator( def compute_labels( hover_mapping: list[dict[str, str]], hist_val_name: str | None, + heatmap_agg_label: str | None, # hover_data - todo, dependent on arrays supported in data mappings types: set[str], labels: dict[str, str] | None, @@ -836,6 +838,7 @@ def compute_labels( Args: hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc + heatmap_agg_label: The aggregate density heatmap column title types: Any types of this chart that require special processing labels: A dictionary of old column name to new column name mappings current_partition: The columns that this figure is partitioned by @@ -846,9 +849,36 @@ def compute_labels( calculate_hist_labels(hist_val_name, hover_mapping[0]) + calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels) + relabel_columns(labels, hover_mapping, types, current_partition) +def calculate_density_heatmap_labels( + heatmap_agg_label: str | None, + hover_mapping: dict[str, str], + labels: dict[str, str] | None, +) -> None: + """Calculate the labels for a density heatmap + The z column is renamed to the heatmap_agg_label + + Args: + heatmap_agg_label: The name of the heatmap aggregate label + hover_mapping: The mapping of variables to columns + labels: A dictionary of labels mapping columns to new labels. + + """ + labels = labels or {} + if heatmap_agg_label: + # the last part of the label is the z column, and could be replaced by labels + split_label = heatmap_agg_label.split(" ") + split_label[-1] = labels.get(split_label[-1], split_label[-1]) + # it's also possible that someone wants to override the whole label + # plotly doesn't seem to do that, but it seems reasonable to allow + new_label = " ".join(split_label) + hover_mapping["z"] = labels.get(new_label, new_label) + + def calculate_hist_labels( hist_val_name: str | None, current_mapping: dict[str, str] ) -> None: @@ -871,6 +901,7 @@ def add_axis_titles( custom_call_args: dict[str, Any], hover_mapping: list[dict[str, str]], hist_val_name: str | None, + heatmap_agg_label: str | None, ) -> None: """Add axis titles. Generally, this only applies when there is a list variable @@ -879,6 +910,7 @@ def add_axis_titles( create hover and axis titles hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc + heatmap_agg_label: The aggregate density heatmap column title """ # Although hovertext is handled above for all plot types, plotly still @@ -892,6 +924,9 @@ def add_axis_titles( new_xaxis_titles = [hover_mapping[0].get("x", None)] new_yaxis_titles = [hover_mapping[0].get("y", None)] + if heatmap_agg_label: + custom_call_args["heatmap_agg_label"] = heatmap_agg_label + # a specified axis title update should override this if new_xaxis_titles: custom_call_args["xaxis_titles"] = custom_call_args.get( @@ -928,6 +963,9 @@ def create_hover_and_axis_titles( (such as the y-axis if the x-axis is specified). Otherwise, there is a legend or not depending on if there is a list of columns or not. + Density heatmaps are also an exception. If "heatmap_agg_label" is specified, + the z column is renamed to this label. + Args: custom_call_args: The custom_call_args that are used to create hover and axis titles @@ -941,14 +979,26 @@ def create_hover_and_axis_titles( labels = custom_call_args.get("labels", None) hist_val_name = custom_call_args.get("hist_val_name", None) + heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None) current_partition = custom_call_args.get("current_partition", {}) - compute_labels(hover_mapping, hist_val_name, types, labels, current_partition) + compute_labels( + hover_mapping, + hist_val_name, + heatmap_agg_label, + types, + labels, + current_partition, + ) hover_text = hover_text_generator(hover_mapping, types, current_partition) - add_axis_titles(custom_call_args, hover_mapping, hist_val_name) + if heatmap_agg_label: + # it's possible that heatmap_agg_label was relabeled, so grab the new label + heatmap_agg_label = hover_mapping[0]["z"] + + add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label) return hover_text diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py index 2377d0bfa..5d4cd4a99 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py @@ -581,6 +581,7 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]: "preprocess_hist" in self.groups or "preprocess_freq" in self.groups or "preprocess_time" in self.groups + or "preprocess_heatmap" in self.groups ) and self.preprocessor: # still need to preprocess the base table table, arg_update = cast( diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py index 8558d8a14..f87c84086 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py @@ -9,3 +9,4 @@ from ._layer import layer from .subplots import make_subplots from .maps import scatter_geo, scatter_mapbox, density_mapbox, line_geo, line_mapbox +from .heatmap import density_heatmap diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py index a3f902977..47108eff3 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py @@ -409,8 +409,8 @@ def histogram( range_y: A list of two numbers that specify the range of the y-axis. range_bins: A list of two numbers that specify the range of data that is used. histfunc: The function to use when aggregating within bins. One of - 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', 'sum', - or 'var' + 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', + 'sum', or 'var' cumulative: If True, values are cumulative. nbins: The number of bins to use. text_auto: If True, display the value at each bar. diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py new file mode 100644 index 000000000..ba40fc27f --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Callable, Literal + +from deephaven.plot.express.shared import default_callback + +from ._private_utils import process_args +from ..deephaven_figure import DeephavenFigure, draw_density_heatmap +from deephaven.table import Table + + +def density_heatmap( + table: Table, + x: str | None = None, + y: str | None = None, + z: str | None = None, + labels: dict[str, str] | None = None, + color_continuous_scale: str | list[str] | None = None, + range_color: list[float] | None = None, + color_continuous_midpoint: float | None = None, + opacity: float = 1.0, + log_x: bool = False, + log_y: bool = False, + range_x: list[float] | None = None, + range_y: list[float] | None = None, + range_bins_x: list[float | None] | None = None, + range_bins_y: list[float | None] | None = None, + histfunc: str = "count", + nbinsx: int = 10, + nbinsy: int = 10, + empty_bin_default: float | Literal["NaN"] | None = None, + title: str | None = None, + template: str | None = None, + unsafe_update_figure: Callable = default_callback, +) -> DeephavenFigure: + """ + A density heatmap creates a grid of colored bins. Each bin represents an aggregation of data points in that region. + + Args: + table: A table to pull data from. + x: A column that contains x-axis values. + y: A column that contains y-axis values. + z: A column that contains z-axis values. If not provided, the count of joint occurrences of x and y will be used. + labels: A dictionary of labels mapping columns to new labels. + color_continuous_scale: A color scale or list of colors for a continuous scale + range_color: A list of two numbers that form the endpoints of the color axis + color_continuous_midpoint: A number that is the midpoint of the color axis + opacity: Opacity to apply to all markers. 0 is completely transparent + and 1 is completely opaque. + log_x: A boolean that specifies if the corresponding axis is a log axis or not. + log_y: A boolean that specifies if the corresponding axis is a log axis or not. + range_x: A list of two numbers that specify the range of the x axes. + None can be specified for no range + range_y: A list of two numbers that specify the range of the y axes. + None can be specified for no range + range_bins_x: A list of two numbers that specify the range of data that is used for x. + None can be specified to use the min and max of the data. + None can also be specified for either of the list values to use the min or max of the data, respectively. + range_bins_y: A list of two numbers that specify the range of data that is used for y. + None can be specified to use the min and max of the data. + None can also be specified for either of the list values to use the min or max of the data, respectively. + histfunc: The function to use when aggregating within bins. One of + 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', + 'sum', or 'var' + nbinsx: The number of bins to use for the x-axis + nbinsy: The number of bins to use for the y-axis + empty_bin_default: The value to use for bins that have no data. + If None and histfunc is 'count' or 'count_distinct', 0 is used. + Otherwise, if None or 'NaN', NaN is used. + 'NaN' forces the bin to be NaN if no data is present, even if histfunc is 'count' or 'count_distinct'. + Note that if multiple points are required to color a bin, such as the case for a histfunc of 'std' or var, + the bin will still be NaN if less than the required number of points are present. + title: The title of the chart + template: The template for the chart. + unsafe_update_figure: An update function that takes a plotly figure + as an argument and optionally returns a plotly figure. If a figure is + not returned, the plotly figure passed will be assumed to be the return + value. Used to add any custom changes to the underlying plotly figure. + Note that the existing data traces should not be removed. This may lead + to unexpected behavior if traces are modified in a way that break data + mappings. + + + + Returns: + DeephavenFigure: A DeephavenFigure that contains the density heatmap + """ + args = locals() + + return process_args(args, {"preprocess_heatmap"}, px_func=draw_density_heatmap) diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py new file mode 100644 index 000000000..241dc44cb --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import Any, Generator + +from deephaven import new_table +from deephaven.column import long_col + +from ..shared import get_unique_names +from .utilities import ( + create_range_table, + validate_heatmap_histfunc, + create_tmp_view, + aggregate_heatmap_bins, + calculate_bin_locations, +) +from deephaven.table import Table + + +class HeatmapPreprocessor: + """ + Preprocessor for heatmaps. + + Attributes: + args: dict[str, Any]: The arguments used to create the plot + histfunc: str: The histfunc to use + nbinsx: int: The number of bins in the x direction + nbinsy: int: The number of bins in the y direction + range_bins_x: list[float | None]: The range of the x bins + range_bins_y: list[float | None]: The range of the y bins + names: dict[str, str]: A mapping of ideal name to unique names + Also contains the names of the x, y, and z columns for ease of use + """ + + def __init__(self, args: dict[str, Any]): + self.args = args + self.histfunc = args.pop("histfunc") + self.nbinsx = args.pop("nbinsx") + self.nbinsy = args.pop("nbinsy") + self.range_bins_x = args.pop("range_bins_x") + self.range_bins_y = args.pop("range_bins_y") + self.empty_bin_default = args.pop("empty_bin_default") + if ( + self.histfunc in {"count", "count_distinct"} + and self.empty_bin_default is None + ): + self.empty_bin_default = 0 + # create unique names for the columns to ensure no collisions + self.names = get_unique_names( + self.args["table"], + [ + "range_index_x", + "range_index_y", + "range_x", + "range_y", + "bin_min_x", + "bin_max_x", + "bin_min_y", + "bin_max_y", + "tmp_x", + "tmp_y", + "agg_col", + self.histfunc, + ], + ) + + # add the column names to names as well for ease of use + self.names.update( + { + "x": self.args["x"], + "y": self.args["y"], + "z": self.args["z"], + } + ) + + def preprocess_partitioned_tables( + self, tables: list[Table], column: str | None = None + ) -> Generator[tuple[Table, dict[str, str | None]], None, None]: + """ + Preprocess params into an appropriate table + + Args: + tables: a list of tables to preprocess + column: the column to aggregate on + ignored for this preprocessor because heatmap always gets the joint count + distribution of x and y or the histfunc of z depending on if z is provided + + Returns: + A tuple containing (the new table, an update to make to the args) + The update should contain the z column name and the heatmap_agg_label + which is the histfunc of z if z is not None, otherwise just the histfunc + + """ + + range_index_x = self.names["range_index_x"] + range_index_y = self.names["range_index_y"] + range_x = self.names["range_x"] + range_y = self.names["range_y"] + histfunc_col = self.names[self.histfunc] + x = self.names["x"] + y = self.names["y"] + z = self.names["z"] + + validate_heatmap_histfunc(z, self.histfunc) + + # there will only be one table, so we can just grab the first one + table = tables[0] + + range_table_x = create_range_table( + table, x, self.range_bins_x, self.nbinsx, range_x + ) + range_table_y = create_range_table( + table, y, self.range_bins_y, self.nbinsy, range_y + ) + range_table = range_table_x.join(range_table_y) + + # ensure that all possible bins are created so that the rendered chart draws spaces for empty bins + bin_counts_x = new_table( + [long_col(range_index_x, [i for i in range(self.nbinsx)])] + ) + bin_counts_y = new_table( + [long_col(range_index_y, [i for i in range(self.nbinsy)])] + ) + bin_counts = bin_counts_x.join(bin_counts_y) + + tmp_view = create_tmp_view(self.names) + + # filter to only the tmp (data) columns, and join the range table to the tmp + ranged_tmp_view = table.view(tmp_view).join(range_table) + + agg_table = aggregate_heatmap_bins(ranged_tmp_view, self.names, self.histfunc) + + # join the aggregated values to the already created comprehensive bin table + bin_counts = bin_counts.natural_join( + agg_table, on=[range_index_x, range_index_y], joins=[self.names["agg_col"]] + ) + + # join the range table to the bin counts - this is needed because the ranges were dropped in the aggregation + ranged_bin_counts = bin_counts.join(range_table) + + bin_counts_with_midpoint = calculate_bin_locations( + ranged_bin_counts, self.names, histfunc_col, self.empty_bin_default + ) + + heatmap_agg_label = f"{self.histfunc} of {z}" if z else self.histfunc + + yield bin_counts_with_midpoint.view([x, y, histfunc_col]), { + "z": histfunc_col, + "heatmap_agg_label": heatmap_agg_label, + } diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py index a285185c8..197b0b73d 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py @@ -9,19 +9,7 @@ from ..shared import get_unique_names from deephaven.column import long_col from deephaven.updateby import cum_sum - -# Used to aggregate within histogram bins -HISTFUNC_MAP = { - "avg": agg.avg, - "count": agg.count_, - "count_distinct": agg.count_distinct, - "max": agg.max_, - "median": agg.median, - "min": agg.min_, - "std": agg.std, - "sum": agg.sum_, - "var": agg.var, -} +from .utilities import create_range_table, HISTFUNC_AGGS def get_aggs( @@ -83,42 +71,14 @@ def prepare_preprocess(self) -> None: self.args["table"], ["range_index", "range", "bin_min", "bin_max", self.histfunc, "total"], ) - self.range_table = self.create_range_table() - - def create_range_table(self) -> Table: - """ - Create a table that contains the bin ranges - - Returns: - A table containing the bin ranges - """ - # partitioned tables need range calculated on all - table = ( - self.table.merge() - if isinstance(self.table, PartitionedTable) - else self.table + self.range_table = create_range_table( + self.args["table"], + self.cols, + self.range_bins, + self.nbins, + self.names["range"], ) - if self.range_bins: - range_min = self.range_bins[0] - range_max = self.range_bins[1] - table = empty_table(1) - else: - range_min = "RangeMin" - range_max = "RangeMax" - # need to find range across all columns - min_aggs, min_cols = get_aggs("RangeMin", self.cols) - max_aggs, max_cols = get_aggs("RangeMax", self.cols) - table = table.agg_by([agg.min_(min_aggs), agg.max_(max_aggs)]).update( - [f"RangeMin = min({min_cols})", f"RangeMax = max({max_cols})"] - ) - - return table.update( - f"{self.names['range']} = new io.deephaven.plot.datasets.histogram." - f"DiscretizedRangeEqual({range_min},{range_max}, " - f"{self.nbins})" - ).view(self.names["range"]) - def create_count_tables( self, tables: list[Table], column: str | None = None ) -> Generator[tuple[Table, str], None, None]: @@ -134,7 +94,7 @@ def create_count_tables( """ range_index, range_ = self.names["range_index"], self.names["range"] - agg_func = HISTFUNC_MAP[self.histfunc] + agg_func = HISTFUNC_AGGS[self.histfunc] if not self.range_table: raise ValueError("Range table not created") for i, table in enumerate(tables): diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py index 808c5a75f..1cba930a0 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py @@ -9,6 +9,7 @@ from .FreqPreprocessor import FreqPreprocessor from .HistPreprocessor import HistPreprocessor from .TimePreprocessor import TimePreprocessor +from .HeatmapPreprocessor import HeatmapPreprocessor class Preprocessor: @@ -54,6 +55,8 @@ def prepare_preprocess(self) -> None: AttachedPreprocessor(self.args, self.always_attached) elif "preprocess_time" in self.groups: self.preprocesser = TimePreprocessor(self.args) + elif "preprocess_heatmap" in self.groups: + self.preprocesser = HeatmapPreprocessor(self.args) def preprocess_partitioned_tables( self, tables: list[Table] | None, column: str | None = None diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py new file mode 100644 index 000000000..4c6ac80aa --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from typing import Generator, Literal + +from deephaven import agg, empty_table +from deephaven.plot.express.shared import get_unique_names +from deephaven.table import PartitionedTable, Table + +# Used to aggregate within bins +HISTFUNC_AGGS = { + "abs_sum": agg.abs_sum, + "avg": agg.avg, + "count": agg.count_, + "count_distinct": agg.count_distinct, + "max": agg.max_, + "median": agg.median, + "min": agg.min_, + "std": agg.std, + "sum": agg.sum_, + "var": agg.var, +} + + +def get_aggs( + base: str, + columns: list[str], +) -> tuple[list[str], str]: + """Create aggregations over all columns + + Args: + base: + The base of the new columns that store the agg per column + columns: + All columns joined for the sake of taking min or max over + the columns + + Returns: + A tuple containing (a list of the new columns, + a joined string of "NewCol, NewCol2...") + + """ + return ( + [f"{base}{column}={column}" for column in columns], + ", ".join([f"{base}{column}" for column in columns]), + ) + + +def single_table(table: Table | PartitionedTable) -> Table: + """ + Merge a table if it is partitioned table + + Args: + table: The table to merge + + Returns: + The table if it is not a partitioned table, otherwise the merged table + """ + return table.merge() if isinstance(table, PartitionedTable) else table + + +def discretized_range_view( + table: Table, + range_min: float | str, + range_max: float | str, + nbins: int, + range_name: str, +) -> Table: + """ + Create a discretized range view that can be joined with a table to compute indices + + Args: + table: The table to create the range view from + range_min: The minimum value of the range. Can be a number or a column name + range_max: The maximum value of the range. Can be a number or a column name + nbins: The number of bins to create + range_name: The name of the range object in the resulting table + + Returns: + A table that contains the range object for the given table + """ + + return table.update( + f"{range_name} = new io.deephaven.plot.datasets.histogram." + f"DiscretizedRangeEqual({range_min},{range_max}, " + f"{nbins})" + ).view(range_name) + + +def create_range_table( + table: Table, + cols: str | list[str], + range_bins: list[float | None] | None, + nbins: int, + range_name: str, +) -> Table: + """ + Create single row tables with range objects that can compute bin membership + + Args: + table: The table to create the range table from + cols: The columns to create the range table from. The resulting range table will have + its range calculated over all of these columns. + range_bins: The range to create the bins over. + If None, the range will be calculated over the columns. + If a list of two numbers, the range will be set to these numbers. + The values within this list can also be None, in which case the range will be calculated over the columns + for whichever value is None. + nbins: The number of bins to create + range_name: The name of the range object in the resulting table + + Returns: + A table that contains the range object for the given + """ + + cols = [cols] if isinstance(cols, str) else cols + + range_min = ( + range_bins[0] if range_bins and range_bins[0] is not None else "RangeMin" + ) + range_max = ( + range_bins[1] if range_bins and range_bins[1] is not None else "RangeMax" + ) + + min_table = empty_table(1) + max_table = empty_table(1) + + if range_min == "RangeMin": + min_aggs, min_cols = get_aggs("RangeMin", cols) + min_table = table.agg_by([agg.min_(min_aggs)]).update( + [f"RangeMin = min({min_cols})"] + ) + if range_max == "RangeMax": + max_aggs, max_cols = get_aggs("RangeMax", cols) + max_table = table.agg_by([agg.max_(max_aggs)]).update( + [f"RangeMax = max({max_cols})"] + ) + + return discretized_range_view( + min_table.join(max_table), range_min, range_max, nbins, range_name + ) + + +def validate_heatmap_histfunc(z: str | None, histfunc: str) -> None: + """ + Check if the histfunc is valid + + Args: + z: The column that contains z-axis values. + histfunc: The function to use when aggregating within bins. Should be 'count' if z is None. + + Raises: + ValueError: If the histfunc is not valid + """ + if z is None and histfunc != "count": + raise ValueError("z must be specified for histfunc other than count") + elif histfunc not in HISTFUNC_AGGS: + raise ValueError(f"{histfunc} is not a valid histfunc") + + +def create_tmp_view( + names: dict[str, str], +) -> list[str]: + """ + Create a temporary view that avoids column name collisions + + Args: + names: The names used for columns so that they don't collide + + Returns: + A list of strings that are used to create a temporary view + """ + + x = names["x"] + y = names["y"] + z = names["z"] + tmp_x = names["tmp_x"] + tmp_y = names["tmp_y"] + agg_col = names["agg_col"] + + tmp_view = [f"{tmp_x} = {x}", f"{tmp_y} = {y}"] + + if z is not None: + tmp_view.append(f"{agg_col} = {z}") + else: + # if z is not specified, just count the number of occurrences, so tmp_x or tmp_y can be used + names["agg_col"] = tmp_x + + return tmp_view + + +def aggregate_heatmap_bins( + table: Table, + names: dict[str, str], + histfunc: str, +) -> Table: + """ + Create count tables that aggregate up values into bins + + Args: + table: The table to aggregate. Should contain the tmp data columns and the range columns + names: The names used for columns so that they don't collide + histfunc: The function to use when aggregating within bins. Should be 'count' if z is None. + + Yields: + A tuple containing the table and a temporary column that contains the aggregated values + """ + + range_x = names["range_x"] + range_y = names["range_y"] + range_index_x = names["range_index_x"] + range_index_y = names["range_index_y"] + + tmp_x = names["tmp_x"] + tmp_y = names["tmp_y"] + agg_col = names["agg_col"] + + count_table = ( + table.update_view( + [ + f"{range_index_x} = {range_x}.index({tmp_x})", + f"{range_index_y} = {range_y}.index({tmp_y})", + ] + ) + .where([f"!isNull({range_index_x})", f"!isNull({range_index_y})"]) + .agg_by([HISTFUNC_AGGS[histfunc](agg_col)], [range_index_x, range_index_y]) + ) + return count_table + + +def calculate_bin_locations( + ranged_bin_counts: Table, + names: dict[str, str], + histfunc_col: str, + empty_bin_default: float | Literal["NaN"] | None, +) -> Table: + """ + Compute the center of the bins for the x and y axes + plotly requires the center of the bins to plot the heatmap and will calculate the width automatically + + Args + bin_counts_ranged: A table that contains the bin counts and the range columns + names: The names used for columns so that they don't collide + histfunc_col: The column that contains the aggregated values + empty_bin_default: The default value to use for bins that have no data + + Returns: + A table that contains the bin counts and the center of the bins + """ + range_index_x = names["range_index_x"] + range_index_y = names["range_index_y"] + range_x = names["range_x"] + range_y = names["range_y"] + bin_min_x = names["bin_min_x"] + bin_max_x = names["bin_max_x"] + bin_min_y = names["bin_min_y"] + bin_max_y = names["bin_max_y"] + x = names["x"] + y = names["y"] + agg_col = names["agg_col"] + + # both "NaN" and None require no replacement + # it is assumed that default_bin_value has already been set to a number + # if needed, such as in the case of a histfunc of count or count_distinct + + if isinstance(empty_bin_default, str) and empty_bin_default != "NaN": + raise ValueError("empty_bin_default must be 'NaN' if it is a string") + + if empty_bin_default not in {"NaN", None}: + ranged_bin_counts = ranged_bin_counts.update_view( + f"{agg_col} = replaceIfNull({agg_col}, {empty_bin_default})" + ) + + return ranged_bin_counts.update_view( + [ + f"{bin_min_x} = {range_x}.binMin({range_index_x})", + f"{bin_max_x} = {range_x}.binMax({range_index_x})", + f"{x}=0.5*({bin_min_x}+{bin_max_x})", + f"{bin_min_y} = {range_y}.binMin({range_index_y})", + f"{bin_max_y} = {range_y}.binMax({range_index_y})", + f"{y}=0.5*({bin_min_y}+{bin_max_y})", + f"{histfunc_col} = {agg_col}", + ] + ) diff --git a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py index d043c39c0..6b1a88315 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py +++ b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py @@ -1,9 +1,26 @@ import unittest from unittest.mock import patch +import pandas as pd from deephaven_server import Server +def remap_types( + df: pd.DataFrame, +) -> None: + """ + Remap the types of the columns to the correct types + + Args: + df: The dataframe to remap the types of + """ + for col in df.columns: + if df[col].dtype == "int64": + df[col] = df[col].astype("Int64") + elif df[col].dtype == "float64": + df[col] = df[col].astype("Float64") + + class BaseTestCase(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py new file mode 100644 index 000000000..bc5ad0b1b --- /dev/null +++ b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py @@ -0,0 +1,321 @@ +import unittest + +from ..BaseTest import BaseTestCase + + +class HeatmapTestCase(BaseTestCase): + def setUp(self) -> None: + from deephaven import new_table + from deephaven.column import int_col + + self.source = new_table( + [ + int_col("X", [1, 2, 3, 4, 5]), + int_col("Y", [1, 2, 3, 4, 5]), + int_col("Z", [1, 2, 3, 4, 5]), + ] + ) + + def test_basic_heatmap(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE + + chart = dx.density_heatmap(self.source, x="X", y="Y").to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "X=%{x}
Y=%{y}
count=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count"}}, + "colorscale": [ + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + "table": 0, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_heatmap_relabel_z(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={"X": "Column X", "Y": "Column Y", "Z": "Column Z"}, + ).to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
count of Column Z=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count of Column Z"}}, + "colorscale": [ + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "table": 0, + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_heatmap_relabel_agg_z(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={"X": "Column X", "Y": "Column Y", "count of Z": "count"}, + ).to_dict(self.exporter) + + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
count=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count"}}, + "colorscale": [ + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "table": 0, + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_full_heatmap(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={ + "X": "Column X", + "Y": "Column Y", + "Z": "Column Z", + "sum of Column Z": "sum", + }, + color_continuous_scale="Magma", + range_color=[0, 10], + color_continuous_midpoint=5, + opacity=0.5, + log_x=True, + log_y=False, + range_x=[0, 1], + range_y=[0, 10], + range_bins_x=[5, 10], + range_bins_y=[5, 10], + empty_bin_default=0, + histfunc="sum", + nbinsx=2, + nbinsy=2, + title="Test Title", + ).to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
sum=%{z}", + "opacity": 0.5, + "type": "heatmap", + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + } + ] + + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "cmax": 10, + "cmid": 5, + "cmin": 0, + "colorbar": {"title": {"text": "sum"}}, + "colorscale": [ + [0.0, "#000004"], + [0.1111111111111111, "#180f3d"], + [0.2222222222222222, "#440f76"], + [0.3333333333333333, "#721f81"], + [0.4444444444444444, "#9e2f7f"], + [0.5555555555555556, "#cd4071"], + [0.6666666666666666, "#f1605d"], + [0.7777777777777778, "#fd9668"], + [0.8888888888888888, "#feca8d"], + [1.0, "#fcfdbf"], + ], + }, + "title": {"text": "Test Title"}, + "xaxis": { + "anchor": "y", + "range": [0, 1], + "side": "bottom", + "title": {"text": "Column X"}, + "type": "log", + }, + "yaxis": { + "anchor": "x", + "range": [0, 10], + "side": "left", + "title": {"text": "Column Y"}, + }, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "sum": ["/plotly/data/0/z"], + }, + "table": 0, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + +if __name__ == "__main__": + unittest.main() diff --git a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py new file mode 100644 index 000000000..4e18dd7e4 --- /dev/null +++ b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py @@ -0,0 +1,561 @@ +import unittest + +import numpy as np +import pandas as pd + +from ..BaseTest import BaseTestCase, remap_types + + +class HeatmapPreprocessorTestCase(BaseTestCase): + def setUp(self) -> None: + from deephaven import new_table, merge + from deephaven.column import int_col + + self.source = new_table( + [ + int_col("X", [0, 4, 0, 4]), + int_col("Y", [0, 0, 4, 4]), + int_col("Z", [1, 1, 2, 2]), + ] + ) + + # Variance and standard deviation require at least two rows to be non-null, so stack them + # for a resulting var and std of 0 in all grid cells + self.var_source = merge([self.source, self.source]) + + def tables_equal(self, args, expected_df, t=None, post_process=None) -> None: + """ + Compare the expected dataframe to the actual dataframe generated by the preprocessor + + Args: + args: The arguments to pass to the preprocessor + expected_df: The expected dataframe + t: The table to preprocess, defaults to self.source + """ + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + import deephaven.pandas as dhpd + + if t is None: + t = self.source + + args_copy = args.copy() + + heatmap_preprocessor = HeatmapPreprocessor(args_copy) + + new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables([t]) + new_table, _ = next(new_table_gen) + + new_df = dhpd.to_pandas(new_table) + + if post_process is not None: + post_process(new_df) + + self.assertTrue(expected_df.equals(new_df)) + + def test_basic_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_basic_preprocessor_z(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]} + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_partial_range_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": [None, 6], + "range_bins_y": None, + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.5, 1.5, 4.5, 4.5], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_full_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 3, + "range_bins_x": [1, 5], + "range_bins_y": [-1, 5], + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [2.0, 2.0, 2.0, 4.0, 4.0, 4.0], + "Y": [0.0, 2.0, 4.0, 0.0, 2.0, 4.0], + "count": [0, 0, 0, 1, 0, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + def test_full_preprocessor_z(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "avg", + "nbinsx": 3, + "nbinsy": 2, + "range_bins_x": [-1, 5], + "range_bins_y": [1, 5], + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [0.0, 0.0, 2.0, 2.0, 4.0, 4.0], + "Y": [2.0, 4.0, 2.0, 4.0, 2.0, 4.0], + "avg": [pd.NA, 2.0, pd.NA, pd.NA, pd.NA, 2.0], + } + ) + remap_types(expected_df) + expected_df["avg"] = expected_df["avg"].astype("Float64") + + self.tables_equal(args, expected_df) + + def test_preprocessor_aggs(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "abs_sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "abs_sum": [1, 2, 1, 2], + } + ) + remap_types(expected_df) + expected_df["abs_sum"] = expected_df["abs_sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "avg" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "avg": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["avg"] = expected_df["avg"].astype("Float64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count_distinct" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "max" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "max": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["max"] = expected_df["max"].astype("Int32") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "median" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "median": [1, 2, 1, 2], + } + ) + remap_types(expected_df) + expected_df["median"] = expected_df["median"].astype("Float64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "min" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "min": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["min"] = expected_df["min"].astype("Int32") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "std" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "std": [0, 0, 0, 0]} + ) + remap_types(expected_df) + expected_df["std"] = expected_df["std"].astype("Float64") + + self.tables_equal(args, expected_df, self.var_source) + + args["histfunc"] = "sum" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "var" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "var": [0, 0, 0, 0]} + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, self.var_source) + + def test_histfunc_z_mismatch(self): + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": None, + "table": self.source, + } + + heatmap_preprocessor = HeatmapPreprocessor(args) + + new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables( + [self.source] + ) + + self.assertRaises(ValueError, lambda: next(new_table_gen)) + + def test_empty_bin_default(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "sum", + "nbinsx": 4, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": 0, + "table": self.source, + } + + # sum, 0 - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, 0, 0, 0, 0, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = "NaN" + + # sum, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # sum, None - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "var" + + # var, None - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [np.nan, np.nan, pd.NA, pd.NA, pd.NA, pd.NA, np.nan, np.nan], + } + ) + remap_types(expected_df) + + expected_df["var"] = expected_df["var"].astype("Float64") + + def na_conversion(df): + # convert to pandas.NA for whole df comparison, + # but verify that the nans are there because they come from there being no data in the bin + arr = df["var"].to_numpy() + for i in [0, 1, 6, 7]: + self.assertTrue(np.isnan(arr[i])) + df.at[i, "var"] = pd.NA + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["empty_bin_default"] = "NaN" + + # var, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA], + } + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["empty_bin_default"] = 3 + + # var, 1 - default to 1 if at least two data points + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [pd.NA, pd.NA, 3, 3, 3, 3, pd.NA, pd.NA], + } + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["histfunc"] = "count" + + # count, 3 - default to 3 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 3, 3, 3, 3, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # count, None - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 0, 0, 0, 0, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = "NaN" + + # count, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count_distinct" + + # count_distinct, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # count_distinct, None - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 0, 0, 0, 0, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = 2 + + # count_distinct, 2 - default to 2 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 2, 2, 2, 2, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + def test_bad_empty_bin_default(self): + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": "bad", + "table": self.source, + } + + process = HeatmapPreprocessor(args) + + self.assertRaises( + ValueError, + lambda: next(process.preprocess_partitioned_tables([self.source])), + ) + + +if __name__ == "__main__": + unittest.main()