From 0ffec96fbf0b5b0c1f49f8bbaafad3f94c18755a Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 09:21:55 -0500 Subject: [PATCH 01/11] wip --- .../src/deephaven/plot/express/__init__.py | 1 + .../plot/express/deephaven_figure/__init__.py | 2 +- .../express/deephaven_figure/custom_draw.py | 113 +++++-- .../plot/express/deephaven_figure/generate.py | 36 ++- .../plot/express/plots/PartitionManager.py | 1 + .../deephaven/plot/express/plots/__init__.py | 1 + .../plot/express/plots/distribution.py | 4 +- .../deephaven/plot/express/plots/heatmap.py | 83 ++++++ .../express/preprocess/HeatmapPreprocessor.py | 130 +++++++++ .../express/preprocess/HistPreprocessor.py | 56 +--- .../plot/express/preprocess/Preprocessor.py | 3 + .../plot/express/preprocess/utilities.py | 276 ++++++++++++++++++ 12 files changed, 630 insertions(+), 76 deletions(-) create mode 100644 plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py create mode 100644 plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py create mode 100644 plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py diff --git a/plugins/plotly-express/src/deephaven/plot/express/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/__init__.py index cca071669..6a4694442 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/__init__.py @@ -41,6 +41,7 @@ density_mapbox, line_geo, line_mapbox, + density_heatmap, ) from .data import data_generators diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py index ba8e38695..0e019f5e4 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/__init__.py @@ -3,5 +3,5 @@ DeephavenFigureNode, ) from .generate import generate_figure, update_traces -from .custom_draw import draw_ohlc, draw_candlestick +from .custom_draw import draw_ohlc, draw_candlestick, draw_density_heatmap from .RevisionManager import RevisionManager diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index 21330a98b..e4e562f6e 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -6,16 +6,17 @@ from pandas import DataFrame import plotly.graph_objects as go from plotly.graph_objects import Figure +from plotly.validators.heatmap import ColorscaleValidator def draw_finance( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], - go_func: Callable, + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], + go_func: Callable, ) -> Figure: """Draws a finance (OHLC or candlestick) chart @@ -33,7 +34,7 @@ def draw_finance( """ if not all(len(open) == len(ls) for ls in [high, low, close]) and ( - len(open) == len(x_finance) or len(x_finance) == 1 + len(open) == len(x_finance) or len(x_finance) == 1 ): raise ValueError( "open, high, low, close must have same length and x " @@ -43,7 +44,7 @@ def draw_finance( data = [] for x_f, o, h, l, c in zip_longest( - x_finance, open, high, low, close, fillvalue=x_finance[0] + x_finance, open, high, low, close, fillvalue=x_finance[0] ): data.append( go_func( @@ -59,12 +60,12 @@ def draw_finance( def draw_ohlc( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], ) -> Figure: """Create a plotly OHLC chart. @@ -85,12 +86,12 @@ def draw_ohlc( def draw_candlestick( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], ) -> Figure: """Create a plotly candlestick chart. @@ -109,3 +110,73 @@ def draw_candlestick( """ return draw_finance(data_frame, x_finance, open, high, low, close, go.Candlestick) + + +def draw_density_heatmap( + data_frame: DataFrame, + x: str, + y: str, + z: str, + range_color: list[float] | None = None, + color_continuous_scale: str = "Viridis", + color_continuous_midpoint=None, + opacity=1.0, + title=None, + template=None, +) -> Figure: + """Create a density heatmap + + Args: + data_frame: The data frame to draw with + x: The name of the column containing x-axis values + y: The name of the column containing y-axis values + z: The name of the column containing bin values + color_continuous_scale: A list of colors for a continuous scale + range_color: A list of two numbers that form the endpoints of the color axis + color_continuous_midpoint: A number that is the midpoint of the color axis + opacity: Opacity to apply to all markers. 0 is completely transparent + and 1 is completely opaque. + title: The title of the chart + template: The template for the chart. + + Returns: + The plotly density heatmap + + """ + + # currently, most plots rely on px setting several attributes such as coloraxis, opacity, etc. + # so we need to set some things manually + # this could be done with handle_custom_args in generate.py in the future if + # we need to provide more options, but it's much easier to just set it here + # and doesn't risk breaking any other plots + + heatmap = go.Figure( + go.Heatmap( + x=data_frame[x], + y=data_frame[y], + z=data_frame[z], + coloraxis="coloraxis1", + opacity=opacity, + ) + ) + + range_color = range_color or [None, None] + + colorscale_validator = ColorscaleValidator("colorscale", "make_figure") + + coloraxis_layout = dict( + colorscale=colorscale_validator.validate_coerce(color_continuous_scale), + cmid=color_continuous_midpoint, + cmin=range_color[0], + cmax=range_color[1], + ) + + heatmap.update_layout( + coloraxis1=coloraxis_layout, + title=title, + template=template, + xaxis_title=x, + yaxis_title=y, + ) + + return heatmap diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 7b174ef9b..97eaa3b6e 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -116,6 +116,7 @@ "current_partition", "colors", "unsafe_update_figure", + "heatmap_title", } # these are columns that are "attached" sequentially to the traces @@ -669,8 +670,9 @@ def handle_custom_args( elif arg == "bargap" or arg == "rangemode": fig.update_layout({arg: val}) - # x_axis_generators.append(key_val_generator("bargap", [val])) - # y_axis_generators.append(key_val_generator("bargap", [val])) + + elif arg == "heatmap_title": + fig.update_coloraxes(colorbar_title_text=val) trace_generator = combined_generator(trace_generators) @@ -824,6 +826,7 @@ def hover_text_generator( def compute_labels( hover_mapping: list[dict[str, str]], hist_val_name: str | None, + heatmap_title: str | None, # hover_data - todo, dependent on arrays supported in data mappings types: set[str], labels: dict[str, str] | None, @@ -836,6 +839,7 @@ def compute_labels( Args: hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc + heatmap_title: The aggregate density heatmap column title types: Any types of this chart that require special processing labels: A dictionary of old column name to new column name mappings current_partition: The columns that this figure is partitioned by @@ -846,9 +850,27 @@ def compute_labels( calculate_hist_labels(hist_val_name, hover_mapping[0]) + calculate_density_heatmap_labels(heatmap_title, hover_mapping[0]) + relabel_columns(labels, hover_mapping, types, current_partition) +def calculate_density_heatmap_labels( + heatmap_title: str | None, + hover_mapping: dict[str, str], +) -> None: + """Calculate the labels for a density heatmap + The z column is renamed to the colorbar title + + Args: + heatmap_title: The title of the colorbar + hover_mapping: The mapping of variables to columns + + """ + if heatmap_title: + hover_mapping["z"] = heatmap_title + + def calculate_hist_labels( hist_val_name: str | None, current_mapping: dict[str, str] ) -> None: @@ -871,6 +893,7 @@ def add_axis_titles( custom_call_args: dict[str, Any], hover_mapping: list[dict[str, str]], hist_val_name: str | None, + heatmap_title: str | None, ) -> None: """Add axis titles. Generally, this only applies when there is a list variable @@ -879,6 +902,7 @@ def add_axis_titles( create hover and axis titles hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc + heatmap_title: The aggregate density heatmap column title """ # Although hovertext is handled above for all plot types, plotly still @@ -892,6 +916,9 @@ def add_axis_titles( new_xaxis_titles = [hover_mapping[0].get("x", None)] new_yaxis_titles = [hover_mapping[0].get("y", None)] + if heatmap_title: + custom_call_args["heatmap_title"] = heatmap_title + # a specified axis title update should override this if new_xaxis_titles: custom_call_args["xaxis_titles"] = custom_call_args.get( @@ -941,14 +968,15 @@ def create_hover_and_axis_titles( labels = custom_call_args.get("labels", None) hist_val_name = custom_call_args.get("hist_val_name", None) + heatmap_title = custom_call_args.get("heatmap_title", None) current_partition = custom_call_args.get("current_partition", {}) - compute_labels(hover_mapping, hist_val_name, types, labels, current_partition) + compute_labels(hover_mapping, hist_val_name, heatmap_title, types, labels, current_partition) hover_text = hover_text_generator(hover_mapping, types, current_partition) - add_axis_titles(custom_call_args, hover_mapping, hist_val_name) + add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_title) return hover_text diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py index 2377d0bfa..5d4cd4a99 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/PartitionManager.py @@ -581,6 +581,7 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]: "preprocess_hist" in self.groups or "preprocess_freq" in self.groups or "preprocess_time" in self.groups + or "preprocess_heatmap" in self.groups ) and self.preprocessor: # still need to preprocess the base table table, arg_update = cast( diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py index 8558d8a14..f87c84086 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/__init__.py @@ -9,3 +9,4 @@ from ._layer import layer from .subplots import make_subplots from .maps import scatter_geo, scatter_mapbox, density_mapbox, line_geo, line_mapbox +from .heatmap import density_heatmap diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py index a3f902977..47108eff3 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/distribution.py @@ -409,8 +409,8 @@ def histogram( range_y: A list of two numbers that specify the range of the y-axis. range_bins: A list of two numbers that specify the range of data that is used. histfunc: The function to use when aggregating within bins. One of - 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', 'sum', - or 'var' + 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', + 'sum', or 'var' cumulative: If True, values are cumulative. nbins: The number of bins to use. text_auto: If True, display the value at each bar. diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py new file mode 100644 index 000000000..9bee2ce14 --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Callable + +from deephaven.plot.express.shared import default_callback + +from ._private_utils import process_args +from ..deephaven_figure import DeephavenFigure, draw_density_heatmap +from deephaven.table import Table + + +def density_heatmap( + table: Table, + x: str | None = None, + y: str | None = None, + z: str | None = None, + labels: dict[str, str] = None, + color_continuous_scale: str = "Viridis", + range_color: list[float] = None, + color_continuous_midpoint: float = None, + opacity: float = 1.0, + log_x: bool = False, + log_y: bool = False, + range_x: list[float] | None = None, + range_y: list[float] | None = None, + range_bins_x: list[float | None] = None, + range_bins_y: list[float | None] = None, + histfunc: str = "count", + nbinsx: int = 10, + nbinsy: int = 10, + title: str = None, + template: str = None, + unsafe_update_figure: Callable = default_callback, +) -> DeephavenFigure: + """ + Create a density heatmap + + Args: + table: A table to pull data from. + x: A column that contains x-axis values. + y: A column that contains y-axis values. + z: A column that contains z-axis values. If not provided, the count of joint occurrences of x and y will be used. + labels: A dictionary of labels mapping columns to new labels. + color_continuous_scale: A list of colors for a continuous scale + range_color: A list of two numbers that form the endpoints of the color axis + color_continuous_midpoint: A number that is the midpoint of the color axis + opacity: Opacity to apply to all markers. 0 is completely transparent + and 1 is completely opaque. + log_x: A boolean or list of booleans that specify if + the corresponding axis is a log axis or not. The booleans loop, so if there + are more series than booleans, booleans will be reused. + log_y: A boolean or list of booleans that specify if + the corresponding axis is a log axis or not. The booleans loop, so if there + are more series than booleans, booleans will be reused. + range_x: A list of two numbers that specify the range of the x axes. + None can be specified for no range + range_y: A list of two numbers that specify the range of the y axes. + None can be specified for no range + range_bins_x: A list of two numbers that specify the range of data that is used for x. + range_bins_y: A list of two numbers that specify the range of data that is used for y. + histfunc: The function to use when aggregating within bins. One of + 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', + 'sum', or 'var' + nbinsx: The number of bins to use for the x-axis + nbinsy: The number of bins to use for the y-axis + title: The title of the chart + template: The template for the chart. + unsafe_update_figure: An update function that takes a plotly figure + as an argument and optionally returns a plotly figure. If a figure is + not returned, the plotly figure passed will be assumed to be the return + value. Used to add any custom changes to the underlying plotly figure. + Note that the existing data traces should not be removed. This may lead + to unexpected behavior if traces are modified in a way that break data + mappings. + + + + Returns: + DeephavenFigure: A DeephavenFigure that contains the density heatmap + """ + args = locals() + + return process_args(args, {"preprocess_heatmap"}, px_func=draw_density_heatmap) diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py new file mode 100644 index 000000000..78c66c932 --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Any, Generator + +from deephaven import new_table +from deephaven.column import long_col + +from ..shared import get_unique_names +from .utilities import create_range_table, validate_heatmap_histfunc, create_tmp_view, \ + aggregate_heatmap_bins, calculate_bin_locations +from deephaven.table import Table + + +class HeatmapPreprocessor: + """ + Preprocessor for heatmaps. + + Attributes: + args: dict[str, Any]: The arguments used to create the plot + range_table: The range table, calculated over the whole original table + + """ + + def __init__(self, args: dict[str, Any]): + self.args = args + self.histfunc = args.pop("histfunc") + self.nbinsx = args.pop("nbinsx") + self.nbinsy = args.pop("nbinsy") + self.range_bins_x = args.pop("range_bins_x") + self.range_bins_y = args.pop("range_bins_y") + # create unique names for the columns to ensure no collisions + self.names = get_unique_names( + self.args["table"], + [ + "range_index_x", + "range_index_y", + "range_x", + "range_y", + "bin_min_x", + "bin_max_x", + "bin_min_y", + "bin_max_y", + "tmp_x", + "tmp_y", + "agg_col", + self.histfunc, + ], + ) + + # add the column names to names as well for ease of use + self.names.update( + { + "x": self.args["x"], + "y": self.args["y"], + "z": self.args["z"], + } + ) + + def preprocess_partitioned_tables( + self, tables: list[Table], column: str | None = None + ) -> Generator[tuple[Table, dict[str, str | None]], None, None]: + """ + Preprocess params into an appropriate table + + Args: + tables: a list of tables to preprocess + column: the column to aggregate on + + Returns: + A tuple containing (the new table, an update to make to the args) + + """ + + range_index_x = self.names["range_index_x"] + range_index_y = self.names["range_index_y"] + range_x = self.names["range_x"] + range_y = self.names["range_y"] + histfunc_col = self.names[self.histfunc] + x = self.names["x"] + y = self.names["y"] + z = self.names["z"] + + validate_heatmap_histfunc(z, self.histfunc) + + # there will only be one table, so we can just grab the first one + table = tables[0] + + range_table_x = create_range_table( + table, x, self.range_bins_x, self.nbinsx, range_name=range_x + ) + range_table_y = create_range_table( + table, y, self.range_bins_y, self.nbinsy, range_name=range_y + ) + range_table = range_table_x.join(range_table_y) + + # ensure that all possible bins are created so that the rendered chart draws spaces for empty bins + bin_counts_x = new_table( + [long_col(range_index_x, [i for i in range(self.nbinsx)])] + ) + bin_counts_y = new_table( + [long_col(range_index_y, [i for i in range(self.nbinsy)])] + ) + bin_counts = bin_counts_x.join(bin_counts_y) + + tmp_view = create_tmp_view(self.names) + + # filter to only the tmp (data) columns, and join the range table to the tmp + ranged_tmp_view = table.view(tmp_view).join(range_table) + + agg_table = aggregate_heatmap_bins( + ranged_tmp_view, self.names, self.histfunc + ) + + # join the aggregated values to the already created comprehensive bin table + bin_counts = bin_counts.natural_join( + agg_table, on=[range_index_x, range_index_y], joins=[self.names["agg_col"]] + ) + + # join the range table to the bin counts - this is needed because the ranges were dropped in the aggregation + ranged_bin_counts = bin_counts.join(range_table) + + bin_counts_with_midpoint = calculate_bin_locations( + ranged_bin_counts, self.names, histfunc_col + ) + + heatmap_title = f"{self.histfunc} of {z}" if z else self.histfunc + + yield bin_counts_with_midpoint.view([x, y, histfunc_col]), { + "z": histfunc_col, "heatmap_title": heatmap_title + } diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py index a285185c8..46971b33b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py @@ -9,19 +9,7 @@ from ..shared import get_unique_names from deephaven.column import long_col from deephaven.updateby import cum_sum - -# Used to aggregate within histogram bins -HISTFUNC_MAP = { - "avg": agg.avg, - "count": agg.count_, - "count_distinct": agg.count_distinct, - "max": agg.max_, - "median": agg.median, - "min": agg.min_, - "std": agg.std, - "sum": agg.sum_, - "var": agg.var, -} +from .utilities import create_range_table, HISTFUNC_AGGS def get_aggs( @@ -83,42 +71,14 @@ def prepare_preprocess(self) -> None: self.args["table"], ["range_index", "range", "bin_min", "bin_max", self.histfunc, "total"], ) - self.range_table = self.create_range_table() - - def create_range_table(self) -> Table: - """ - Create a table that contains the bin ranges - - Returns: - A table containing the bin ranges - """ - # partitioned tables need range calculated on all - table = ( - self.table.merge() - if isinstance(self.table, PartitionedTable) - else self.table + self.range_table = create_range_table( + self.args["table"], + self.range_bins, + [self.var], + self.nbins, + self.names["range"], ) - if self.range_bins: - range_min = self.range_bins[0] - range_max = self.range_bins[1] - table = empty_table(1) - else: - range_min = "RangeMin" - range_max = "RangeMax" - # need to find range across all columns - min_aggs, min_cols = get_aggs("RangeMin", self.cols) - max_aggs, max_cols = get_aggs("RangeMax", self.cols) - table = table.agg_by([agg.min_(min_aggs), agg.max_(max_aggs)]).update( - [f"RangeMin = min({min_cols})", f"RangeMax = max({max_cols})"] - ) - - return table.update( - f"{self.names['range']} = new io.deephaven.plot.datasets.histogram." - f"DiscretizedRangeEqual({range_min},{range_max}, " - f"{self.nbins})" - ).view(self.names["range"]) - def create_count_tables( self, tables: list[Table], column: str | None = None ) -> Generator[tuple[Table, str], None, None]: @@ -134,7 +94,7 @@ def create_count_tables( """ range_index, range_ = self.names["range_index"], self.names["range"] - agg_func = HISTFUNC_MAP[self.histfunc] + agg_func = HISTFUNC_AGGS[self.histfunc] if not self.range_table: raise ValueError("Range table not created") for i, table in enumerate(tables): diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py index 808c5a75f..1cba930a0 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/Preprocessor.py @@ -9,6 +9,7 @@ from .FreqPreprocessor import FreqPreprocessor from .HistPreprocessor import HistPreprocessor from .TimePreprocessor import TimePreprocessor +from .HeatmapPreprocessor import HeatmapPreprocessor class Preprocessor: @@ -54,6 +55,8 @@ def prepare_preprocess(self) -> None: AttachedPreprocessor(self.args, self.always_attached) elif "preprocess_time" in self.groups: self.preprocesser = TimePreprocessor(self.args) + elif "preprocess_heatmap" in self.groups: + self.preprocesser = HeatmapPreprocessor(self.args) def preprocess_partitioned_tables( self, tables: list[Table] | None, column: str | None = None diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py new file mode 100644 index 000000000..5b3424772 --- /dev/null +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from typing import Generator + +from deephaven import agg, empty_table +from deephaven.plot.express.shared import get_unique_names +from deephaven.table import PartitionedTable, Table + +# Used to aggregate within bins +HISTFUNC_AGGS = { + "abs_sum": agg.abs_sum, + "avg": agg.avg, + "count": agg.count_, + "count_distinct": agg.count_distinct, + "max": agg.max_, + "median": agg.median, + "min": agg.min_, + "std": agg.std, + "sum": agg.sum_, + "var": agg.var, +} + + +def get_aggs( + base: str, + columns: list[str], +) -> tuple[list[str], str]: + """Create aggregations over all columns + + Args: + base: + The base of the new columns that store the agg per column + columns: + All columns joined for the sake of taking min or max over + the columns + + Returns: + A tuple containing (a list of the new columns, + a joined string of "NewCol, NewCol2...") + + """ + return ( + [f"{base}{column}={column}" for column in columns], + ", ".join([f"{base}{column}" for column in columns]), + ) + + +def single_table(table: Table | PartitionedTable) -> Table: + """ + Merge a table if it is partitioned table + + Args: + table: The table to merge + + Returns: + The table if it is not a partitioned table, otherwise the merged table + """ + return table.merge() if isinstance(table, PartitionedTable) else table + + +def discretized_range_view( + table: Table, + range_min: float | str, + range_max: float | str, + nbins: int, + range_name: str, +) -> Table: + """ + Create a discretized range view that can be joined with a table to compute indices + + Args: + table: The table to create the range view from + range_min: The minimum value of the range. Can be a number or a column name + range_max: The maximum value of the range. Can be a number or a column name + nbins: The number of bins to create + range_name: The name of the range object in the resulting table + + Returns: + A table that contains the range object for the given table + """ + + return table.update( + f"{range_name} = new io.deephaven.plot.datasets.histogram." + f"DiscretizedRangeEqual({range_min},{range_max}, " + f"{nbins})" + ).view(range_name) + + +def create_range_table( + table: Table, + cols: str | list[str], + range_bins: list[float | None] | None, + nbins: int, + range_name: str, +) -> Table: + """ + Create single row tables with range objects that can compute bin membership + + Args: + table: The table to create the range table from + cols: The columns to create the range table from. The resulting range table will have + its range calculated over all of these columns. + range_bins: The range to create the bins over. + If None, the range will be calculated over the columns. + If a list of two numbers, the range will be set to these numbers. + The values within this list can also be None, in which case the range will be calculated over the columns + for whichever value is None. + nbins: The number of bins to create + range_name: The name of the range object in the resulting table + + Returns: + A table that contains the range object for the given + """ + + cols = [cols] if isinstance(cols, str) else cols + + range_min = ( + range_bins[0] if range_bins and range_bins[0] is not None else "RangeMin" + ) + range_max = ( + range_bins[1] if range_bins and range_bins[1] is not None else "RangeMax" + ) + + min_table = empty_table(1) + max_table = empty_table(1) + + if range_min == "RangeMin": + min_aggs, min_cols = get_aggs("RangeMin", cols) + min_table = table.agg_by([agg.min_(min_aggs)]).update( + [f"RangeMin = min({min_cols})"] + ) + if range_max == "RangeMax": + max_aggs, max_cols = get_aggs("RangeMax", cols) + max_table = table.agg_by([agg.max_(max_aggs)]).update( + [f"RangeMax = max({max_cols})"] + ) + + return discretized_range_view( + min_table.join(max_table), range_min, range_max, nbins, range_name + ) + + +def validate_heatmap_histfunc( + z: str | None, + histfunc: str +) -> None: + """ + Check if the histfunc is valid + + Args: + z: The column that contains z-axis values. + histfunc: The function to use when aggregating within bins. Should be 'count' if z is None. + + Raises: + ValueError: If the histfunc is not valid + """ + if z is None and histfunc != "count": + raise ValueError("z must be specified for histfunc other than count") + elif histfunc not in HISTFUNC_AGGS: + raise ValueError(f"{histfunc} is not a valid histfunc") + + +def create_tmp_view( + names: dict[str, str], +) -> list[str]: + """ + Create a temporary view that avoids column name collisions + + Args: + names: The names used for columns so that they don't collide + + Returns: + A list of strings that are used to create a temporary view + """ + + x = names["x"] + y = names["y"] + z = names["z"] + tmp_x = names["tmp_x"] + tmp_y = names["tmp_y"] + agg_col = names["agg_col"] + + tmp_view = [f"{tmp_x} = {x}", f"{tmp_y} = {y}"] + + if z is not None: + tmp_view.append(f"{agg_col} = {z}") + else: + # if z is not specified, just count the number of occurrences, so tmp_x or tmp_y can be used + names["agg_col"] = tmp_x + + return tmp_view + + +def aggregate_heatmap_bins( + table: Table, + names: dict[str, str], + histfunc: str, +) -> Table: + """ + Create count tables that aggregate up values into bins + + Args: + table: The table to aggregate. Should contain the tmp data columns and the range columns + names: The names used for columns so that they don't collide + histfunc: The function to use when aggregating within bins. Should be 'count' if z is None. + + Yields: + A tuple containing the table and a temporary column that contains the aggregated values + """ + + range_x = names["range_x"] + range_y = names["range_y"] + range_index_x = names["range_index_x"] + range_index_y = names["range_index_y"] + + tmp_x = names["tmp_x"] + tmp_y = names["tmp_y"] + agg_col = names["agg_col"] + + count_table = ( + table + .update_view( + [ + f"{range_index_x} = {range_x}.index({tmp_x})", + f"{range_index_y} = {range_y}.index({tmp_y})", + ] + ) + .where([f"!isNull({range_index_x})", f"!isNull({range_index_y})"]) + .agg_by([HISTFUNC_AGGS[histfunc](agg_col)], [range_index_x, range_index_y]) + ) + return count_table + + +def calculate_bin_locations( + ranged_bin_counts: Table, + names: dict[str, str], + histfunc_col: str, +) -> Table: + """ + Compute the center of the bins for the x and y axes + plotly requires the center of the bins to plot the heatmap and will calculate the width automatically + + Args + bin_counts_ranged: A table that contains the bin counts and the range columns + names: The names used for columns so that they don't collide + histfunc_col: The column that contains the aggregated values + + Returns: + A table that contains the bin counts and the center of the bins + """ + range_index_x = names["range_index_x"] + range_index_y = names["range_index_y"] + range_x = names["range_x"] + range_y = names["range_y"] + bin_min_x = names["bin_min_x"] + bin_max_x = names["bin_max_x"] + bin_min_y = names["bin_min_y"] + bin_max_y = names["bin_max_y"] + x = names["x"] + y = names["y"] + agg_col = names["agg_col"] + + return ( + ranged_bin_counts + .update_view( + [ + f"{bin_min_x} = {range_x}.binMin({range_index_x})", + f"{bin_max_x} = {range_x}.binMax({range_index_x})", + f"{x}=0.5*({bin_min_x}+{bin_max_x})", + f"{bin_min_y} = {range_y}.binMin({range_index_y})", + f"{bin_max_y} = {range_y}.binMax({range_index_y})", + f"{y}=0.5*({bin_min_y}+{bin_max_y})", + f"{histfunc_col} = {agg_col}", + ] + ) + ) From aa65d9a9c298fcd2705486f206bf1bd2ec1bf55a Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 09:22:47 -0500 Subject: [PATCH 02/11] wip --- .../express/deephaven_figure/custom_draw.py | 42 +++++------ .../plot/express/deephaven_figure/generate.py | 4 +- .../express/preprocess/HeatmapPreprocessor.py | 16 +++-- .../plot/express/preprocess/utilities.py | 71 +++++++++---------- 4 files changed, 66 insertions(+), 67 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index e4e562f6e..9c19cc06f 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -10,13 +10,13 @@ def draw_finance( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], - go_func: Callable, + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], + go_func: Callable, ) -> Figure: """Draws a finance (OHLC or candlestick) chart @@ -34,7 +34,7 @@ def draw_finance( """ if not all(len(open) == len(ls) for ls in [high, low, close]) and ( - len(open) == len(x_finance) or len(x_finance) == 1 + len(open) == len(x_finance) or len(x_finance) == 1 ): raise ValueError( "open, high, low, close must have same length and x " @@ -44,7 +44,7 @@ def draw_finance( data = [] for x_f, o, h, l, c in zip_longest( - x_finance, open, high, low, close, fillvalue=x_finance[0] + x_finance, open, high, low, close, fillvalue=x_finance[0] ): data.append( go_func( @@ -60,12 +60,12 @@ def draw_finance( def draw_ohlc( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], ) -> Figure: """Create a plotly OHLC chart. @@ -86,12 +86,12 @@ def draw_ohlc( def draw_candlestick( - data_frame: DataFrame, - x_finance: str | list[str], - open: str | list[str], - high: str | list[str], - low: str | list[str], - close: str | list[str], + data_frame: DataFrame, + x_finance: str | list[str], + open: str | list[str], + high: str | list[str], + low: str | list[str], + close: str | list[str], ) -> Figure: """Create a plotly candlestick chart. diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 97eaa3b6e..6abef3384 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -972,7 +972,9 @@ def create_hover_and_axis_titles( current_partition = custom_call_args.get("current_partition", {}) - compute_labels(hover_mapping, hist_val_name, heatmap_title, types, labels, current_partition) + compute_labels( + hover_mapping, hist_val_name, heatmap_title, types, labels, current_partition + ) hover_text = hover_text_generator(hover_mapping, types, current_partition) diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py index 78c66c932..e99029c12 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -6,8 +6,13 @@ from deephaven.column import long_col from ..shared import get_unique_names -from .utilities import create_range_table, validate_heatmap_histfunc, create_tmp_view, \ - aggregate_heatmap_bins, calculate_bin_locations +from .utilities import ( + create_range_table, + validate_heatmap_histfunc, + create_tmp_view, + aggregate_heatmap_bins, + calculate_bin_locations, +) from deephaven.table import Table @@ -107,9 +112,7 @@ def preprocess_partitioned_tables( # filter to only the tmp (data) columns, and join the range table to the tmp ranged_tmp_view = table.view(tmp_view).join(range_table) - agg_table = aggregate_heatmap_bins( - ranged_tmp_view, self.names, self.histfunc - ) + agg_table = aggregate_heatmap_bins(ranged_tmp_view, self.names, self.histfunc) # join the aggregated values to the already created comprehensive bin table bin_counts = bin_counts.natural_join( @@ -126,5 +129,6 @@ def preprocess_partitioned_tables( heatmap_title = f"{self.histfunc} of {z}" if z else self.histfunc yield bin_counts_with_midpoint.view([x, y, histfunc_col]), { - "z": histfunc_col, "heatmap_title": heatmap_title + "z": histfunc_col, + "heatmap_title": heatmap_title, } diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py index 5b3424772..6d20eb8f1 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py @@ -22,8 +22,8 @@ def get_aggs( - base: str, - columns: list[str], + base: str, + columns: list[str], ) -> tuple[list[str], str]: """Create aggregations over all columns @@ -59,11 +59,11 @@ def single_table(table: Table | PartitionedTable) -> Table: def discretized_range_view( - table: Table, - range_min: float | str, - range_max: float | str, - nbins: int, - range_name: str, + table: Table, + range_min: float | str, + range_max: float | str, + nbins: int, + range_name: str, ) -> Table: """ Create a discretized range view that can be joined with a table to compute indices @@ -87,11 +87,11 @@ def discretized_range_view( def create_range_table( - table: Table, - cols: str | list[str], - range_bins: list[float | None] | None, - nbins: int, - range_name: str, + table: Table, + cols: str | list[str], + range_bins: list[float | None] | None, + nbins: int, + range_name: str, ) -> Table: """ Create single row tables with range objects that can compute bin membership @@ -140,10 +140,7 @@ def create_range_table( ) -def validate_heatmap_histfunc( - z: str | None, - histfunc: str -) -> None: +def validate_heatmap_histfunc(z: str | None, histfunc: str) -> None: """ Check if the histfunc is valid @@ -161,11 +158,11 @@ def validate_heatmap_histfunc( def create_tmp_view( - names: dict[str, str], + names: dict[str, str], ) -> list[str]: """ Create a temporary view that avoids column name collisions - + Args: names: The names used for columns so that they don't collide @@ -192,9 +189,9 @@ def create_tmp_view( def aggregate_heatmap_bins( - table: Table, - names: dict[str, str], - histfunc: str, + table: Table, + names: dict[str, str], + histfunc: str, ) -> Table: """ Create count tables that aggregate up values into bins @@ -218,8 +215,7 @@ def aggregate_heatmap_bins( agg_col = names["agg_col"] count_table = ( - table - .update_view( + table.update_view( [ f"{range_index_x} = {range_x}.index({tmp_x})", f"{range_index_y} = {range_y}.index({tmp_y})", @@ -232,9 +228,9 @@ def aggregate_heatmap_bins( def calculate_bin_locations( - ranged_bin_counts: Table, - names: dict[str, str], - histfunc_col: str, + ranged_bin_counts: Table, + names: dict[str, str], + histfunc_col: str, ) -> Table: """ Compute the center of the bins for the x and y axes @@ -260,17 +256,14 @@ def calculate_bin_locations( y = names["y"] agg_col = names["agg_col"] - return ( - ranged_bin_counts - .update_view( - [ - f"{bin_min_x} = {range_x}.binMin({range_index_x})", - f"{bin_max_x} = {range_x}.binMax({range_index_x})", - f"{x}=0.5*({bin_min_x}+{bin_max_x})", - f"{bin_min_y} = {range_y}.binMin({range_index_y})", - f"{bin_max_y} = {range_y}.binMax({range_index_y})", - f"{y}=0.5*({bin_min_y}+{bin_max_y})", - f"{histfunc_col} = {agg_col}", - ] - ) + return ranged_bin_counts.update_view( + [ + f"{bin_min_x} = {range_x}.binMin({range_index_x})", + f"{bin_max_x} = {range_x}.binMax({range_index_x})", + f"{x}=0.5*({bin_min_x}+{bin_max_x})", + f"{bin_min_y} = {range_y}.binMin({range_index_y})", + f"{bin_max_y} = {range_y}.binMax({range_index_y})", + f"{y}=0.5*({bin_min_y}+{bin_max_y})", + f"{histfunc_col} = {agg_col}", + ] ) From 6fd881b34e7cefd55b787c41e59d709eec851323 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 13:47:57 -0500 Subject: [PATCH 03/11] wip --- .../communication/DeephavenFigureListener.py | 2 ++ .../plot/express/deephaven_figure/custom_draw.py | 16 ++++++++-------- .../src/deephaven/plot/express/plots/heatmap.py | 14 +++++++------- .../express/preprocess/HeatmapPreprocessor.py | 4 ++-- .../plot/express/preprocess/HistPreprocessor.py | 2 +- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py index 288d8dfa1..cab5def10 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py +++ b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py @@ -67,8 +67,10 @@ def _setup_listeners(self) -> None: Setup listeners for the partitioned tables """ for table, node in self._partitioned_tables.values(): + print(table, node) listen_func = partial(self._on_update, node) handle = listen(table, listen_func) + print(handle) self._handles.append(handle) self._liveness_scope.manage(handle) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index 9c19cc06f..431eaaea3 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -118,11 +118,11 @@ def draw_density_heatmap( y: str, z: str, range_color: list[float] | None = None, - color_continuous_scale: str = "Viridis", - color_continuous_midpoint=None, - opacity=1.0, - title=None, - template=None, + color_continuous_scale: str | None = "Viridis", + color_continuous_midpoint: list[float] | None = None, + opacity: float = 1.0, + title: str | None = None, + template: str | None = None, ) -> Figure: """Create a density heatmap @@ -160,15 +160,15 @@ def draw_density_heatmap( ) ) - range_color = range_color or [None, None] + range_color_list = range_color or [None, None] colorscale_validator = ColorscaleValidator("colorscale", "make_figure") coloraxis_layout = dict( colorscale=colorscale_validator.validate_coerce(color_continuous_scale), cmid=color_continuous_midpoint, - cmin=range_color[0], - cmax=range_color[1], + cmin=range_color_list[0], + cmax=range_color_list[1], ) heatmap.update_layout( diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py index 9bee2ce14..613b514c4 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -14,22 +14,22 @@ def density_heatmap( x: str | None = None, y: str | None = None, z: str | None = None, - labels: dict[str, str] = None, + labels: dict[str, str] | None = None, color_continuous_scale: str = "Viridis", - range_color: list[float] = None, - color_continuous_midpoint: float = None, + range_color: list[float] | None = None, + color_continuous_midpoint: float | None = None, opacity: float = 1.0, log_x: bool = False, log_y: bool = False, range_x: list[float] | None = None, range_y: list[float] | None = None, - range_bins_x: list[float | None] = None, - range_bins_y: list[float | None] = None, + range_bins_x: list[float | None] | None = None, + range_bins_y: list[float | None] | None = None, histfunc: str = "count", nbinsx: int = 10, nbinsy: int = 10, - title: str = None, - template: str = None, + title: str | None = None, + template: str | None = None, unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """ diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py index e99029c12..71a4346d2 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -91,10 +91,10 @@ def preprocess_partitioned_tables( table = tables[0] range_table_x = create_range_table( - table, x, self.range_bins_x, self.nbinsx, range_name=range_x + table, x, self.range_bins_x, self.nbinsx, range_x ) range_table_y = create_range_table( - table, y, self.range_bins_y, self.nbinsy, range_name=range_y + table, y, self.range_bins_y, self.nbinsy, range_y ) range_table = range_table_x.join(range_table_y) diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py index 46971b33b..197b0b73d 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HistPreprocessor.py @@ -73,8 +73,8 @@ def prepare_preprocess(self) -> None: ) self.range_table = create_range_table( self.args["table"], + self.cols, self.range_bins, - [self.var], self.nbins, self.names["range"], ) From fd5191e658d957afc33e9f150066a200f91bb620 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 13:57:19 -0500 Subject: [PATCH 04/11] wip --- .../plot/express/deephaven_figure/generate.py | 37 +++++++++++-------- .../express/preprocess/HeatmapPreprocessor.py | 17 +++++++-- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 6abef3384..9dd0de015 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -116,7 +116,7 @@ "current_partition", "colors", "unsafe_update_figure", - "heatmap_title", + "heatmap_agg_label", } # these are columns that are "attached" sequentially to the traces @@ -671,7 +671,7 @@ def handle_custom_args( elif arg == "bargap" or arg == "rangemode": fig.update_layout({arg: val}) - elif arg == "heatmap_title": + elif arg == "heatmap_agg_label": fig.update_coloraxes(colorbar_title_text=val) trace_generator = combined_generator(trace_generators) @@ -826,7 +826,7 @@ def hover_text_generator( def compute_labels( hover_mapping: list[dict[str, str]], hist_val_name: str | None, - heatmap_title: str | None, + heatmap_agg_label: str | None, # hover_data - todo, dependent on arrays supported in data mappings types: set[str], labels: dict[str, str] | None, @@ -839,7 +839,7 @@ def compute_labels( Args: hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc - heatmap_title: The aggregate density heatmap column title + heatmap_agg_label: The aggregate density heatmap column title types: Any types of this chart that require special processing labels: A dictionary of old column name to new column name mappings current_partition: The columns that this figure is partitioned by @@ -850,25 +850,25 @@ def compute_labels( calculate_hist_labels(hist_val_name, hover_mapping[0]) - calculate_density_heatmap_labels(heatmap_title, hover_mapping[0]) + calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0]) relabel_columns(labels, hover_mapping, types, current_partition) def calculate_density_heatmap_labels( - heatmap_title: str | None, + heatmap_agg_label: str | None, hover_mapping: dict[str, str], ) -> None: """Calculate the labels for a density heatmap The z column is renamed to the colorbar title Args: - heatmap_title: The title of the colorbar + heatmap_agg_label: The title of the colorbar hover_mapping: The mapping of variables to columns """ - if heatmap_title: - hover_mapping["z"] = heatmap_title + if heatmap_agg_label: + hover_mapping["z"] = heatmap_agg_label def calculate_hist_labels( @@ -893,7 +893,7 @@ def add_axis_titles( custom_call_args: dict[str, Any], hover_mapping: list[dict[str, str]], hist_val_name: str | None, - heatmap_title: str | None, + heatmap_agg_label: str | None, ) -> None: """Add axis titles. Generally, this only applies when there is a list variable @@ -902,7 +902,7 @@ def add_axis_titles( create hover and axis titles hover_mapping: The mapping of variables to columns hist_val_name: The histogram name for the value axis, generally histfunc - heatmap_title: The aggregate density heatmap column title + heatmap_agg_label: The aggregate density heatmap column title """ # Although hovertext is handled above for all plot types, plotly still @@ -916,8 +916,8 @@ def add_axis_titles( new_xaxis_titles = [hover_mapping[0].get("x", None)] new_yaxis_titles = [hover_mapping[0].get("y", None)] - if heatmap_title: - custom_call_args["heatmap_title"] = heatmap_title + if heatmap_agg_label: + custom_call_args["heatmap_agg_label"] = heatmap_agg_label # a specified axis title update should override this if new_xaxis_titles: @@ -968,17 +968,22 @@ def create_hover_and_axis_titles( labels = custom_call_args.get("labels", None) hist_val_name = custom_call_args.get("hist_val_name", None) - heatmap_title = custom_call_args.get("heatmap_title", None) + heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None) current_partition = custom_call_args.get("current_partition", {}) compute_labels( - hover_mapping, hist_val_name, heatmap_title, types, labels, current_partition + hover_mapping, + hist_val_name, + heatmap_agg_label, + types, + labels, + current_partition, ) hover_text = hover_text_generator(hover_mapping, types, current_partition) - add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_title) + add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label) return hover_text diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py index 71a4346d2..9a0196ac8 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -22,8 +22,13 @@ class HeatmapPreprocessor: Attributes: args: dict[str, Any]: The arguments used to create the plot - range_table: The range table, calculated over the whole original table - + histfunc: str: The histfunc to use + nbinsx: int: The number of bins in the x direction + nbinsy: int: The number of bins in the y direction + range_bins_x: list[float | None]: The range of the x bins + range_bins_y: list[float | None]: The range of the y bins + names: dict[str, str]: A mapping of ideal name to unique names + Also contains the names of the x, y, and z columns for ease of use """ def __init__(self, args: dict[str, Any]): @@ -70,9 +75,13 @@ def preprocess_partitioned_tables( Args: tables: a list of tables to preprocess column: the column to aggregate on + ignored for this preprocessor because heatmap always gets the joint count + distribution of x and y or the histfunc of z depending on if z is provided Returns: A tuple containing (the new table, an update to make to the args) + The update should contain the z column name and the heatmap_agg_label + which is the histfunc of z if z is not None, otherwise just the histfunc """ @@ -126,9 +135,9 @@ def preprocess_partitioned_tables( ranged_bin_counts, self.names, histfunc_col ) - heatmap_title = f"{self.histfunc} of {z}" if z else self.histfunc + heatmap_agg_label = f"{self.histfunc} of {z}" if z else self.histfunc yield bin_counts_with_midpoint.view([x, y, histfunc_col]), { "z": histfunc_col, - "heatmap_title": heatmap_title, + "heatmap_agg_label": heatmap_agg_label, } From a3e3ebf3c660019d238b294e9805eb87ab2c23b6 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 13:58:34 -0500 Subject: [PATCH 05/11] wip --- .../src/deephaven/plot/express/deephaven_figure/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 9dd0de015..a8a81229b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -860,7 +860,7 @@ def calculate_density_heatmap_labels( hover_mapping: dict[str, str], ) -> None: """Calculate the labels for a density heatmap - The z column is renamed to the colorbar title + The z column is renamed to the heatmap_agg_label Args: heatmap_agg_label: The title of the colorbar From d7cf130c698966e33b4173e8cb665ddd17fd642f Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 13:58:57 -0500 Subject: [PATCH 06/11] wip --- .../src/deephaven/plot/express/deephaven_figure/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index a8a81229b..764b35d28 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -863,7 +863,7 @@ def calculate_density_heatmap_labels( The z column is renamed to the heatmap_agg_label Args: - heatmap_agg_label: The title of the colorbar + heatmap_agg_label: The name of the heatmap aggregate label hover_mapping: The mapping of variables to columns """ From 9197227e0261649f3fa53e088bd17c5ddbe9f22b Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 3 Jul 2024 15:39:52 -0500 Subject: [PATCH 07/11] wip --- .../src/deephaven/plot/express/deephaven_figure/custom_draw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index 431eaaea3..ac903192b 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -162,7 +162,7 @@ def draw_density_heatmap( range_color_list = range_color or [None, None] - colorscale_validator = ColorscaleValidator("colorscale", "make_figure") + colorscale_validator = ColorscaleValidator("colorscale", "draw_density_heatmap") coloraxis_layout = dict( colorscale=colorscale_validator.validate_coerce(color_continuous_scale), From 92beb93cffe73087f20d37add0a73a87ec631931 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 10 Jul 2024 12:42:12 -0500 Subject: [PATCH 08/11] wip --- .../express/deephaven_figure/custom_draw.py | 6 + .../plot/express/deephaven_figure/generate.py | 20 +- .../deephaven/plot/express/plots/heatmap.py | 4 +- .../test/deephaven/plot/express/BaseTest.py | 17 + .../plot/express/plots/test_heatmap.py | 320 ++++++++++++++++++ .../preprocess/test_HeatmapPreprocessor.py | 317 +++++++++++++++++ 6 files changed, 680 insertions(+), 4 deletions(-) create mode 100644 plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py create mode 100644 plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index ac903192b..e07e1bc4a 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -117,6 +117,7 @@ def draw_density_heatmap( x: str, y: str, z: str, + labels: dict[str, str] | None = None, range_color: list[float] | None = None, color_continuous_scale: str | None = "Viridis", color_continuous_midpoint: list[float] | None = None, @@ -131,6 +132,7 @@ def draw_density_heatmap( x: The name of the column containing x-axis values y: The name of the column containing y-axis values z: The name of the column containing bin values + labels: A dictionary of labels mapping columns to new labels color_continuous_scale: A list of colors for a continuous scale range_color: A list of two numbers that form the endpoints of the color axis color_continuous_midpoint: A number that is the midpoint of the color axis @@ -171,6 +173,10 @@ def draw_density_heatmap( cmax=range_color_list[1], ) + if labels: + x = labels.get(x, x) + y = labels.get(y, y) + heatmap.update_layout( coloraxis1=coloraxis_layout, title=title, diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 764b35d28..9d84aea6a 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -850,7 +850,7 @@ def compute_labels( calculate_hist_labels(hist_val_name, hover_mapping[0]) - calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0]) + calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels) relabel_columns(labels, hover_mapping, types, current_partition) @@ -858,6 +858,7 @@ def compute_labels( def calculate_density_heatmap_labels( heatmap_agg_label: str | None, hover_mapping: dict[str, str], + labels: dict[str, str] | None, ) -> None: """Calculate the labels for a density heatmap The z column is renamed to the heatmap_agg_label @@ -865,10 +866,18 @@ def calculate_density_heatmap_labels( Args: heatmap_agg_label: The name of the heatmap aggregate label hover_mapping: The mapping of variables to columns + labels: A dictionary of labels mapping columns to new labels. """ + labels = labels or {} if heatmap_agg_label: - hover_mapping["z"] = heatmap_agg_label + # the last part of the label is the z column, and could be replaced by labels + split_label = heatmap_agg_label.split(" ") + split_label[-1] = labels.get(split_label[-1], split_label[-1]) + # it's also possible that someone wants to override the whole label + # plotly doesn't seem to do that, but it seems reasonable to allow + new_label = " ".join(split_label) + hover_mapping["z"] = labels.get(new_label, new_label) def calculate_hist_labels( @@ -955,6 +964,9 @@ def create_hover_and_axis_titles( (such as the y-axis if the x-axis is specified). Otherwise, there is a legend or not depending on if there is a list of columns or not. + Density heatmaps are also an exception. If "heatmap_agg_label" is specified, + the z column is renamed to this label. + Args: custom_call_args: The custom_call_args that are used to create hover and axis titles @@ -983,6 +995,10 @@ def create_hover_and_axis_titles( hover_text = hover_text_generator(hover_mapping, types, current_partition) + if heatmap_agg_label: + # it's possible that heatmap_agg_label was relabeled, so grab the new label + heatmap_agg_label = hover_mapping[0]["z"] + add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label) return hover_text diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py index 613b514c4..0d85928d3 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -15,7 +15,7 @@ def density_heatmap( y: str | None = None, z: str | None = None, labels: dict[str, str] | None = None, - color_continuous_scale: str = "Viridis", + color_continuous_scale: str | None = None, range_color: list[float] | None = None, color_continuous_midpoint: float | None = None, opacity: float = 1.0, @@ -33,7 +33,7 @@ def density_heatmap( unsafe_update_figure: Callable = default_callback, ) -> DeephavenFigure: """ - Create a density heatmap + A density heatmap creates a grid of colored bins. Each bin represents an aggregation of data points in that region. Args: table: A table to pull data from. diff --git a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py index d043c39c0..6b1a88315 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py +++ b/plugins/plotly-express/test/deephaven/plot/express/BaseTest.py @@ -1,9 +1,26 @@ import unittest from unittest.mock import patch +import pandas as pd from deephaven_server import Server +def remap_types( + df: pd.DataFrame, +) -> None: + """ + Remap the types of the columns to the correct types + + Args: + df: The dataframe to remap the types of + """ + for col in df.columns: + if df[col].dtype == "int64": + df[col] = df[col].astype("Int64") + elif df[col].dtype == "float64": + df[col] = df[col].astype("Float64") + + class BaseTestCase(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py new file mode 100644 index 000000000..3aefe5e31 --- /dev/null +++ b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py @@ -0,0 +1,320 @@ +import unittest + +from ..BaseTest import BaseTestCase + + +class HeatmapTestCase(BaseTestCase): + def setUp(self) -> None: + from deephaven import new_table + from deephaven.column import int_col + + self.source = new_table( + [ + int_col("X", [1, 2, 3, 4, 5]), + int_col("Y", [1, 2, 3, 4, 5]), + int_col("Z", [1, 2, 3, 4, 5]), + ] + ) + + def test_basic_heatmap(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE + + chart = dx.density_heatmap(self.source, x="X", y="Y").to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "X=%{x}
Y=%{y}
count=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count"}}, + "colorscale": [ + [0.0, "#440154"], + [0.1111111111111111, "#482878"], + [0.2222222222222222, "#3e4989"], + [0.3333333333333333, "#31688e"], + [0.4444444444444444, "#26828e"], + [0.5555555555555556, "#1f9e89"], + [0.6666666666666666, "#35b779"], + [0.7777777777777778, "#6ece58"], + [0.8888888888888888, "#b5de2b"], + [1.0, "#fde725"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + "table": 0, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_heatmap_relabel_z(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={"X": "Column X", "Y": "Column Y", "Z": "Column Z"}, + ).to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
count of Column Z=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count of Column Z"}}, + "colorscale": [ + [0.0, "#440154"], + [0.1111111111111111, "#482878"], + [0.2222222222222222, "#3e4989"], + [0.3333333333333333, "#31688e"], + [0.4444444444444444, "#26828e"], + [0.5555555555555556, "#1f9e89"], + [0.6666666666666666, "#35b779"], + [0.7777777777777778, "#6ece58"], + [0.8888888888888888, "#b5de2b"], + [1.0, "#fde725"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "table": 0, + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_heatmap_relabel_agg_z(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={"X": "Column X", "Y": "Column Y", "count of Z": "count"}, + ).to_dict(self.exporter) + + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
count=%{z}", + "opacity": 1.0, + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + "type": "heatmap", + } + ] + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "colorbar": {"title": {"text": "count"}}, + "colorscale": [ + [0.0, "#440154"], + [0.1111111111111111, "#482878"], + [0.2222222222222222, "#3e4989"], + [0.3333333333333333, "#31688e"], + [0.4444444444444444, "#26828e"], + [0.5555555555555556, "#1f9e89"], + [0.6666666666666666, "#35b779"], + [0.7777777777777778, "#6ece58"], + [0.8888888888888888, "#b5de2b"], + [1.0, "#fde725"], + ], + }, + "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, + "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "table": 0, + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "count": ["/plotly/data/0/z"], + }, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + def test_full_heatmap(self): + import src.deephaven.plot.express as dx + from deephaven.constants import NULL_LONG, NULL_DOUBLE, NULL_INT + + chart = dx.density_heatmap( + self.source, + x="X", + y="Y", + z="Z", + labels={ + "X": "Column X", + "Y": "Column Y", + "Z": "Column Z", + "sum of Column Z": "sum", + }, + color_continuous_scale="Magma", + range_color=[0, 10], + color_continuous_midpoint=5, + opacity=0.5, + log_x=True, + log_y=False, + range_x=[0, 1], + range_y=[0, 10], + range_bins_x=[5, 10], + range_bins_y=[5, 10], + histfunc="sum", + nbinsx=2, + nbinsy=2, + title="Test Title", + ).to_dict(self.exporter) + plotly, deephaven = chart["plotly"], chart["deephaven"] + + # pop template as we currently do not modify it + plotly["layout"].pop("template") + + expected_data = [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Column X=%{x}
Column Y=%{y}
sum=%{z}", + "opacity": 0.5, + "type": "heatmap", + "x": [NULL_DOUBLE], + "y": [NULL_DOUBLE], + "z": [NULL_LONG], + } + ] + + self.assertEqual(plotly["data"], expected_data) + + expected_layout = { + "coloraxis": { + "cmax": 10, + "cmid": 5, + "cmin": 0, + "colorbar": {"title": {"text": "sum"}}, + "colorscale": [ + [0.0, "#000004"], + [0.1111111111111111, "#180f3d"], + [0.2222222222222222, "#440f76"], + [0.3333333333333333, "#721f81"], + [0.4444444444444444, "#9e2f7f"], + [0.5555555555555556, "#cd4071"], + [0.6666666666666666, "#f1605d"], + [0.7777777777777778, "#fd9668"], + [0.8888888888888888, "#feca8d"], + [1.0, "#fcfdbf"], + ], + }, + "title": {"text": "Test Title"}, + "xaxis": { + "anchor": "y", + "range": [0, 1], + "side": "bottom", + "title": {"text": "Column X"}, + "type": "log", + }, + "yaxis": { + "anchor": "x", + "range": [0, 10], + "side": "left", + "title": {"text": "Column Y"}, + }, + } + + self.assertEqual(plotly["layout"], expected_layout) + + expected_mappings = [ + { + "data_columns": { + "X": ["/plotly/data/0/x"], + "Y": ["/plotly/data/0/y"], + "sum": ["/plotly/data/0/z"], + }, + "table": 0, + } + ] + + self.assertEqual(deephaven["mappings"], expected_mappings) + + self.assertEqual(deephaven["is_user_set_template"], False) + self.assertEqual(deephaven["is_user_set_color"], False) + + +if __name__ == "__main__": + unittest.main() diff --git a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py new file mode 100644 index 000000000..765ba39f7 --- /dev/null +++ b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py @@ -0,0 +1,317 @@ +import unittest + +import pandas as pd + +from ..BaseTest import BaseTestCase, remap_types + + +class HeatmapPreprocessorTestCase(BaseTestCase): + def setUp(self) -> None: + from deephaven import new_table, merge + from deephaven.column import int_col + + self.source = new_table( + [ + int_col("X", [0, 4, 0, 4]), + int_col("Y", [0, 0, 4, 4]), + int_col("Z", [1, 1, 2, 2]), + ] + ) + + # Variance and standard deviation require at least two rows to be non-null, so stack them + # for a resulting var and std of 0 in all grid cells + self.var_source = merge([self.source, self.source]) + + def tables_equal(self, args, expected_df, t=None): + if t is None: + t = self.source + + args_copy = args.copy() + + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + import deephaven.pandas as dhpd + + heatmap_preprocessor = HeatmapPreprocessor(args_copy) + + new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables([t]) + new_table, _ = next(new_table_gen) + + new_df = dhpd.to_pandas(new_table) + + self.assertTrue(expected_df.equals(new_df)) + + def test_basic_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_basic_preprocessor_z(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]} + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_partial_range_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": [None, 6], + "range_bins_y": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.5, 1.5, 4.5, 4.5], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + + self.tables_equal(args, expected_df) + + def test_full_preprocessor(self): + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 3, + "range_bins_x": [1, 5], + "range_bins_y": [-1, 5], + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [2.0, 2.0, 2.0, 4.0, 4.0, 4.0], + "Y": [0.0, 2.0, 4.0, 0.0, 2.0, 4.0], + "count": [pd.NA, pd.NA, pd.NA, 1, pd.NA, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + def test_full_preprocessor_z(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "avg", + "nbinsx": 3, + "nbinsy": 2, + "range_bins_x": [-1, 5], + "range_bins_y": [1, 5], + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [0.0, 0.0, 2.0, 2.0, 4.0, 4.0], + "Y": [2.0, 4.0, 2.0, 4.0, 2.0, 4.0], + "avg": [pd.NA, 2.0, pd.NA, pd.NA, pd.NA, 2.0], + } + ) + remap_types(expected_df) + expected_df["avg"] = expected_df["avg"].astype("Float64") + + self.tables_equal(args, expected_df) + + def test_preprocessor_aggs(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "abs_sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "table": self.source, + } + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "abs_sum": [1, 2, 1, 2], + } + ) + remap_types(expected_df) + expected_df["abs_sum"] = expected_df["abs_sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "avg" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "avg": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["avg"] = expected_df["avg"].astype("Float64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count_distinct" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "max" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "max": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["max"] = expected_df["max"].astype("Int32") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "median" + + expected_df = pd.DataFrame( + { + "X": [1.0, 1.0, 3.0, 3.0], + "Y": [1.0, 3.0, 1.0, 3.0], + "median": [1, 2, 1, 2], + } + ) + remap_types(expected_df) + expected_df["median"] = expected_df["median"].astype("Float64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "min" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "min": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["min"] = expected_df["min"].astype("Int32") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "std" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "std": [0, 0, 0, 0]} + ) + remap_types(expected_df) + expected_df["std"] = expected_df["std"].astype("Float64") + + self.tables_equal(args, expected_df, self.var_source) + + args["histfunc"] = "sum" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "sum": [1, 2, 1, 2]} + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "var" + + expected_df = pd.DataFrame( + {"X": [1.0, 1.0, 3.0, 3.0], "Y": [1.0, 3.0, 1.0, 3.0], "var": [0, 0, 0, 0]} + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, self.var_source) + + def test_histfunc_z_mismatch(self): + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "sum", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "table": self.source, + } + + heatmap_preprocessor = HeatmapPreprocessor(args) + + new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables( + [self.source] + ) + + self.assertRaises(ValueError, lambda: next(new_table_gen)) + + +if __name__ == "__main__": + unittest.main() From ee3dc8225b1e511e889f8405b566ccfe8cc03445 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Wed, 10 Jul 2024 14:11:45 -0500 Subject: [PATCH 09/11] wip --- .../express/communication/DeephavenFigureListener.py | 2 -- .../plot/express/deephaven_figure/generate.py | 1 - .../src/deephaven/plot/express/plots/heatmap.py | 12 ++++++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py index cab5def10..288d8dfa1 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py +++ b/plugins/plotly-express/src/deephaven/plot/express/communication/DeephavenFigureListener.py @@ -67,10 +67,8 @@ def _setup_listeners(self) -> None: Setup listeners for the partitioned tables """ for table, node in self._partitioned_tables.values(): - print(table, node) listen_func = partial(self._on_update, node) handle = listen(table, listen_func) - print(handle) self._handles.append(handle) self._liveness_scope.manage(handle) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py index 9d84aea6a..63d20ad55 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/generate.py @@ -783,7 +783,6 @@ def get_hover_body( def hover_text_generator( hover_mapping: list[dict[str, str]], - # hover_data - todo, dependent on arrays supported in data mappings types: set[str] | None = None, current_partition: dict[str, str] | None = None, ) -> Generator[dict[str, Any], None, None]: diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py index 0d85928d3..5118a2437 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -46,18 +46,18 @@ def density_heatmap( color_continuous_midpoint: A number that is the midpoint of the color axis opacity: Opacity to apply to all markers. 0 is completely transparent and 1 is completely opaque. - log_x: A boolean or list of booleans that specify if - the corresponding axis is a log axis or not. The booleans loop, so if there - are more series than booleans, booleans will be reused. - log_y: A boolean or list of booleans that specify if - the corresponding axis is a log axis or not. The booleans loop, so if there - are more series than booleans, booleans will be reused. + log_x: A boolean that specifies if the corresponding axis is a log axis or not. + log_y: A boolean that specifies if the corresponding axis is a log axis or not. range_x: A list of two numbers that specify the range of the x axes. None can be specified for no range range_y: A list of two numbers that specify the range of the y axes. None can be specified for no range range_bins_x: A list of two numbers that specify the range of data that is used for x. + None can be specified to use the min and max of the data. + None can also be specified for either of the list values to use the min or max of the data, respectively. range_bins_y: A list of two numbers that specify the range of data that is used for y. + None can be specified to use the min and max of the data. + None can also be specified for either of the list values to use the min or max of the data, respectively. histfunc: The function to use when aggregating within bins. One of 'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', 'sum', or 'var' From 7497eaf039084e3b69333694bd30340651227534 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Fri, 12 Jul 2024 10:35:09 -0500 Subject: [PATCH 10/11] wip --- .../express/deephaven_figure/custom_draw.py | 6 +- .../deephaven/plot/express/plots/heatmap.py | 12 +- .../express/preprocess/HeatmapPreprocessor.py | 8 +- .../plot/express/preprocess/utilities.py | 10 + .../plot/express/plots/test_heatmap.py | 63 ++--- .../preprocess/test_HeatmapPreprocessor.py | 231 +++++++++++++++++- 6 files changed, 287 insertions(+), 43 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py index e07e1bc4a..4d822dd55 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py +++ b/plugins/plotly-express/src/deephaven/plot/express/deephaven_figure/custom_draw.py @@ -119,8 +119,8 @@ def draw_density_heatmap( z: str, labels: dict[str, str] | None = None, range_color: list[float] | None = None, - color_continuous_scale: str | None = "Viridis", - color_continuous_midpoint: list[float] | None = None, + color_continuous_scale: str | list[str] | None = "plasma", + color_continuous_midpoint: float | None = None, opacity: float = 1.0, title: str | None = None, template: str | None = None, @@ -133,7 +133,7 @@ def draw_density_heatmap( y: The name of the column containing y-axis values z: The name of the column containing bin values labels: A dictionary of labels mapping columns to new labels - color_continuous_scale: A list of colors for a continuous scale + color_continuous_scale: A color scale or list of colors for a continuous scale range_color: A list of two numbers that form the endpoints of the color axis color_continuous_midpoint: A number that is the midpoint of the color axis opacity: Opacity to apply to all markers. 0 is completely transparent diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py index 5118a2437..f40afd023 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Literal from deephaven.plot.express.shared import default_callback @@ -15,7 +15,7 @@ def density_heatmap( y: str | None = None, z: str | None = None, labels: dict[str, str] | None = None, - color_continuous_scale: str | None = None, + color_continuous_scale: str | list[str] | None = None, range_color: list[float] | None = None, color_continuous_midpoint: float | None = None, opacity: float = 1.0, @@ -28,6 +28,7 @@ def density_heatmap( histfunc: str = "count", nbinsx: int = 10, nbinsy: int = 10, + empty_bin_default: float | Literal["NaN"] | None = None, title: str | None = None, template: str | None = None, unsafe_update_figure: Callable = default_callback, @@ -41,7 +42,7 @@ def density_heatmap( y: A column that contains y-axis values. z: A column that contains z-axis values. If not provided, the count of joint occurrences of x and y will be used. labels: A dictionary of labels mapping columns to new labels. - color_continuous_scale: A list of colors for a continuous scale + color_continuous_scale: A color scale or list of colors for a continuous scale range_color: A list of two numbers that form the endpoints of the color axis color_continuous_midpoint: A number that is the midpoint of the color axis opacity: Opacity to apply to all markers. 0 is completely transparent @@ -63,6 +64,11 @@ def density_heatmap( 'sum', or 'var' nbinsx: The number of bins to use for the x-axis nbinsy: The number of bins to use for the y-axis + empty_bin_default: The value to use for bins that have no data. + If None and histfunc is 'count' or 'count_distinct', 0 is used. + Otherwise, if None or 'NaN', NaN is used. + Note that if multiple points are required to color a bin, such as the case for a histfunc of 'std' or var, + the bin will still be NaN if less than the required number of points are present. title: The title of the chart template: The template for the chart. unsafe_update_figure: An update function that takes a plotly figure diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py index 9a0196ac8..241dc44cb 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/HeatmapPreprocessor.py @@ -38,6 +38,12 @@ def __init__(self, args: dict[str, Any]): self.nbinsy = args.pop("nbinsy") self.range_bins_x = args.pop("range_bins_x") self.range_bins_y = args.pop("range_bins_y") + self.empty_bin_default = args.pop("empty_bin_default") + if ( + self.histfunc in {"count", "count_distinct"} + and self.empty_bin_default is None + ): + self.empty_bin_default = 0 # create unique names for the columns to ensure no collisions self.names = get_unique_names( self.args["table"], @@ -132,7 +138,7 @@ def preprocess_partitioned_tables( ranged_bin_counts = bin_counts.join(range_table) bin_counts_with_midpoint = calculate_bin_locations( - ranged_bin_counts, self.names, histfunc_col + ranged_bin_counts, self.names, histfunc_col, self.empty_bin_default ) heatmap_agg_label = f"{self.histfunc} of {z}" if z else self.histfunc diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py index 6d20eb8f1..55d47de91 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py @@ -231,6 +231,7 @@ def calculate_bin_locations( ranged_bin_counts: Table, names: dict[str, str], histfunc_col: str, + empty_bin_default: float | str | None, ) -> Table: """ Compute the center of the bins for the x and y axes @@ -240,6 +241,7 @@ def calculate_bin_locations( bin_counts_ranged: A table that contains the bin counts and the range columns names: The names used for columns so that they don't collide histfunc_col: The column that contains the aggregated values + empty_bin_default: The default value to use for bins that have no data Returns: A table that contains the bin counts and the center of the bins @@ -256,6 +258,14 @@ def calculate_bin_locations( y = names["y"] agg_col = names["agg_col"] + # both "NaN" and None require no replacement + # it is assumed that default_bin_value has already been set to a number + # if needed, such as in the case of a histfunc of count or count_distinct + if empty_bin_default not in {"NaN", None}: + ranged_bin_counts = ranged_bin_counts.update_view( + f"{agg_col} = replaceIfNull({agg_col}, {empty_bin_default})" + ) + return ranged_bin_counts.update_view( [ f"{bin_min_x} = {range_x}.binMin({range_index_x})", diff --git a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py index 3aefe5e31..0c997e5ca 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py +++ b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py @@ -44,16 +44,16 @@ def test_basic_heatmap(self): "coloraxis": { "colorbar": {"title": {"text": "count"}}, "colorscale": [ - [0.0, "#440154"], - [0.1111111111111111, "#482878"], - [0.2222222222222222, "#3e4989"], - [0.3333333333333333, "#31688e"], - [0.4444444444444444, "#26828e"], - [0.5555555555555556, "#1f9e89"], - [0.6666666666666666, "#35b779"], - [0.7777777777777778, "#6ece58"], - [0.8888888888888888, "#b5de2b"], - [1.0, "#fde725"], + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], ], }, "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "X"}}, @@ -111,16 +111,16 @@ def test_heatmap_relabel_z(self): "coloraxis": { "colorbar": {"title": {"text": "count of Column Z"}}, "colorscale": [ - [0.0, "#440154"], - [0.1111111111111111, "#482878"], - [0.2222222222222222, "#3e4989"], - [0.3333333333333333, "#31688e"], - [0.4444444444444444, "#26828e"], - [0.5555555555555556, "#1f9e89"], - [0.6666666666666666, "#35b779"], - [0.7777777777777778, "#6ece58"], - [0.8888888888888888, "#b5de2b"], - [1.0, "#fde725"], + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], ], }, "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, @@ -179,22 +179,24 @@ def test_heatmap_relabel_agg_z(self): "coloraxis": { "colorbar": {"title": {"text": "count"}}, "colorscale": [ - [0.0, "#440154"], - [0.1111111111111111, "#482878"], - [0.2222222222222222, "#3e4989"], - [0.3333333333333333, "#31688e"], - [0.4444444444444444, "#26828e"], - [0.5555555555555556, "#1f9e89"], - [0.6666666666666666, "#35b779"], - [0.7777777777777778, "#6ece58"], - [0.8888888888888888, "#b5de2b"], - [1.0, "#fde725"], + [0.0, "#0d0887"], + [0.1111111111111111, "#46039f"], + [0.2222222222222222, "#7201a8"], + [0.3333333333333333, "#9c179e"], + [0.4444444444444444, "#bd3786"], + [0.5555555555555556, "#d8576b"], + [0.6666666666666666, "#ed7953"], + [0.7777777777777778, "#fb9f3a"], + [0.8888888888888888, "#fdca26"], + [1.0, "#f0f921"], ], }, "xaxis": {"anchor": "y", "side": "bottom", "title": {"text": "Column X"}}, "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, } + print(plotly["layout"]["coloraxis"]["colorscale"]) + self.assertEqual(plotly["layout"], expected_layout) expected_mappings = [ @@ -238,6 +240,7 @@ def test_full_heatmap(self): range_y=[0, 10], range_bins_x=[5, 10], range_bins_y=[5, 10], + empty_bin_default=0, histfunc="sum", nbinsx=2, nbinsy=2, diff --git a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py index 765ba39f7..1e1403fbe 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py +++ b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py @@ -1,5 +1,6 @@ import unittest +import numpy as np import pandas as pd from ..BaseTest import BaseTestCase, remap_types @@ -22,17 +23,25 @@ def setUp(self) -> None: # for a resulting var and std of 0 in all grid cells self.var_source = merge([self.source, self.source]) - def tables_equal(self, args, expected_df, t=None): - if t is None: - t = self.source - - args_copy = args.copy() + def tables_equal(self, args, expected_df, t=None, post_process=None) -> None: + """ + Compare the expected dataframe to the actual dataframe generated by the preprocessor + Args: + args: The arguments to pass to the preprocessor + expected_df: The expected dataframe + t: The table to preprocess, defaults to self.source + """ from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( HeatmapPreprocessor, ) import deephaven.pandas as dhpd + if t is None: + t = self.source + + args_copy = args.copy() + heatmap_preprocessor = HeatmapPreprocessor(args_copy) new_table_gen = heatmap_preprocessor.preprocess_partitioned_tables([t]) @@ -40,6 +49,9 @@ def tables_equal(self, args, expected_df, t=None): new_df = dhpd.to_pandas(new_table) + if post_process is not None: + post_process(new_df) + self.assertTrue(expected_df.equals(new_df)) def test_basic_preprocessor(self): @@ -52,6 +64,7 @@ def test_basic_preprocessor(self): "nbinsy": 2, "range_bins_x": None, "range_bins_y": None, + "empty_bin_default": None, "table": self.source, } @@ -76,6 +89,7 @@ def test_basic_preprocessor_z(self): "nbinsy": 2, "range_bins_x": None, "range_bins_y": None, + "empty_bin_default": None, "table": self.source, } @@ -96,6 +110,7 @@ def test_partial_range_preprocessor(self): "nbinsy": 2, "range_bins_x": [None, 6], "range_bins_y": None, + "empty_bin_default": None, "table": self.source, } @@ -120,6 +135,7 @@ def test_full_preprocessor(self): "nbinsy": 3, "range_bins_x": [1, 5], "range_bins_y": [-1, 5], + "empty_bin_default": None, "table": self.source, } @@ -127,7 +143,7 @@ def test_full_preprocessor(self): { "X": [2.0, 2.0, 2.0, 4.0, 4.0, 4.0], "Y": [0.0, 2.0, 4.0, 0.0, 2.0, 4.0], - "count": [pd.NA, pd.NA, pd.NA, 1, pd.NA, 1], + "count": [0, 0, 0, 1, 0, 1], } ) remap_types(expected_df) @@ -145,6 +161,7 @@ def test_full_preprocessor_z(self): "nbinsy": 2, "range_bins_x": [-1, 5], "range_bins_y": [1, 5], + "empty_bin_default": None, "table": self.source, } @@ -170,6 +187,7 @@ def test_preprocessor_aggs(self): "nbinsy": 2, "range_bins_x": None, "range_bins_y": None, + "empty_bin_default": None, "table": self.source, } @@ -301,6 +319,7 @@ def test_histfunc_z_mismatch(self): "nbinsy": 2, "range_bins_x": None, "range_bins_y": None, + "empty_bin_default": None, "table": self.source, } @@ -312,6 +331,206 @@ def test_histfunc_z_mismatch(self): self.assertRaises(ValueError, lambda: next(new_table_gen)) + def test_empty_bin_default(self): + args = { + "x": "X", + "y": "Y", + "z": "Z", + "histfunc": "sum", + "nbinsx": 4, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": 0, + "table": self.source, + } + + # sum, 0 - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, 0, 0, 0, 0, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = "NaN" + + # sum, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # sum, None - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "sum": [1, 2, pd.NA, pd.NA, pd.NA, pd.NA, 1, 2], + } + ) + remap_types(expected_df) + expected_df["sum"] = expected_df["sum"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "var" + + # var, None - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [np.nan, np.nan, pd.NA, pd.NA, pd.NA, pd.NA, np.nan, np.nan], + } + ) + remap_types(expected_df) + + expected_df["var"] = expected_df["var"].astype("Float64") + + def na_conversion(df): + # convert to pandas.NA for whole df comparison, + # but verify that the nans are there because they come from there being no data in the bin + arr = df["var"].to_numpy() + for i in [0, 1, 6, 7]: + self.assertTrue(np.isnan(arr[i])) + df.at[i, "var"] = pd.NA + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["empty_bin_default"] = "NaN" + + # var, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA], + } + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["empty_bin_default"] = 3 + + # var, 1 - default to 1 if at least two data points + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "var": [pd.NA, pd.NA, 3, 3, 3, 3, pd.NA, pd.NA], + } + ) + remap_types(expected_df) + expected_df["var"] = expected_df["var"].astype("Float64") + + self.tables_equal(args, expected_df, post_process=na_conversion) + + args["histfunc"] = "count" + + # count, 3 - default to 3 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 3, 3, 3, 3, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # count, None - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, 0, 0, 0, 0, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = "NaN" + + # count, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count"] = expected_df["count"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["histfunc"] = "count_distinct" + + # count_distinct, "NaN" - no default + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, pd.NA, pd.NA, pd.NA, pd.NA, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = None + + # count_distinct, None - default to 0 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 0, 0, 0, 0, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + + args["empty_bin_default"] = 2 + + # count_distinct, 2 - default to 2 + expected_df = pd.DataFrame( + { + "X": [0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5], + "Y": [1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0], + "count_distinct": [1, 1, 2, 2, 2, 2, 1, 1], + } + ) + remap_types(expected_df) + expected_df["count_distinct"] = expected_df["count_distinct"].astype("Int64") + + self.tables_equal(args, expected_df) + if __name__ == "__main__": unittest.main() From d1a59bca4a900fe4f1a12d5c0612812587b53668 Mon Sep 17 00:00:00 2001 From: Joe Numainville Date: Tue, 16 Jul 2024 16:34:15 -0500 Subject: [PATCH 11/11] wip --- .../deephaven/plot/express/plots/heatmap.py | 1 + .../plot/express/preprocess/utilities.py | 8 ++++-- .../plot/express/plots/test_heatmap.py | 2 -- .../preprocess/test_HeatmapPreprocessor.py | 25 +++++++++++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py index f40afd023..ba40fc27f 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py +++ b/plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py @@ -67,6 +67,7 @@ def density_heatmap( empty_bin_default: The value to use for bins that have no data. If None and histfunc is 'count' or 'count_distinct', 0 is used. Otherwise, if None or 'NaN', NaN is used. + 'NaN' forces the bin to be NaN if no data is present, even if histfunc is 'count' or 'count_distinct'. Note that if multiple points are required to color a bin, such as the case for a histfunc of 'std' or var, the bin will still be NaN if less than the required number of points are present. title: The title of the chart diff --git a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py index 55d47de91..4c6ac80aa 100644 --- a/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py +++ b/plugins/plotly-express/src/deephaven/plot/express/preprocess/utilities.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generator +from typing import Generator, Literal from deephaven import agg, empty_table from deephaven.plot.express.shared import get_unique_names @@ -231,7 +231,7 @@ def calculate_bin_locations( ranged_bin_counts: Table, names: dict[str, str], histfunc_col: str, - empty_bin_default: float | str | None, + empty_bin_default: float | Literal["NaN"] | None, ) -> Table: """ Compute the center of the bins for the x and y axes @@ -261,6 +261,10 @@ def calculate_bin_locations( # both "NaN" and None require no replacement # it is assumed that default_bin_value has already been set to a number # if needed, such as in the case of a histfunc of count or count_distinct + + if isinstance(empty_bin_default, str) and empty_bin_default != "NaN": + raise ValueError("empty_bin_default must be 'NaN' if it is a string") + if empty_bin_default not in {"NaN", None}: ranged_bin_counts = ranged_bin_counts.update_view( f"{agg_col} = replaceIfNull({agg_col}, {empty_bin_default})" diff --git a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py index 0c997e5ca..bc5ad0b1b 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py +++ b/plugins/plotly-express/test/deephaven/plot/express/plots/test_heatmap.py @@ -195,8 +195,6 @@ def test_heatmap_relabel_agg_z(self): "yaxis": {"anchor": "x", "side": "left", "title": {"text": "Column Y"}}, } - print(plotly["layout"]["coloraxis"]["colorscale"]) - self.assertEqual(plotly["layout"], expected_layout) expected_mappings = [ diff --git a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py index 1e1403fbe..4e18dd7e4 100644 --- a/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py +++ b/plugins/plotly-express/test/deephaven/plot/express/preprocess/test_HeatmapPreprocessor.py @@ -531,6 +531,31 @@ def na_conversion(df): self.tables_equal(args, expected_df) + def test_bad_empty_bin_default(self): + from src.deephaven.plot.express.preprocess.HeatmapPreprocessor import ( + HeatmapPreprocessor, + ) + + args = { + "x": "X", + "y": "Y", + "z": None, + "histfunc": "count", + "nbinsx": 2, + "nbinsy": 2, + "range_bins_x": None, + "range_bins_y": None, + "empty_bin_default": "bad", + "table": self.source, + } + + process = HeatmapPreprocessor(args) + + self.assertRaises( + ValueError, + lambda: next(process.preprocess_partitioned_tables([self.source])), + ) + if __name__ == "__main__": unittest.main()