Skip to content

Commit

Permalink
feat(bokeh): Server side HoverTool for rasterized/datashaded plots wi…
Browse files Browse the repository at this point in the history
…th selector (#6422)
  • Loading branch information
hoxbro authored Feb 20, 2025
1 parent d301861 commit 02258ce
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 108 deletions.
34 changes: 26 additions & 8 deletions holoviews/element/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def load_image(cls, filename, height=1, array=False, bounds=None, bare=False, **
try:
from PIL import Image
except ImportError:
raise ImportError("RGB.load_image requires PIL (or Pillow).") from None
raise ImportError(f"{cls.__name__}.load_image requires PIL (or Pillow).") from None

with open(filename, 'rb') as f:
data = np.array(Image.open(f))
Expand Down Expand Up @@ -687,15 +687,33 @@ def __init__(self, data, kdims=None, vdims=None, **params):
else:
vdims = list(vdims) if isinstance(vdims, list) else [vdims]

alpha = self.alpha_dimension
if ((hasattr(data, 'shape') and data.shape[-1] == 4 and len(vdims) == 3) or
(isinstance(data, tuple) and isinstance(data[-1], np.ndarray) and data[-1].ndim == 3
and data[-1].shape[-1] == 4 and len(vdims) == 3) or
(isinstance(data, dict) and (*map(dimension_name, vdims), alpha.name) in data)):
# Handle all forms of packed value dimensions
vdims.append(alpha)
if self._has_alpha_dimension(data, vdims):
vdims.append(self.alpha_dimension)
super().__init__(data, kdims=kdims, vdims=vdims, **params)

def _has_alpha_dimension(self, data, vdims) -> bool:
# Handle all forms of packed value dimensions
if len(vdims) != 3:
return False

alpha = self.alpha_dimension

if hasattr(data, "shape") and data.shape[-1] == 4:
return True

if isinstance(data, tuple):
last = data[-1]
if isinstance(last, np.ndarray) and last.ndim == 3 and last.shape[-1] == 4:
return True

if isinstance(data, dict) and (*map(dimension_name, vdims), alpha.name) in data:
return True

if str(alpha) in getattr(data, "data_vars", []):
return True

return False


class HSV(RGB):
"""
Expand Down
186 changes: 137 additions & 49 deletions holoviews/operation/datashader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import warnings
from collections.abc import Callable, Iterable
from functools import partial
Expand Down Expand Up @@ -238,6 +239,25 @@ class LineAggregationOperation(AggregationOperation):



class AggState(enum.Enum):
AGG_ONLY = 0 # Only aggregator
AGG_BY = 1 # Aggregator where the aggregator is ds.by
AGG_SEL = 2 # Selector and aggregator
AGG_SEL_BY = 3 # Selector and aggregator, where the aggregator is ds.by

def get_state(agg_fn, sel_fn):
if isinstance(agg_fn, ds.by):
return AggState.AGG_SEL_BY if sel_fn else AggState.AGG_BY
else:
return AggState.AGG_SEL if sel_fn else AggState.AGG_ONLY

def has_sel(state):
return state in (AggState.AGG_SEL, AggState.AGG_SEL_BY)

def has_by(state):
return state in (AggState.AGG_BY, AggState.AGG_SEL_BY)


class aggregate(LineAggregationOperation):
"""
aggregate implements 2D binning for any valid HoloViews Element
Expand Down Expand Up @@ -391,16 +411,14 @@ def _process(self, element, key=None):
dfdata = PandasInterface.as_dframe(data)
cvs_fn = getattr(cvs, glyph)

if sel_fn:
agg_state = AggState.get_state(agg_fn, sel_fn)
if AggState.has_sel(agg_state):
if isinstance(params["vdims"], (Dimension, str)):
params["vdims"] = [params["vdims"]]
sum_agg = ds.summary(**{str(params["vdims"][0]): agg_fn, "__index__": ds.where(sel_fn)})
agg = self._apply_datashader(dfdata, cvs_fn, sum_agg, agg_kwargs, x, y)
_ignore = [*params["vdims"], "__index__"]
sel_vdims = [s for s in agg if s not in _ignore]
params["vdims"] = [*params["vdims"], *sel_vdims]
agg = self._apply_datashader(dfdata, cvs_fn, sum_agg, agg_kwargs, x, y, agg_state)
else:
agg = self._apply_datashader(dfdata, cvs_fn, agg_fn, agg_kwargs, x, y)
agg = self._apply_datashader(dfdata, cvs_fn, agg_fn, agg_kwargs, x, y, agg_state)

if 'x_axis' in agg.coords and 'y_axis' in agg.coords:
agg = agg.rename({'x_axis': x, 'y_axis': y})
Expand All @@ -409,14 +427,16 @@ def _process(self, element, key=None):
if ytype == 'datetime':
agg[y.name] = agg[y.name].astype('datetime64[ns]')

if isinstance(agg, xr.Dataset) or agg.ndim == 2:
# Replacing x and y coordinates to avoid numerical precision issues
if not AggState.has_by(agg_state):
return self.p.element_type(agg, **params)
else:
elif agg_state == AggState.AGG_BY:
params['vdims'] = list(map(str, agg.coords[agg_fn.column].data))
return ImageStack(agg, **params)
elif agg_state == AggState.AGG_SEL_BY:
params['vdims'] = [d for d in agg.data_vars if d not in agg.attrs["selector_columns"]]
return ImageStack(agg, **params)

def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y, agg_state: AggState):
# Suppress numpy warning emitted by dask:
# https://github.com/dask/dask/issues/8439
with warnings.catch_warnings():
Expand All @@ -427,18 +447,26 @@ def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
agg = cvs_fn(dfdata, x.name, y.name, agg_fn, **agg_kwargs)

is_where_index = DATASHADER_GE_0_15_1 and isinstance(agg_fn, ds.where) and isinstance(agg_fn.column, rd.SpecialColumn)
is_summary_index = isinstance(agg_fn, ds.summary) and "__index__" in agg
is_summary_index = AggState.has_sel(agg_state)
if is_where_index or is_summary_index:
if is_where_index:
data = agg.data
index = agg.data
agg = agg.to_dataset(name="__index__")
else: # summary index
data = agg["__index__"].data
neg1 = data == -1
index = agg["__index__"].data
if agg_state == AggState.AGG_SEL_BY:
main_dim = next(k for k in agg if k != "__index__")
# Taking values from the main dimension expanding it to
# a new dataset
agg = agg[main_dim].to_dataset(dim=list(agg.sizes)[2])
agg["__index__"] = ((y.name, x.name), index)

neg1 = index == -1
agg.attrs["selector_columns"] = sel_cols = ["__index__"]
for col in dfdata.columns:
if col in agg.coords:
continue
val = dfdata[col].values[data]
val = dfdata[col].values[index]
if val.dtype.kind == 'f':
val[neg1] = np.nan
elif isinstance(val.dtype, pd.CategoricalDtype):
Expand All @@ -452,8 +480,9 @@ def _apply_datashader(self, dfdata, cvs_fn, agg_fn, agg_kwargs, x, y):
val = val.astype(np.float64)
val[neg1] = np.nan
agg[col] = ((y.name, x.name), val)
sel_cols.append(col)

if isinstance(agg_fn, ds.by):
if agg_state == AggState.AGG_BY:
col = agg_fn.column
if '' in agg.coords[col]:
agg = agg.drop_sel(**{col: ''})
Expand Down Expand Up @@ -1234,7 +1263,9 @@ def uint32_to_uint8(cls, img):
"""
Cast uint32 RGB image to 4 uint8 channels.
"""
return np.flipud(img.view(dtype=np.uint8).reshape((*img.shape, 4)))
new_array = np.flipud(img.view(dtype=np.uint8).reshape((*img.shape, 4)))
new_array[new_array[:,:,3] == 0] = 0 # Set alpha 0 to 0 for all dimension
return new_array


@classmethod
Expand All @@ -1243,6 +1274,7 @@ def uint32_to_uint8_xr(cls, img):
Cast uint32 xarray DataArray to 4 uint8 channels.
"""
new_array = img.values.view(dtype=np.uint8).reshape((*img.shape, 4))
new_array[new_array[:,:,3] == 0] = 0
coords = dict(img.coords, band=[0, 1, 2, 3])
return xr.DataArray(new_array, coords=coords, dims=(*img.dims, 'band'))

Expand Down Expand Up @@ -1277,6 +1309,33 @@ def to_xarray(cls, element):
xdensity=element.xdensity,
ydensity=element.ydensity)

@classmethod
def _extract_data(self, element):
vdims = element.vdims
vdim = vdims[0].name if len(vdims) == 1 else None
if isinstance(element, ImageStack):
array = element.data
main_dims = element.data.sizes
# Dropping data related to selector columns
if sel_cols := array.attrs.get("selector_columns"):
array = array.drop_vars(sel_cols)
# If data is a xarray Dataset it has to be converted to a
# DataArray, either by selecting the singular value
# dimension or by adding a z-dimension
if not element.interface.packed(element):
if vdim:
array = array[vdim]
else:
array = array.to_array("z")
# If data is 3D then we have one extra constant dimension
if array.ndim > 3:
drop = set(array.dims) - {*main_dims, 'z'}
array = array.squeeze(dim=drop)
array = array.transpose(*main_dims, ...)
else:
array = element.data[vdim]

return array

def _process(self, element, key=None):
element = element.map(self.to_xarray, Image)
Expand All @@ -1297,26 +1356,7 @@ def _process(self, element, key=None):
element = element.clone(datatype=['xarray'])

kdims = element.kdims
if isinstance(element, ImageStack):
vdim = element.vdims
array = element.data
# If data is a xarray Dataset it has to be converted to a
# DataArray, either by selecting the singular value
# dimension or by adding a z-dimension
kdims = [kdim.name for kdim in kdims]
if not element.interface.packed(element):
if len(vdim) == 1:
array = array[vdim[0].name]
else:
array = array.to_array("z")
# If data is 3D then we have one extra constant dimension
if array.ndim > 3:
drop = [d for d in array.dims if d not in [*kdims, 'z']]
array = array.squeeze(dim=drop)
array = array.transpose(*kdims, ...)
else:
vdim = element.vdims[0].name
array = element.data[vdim]
array = self._extract_data(element)

# Dask is not supported by shade so materialize it
array = array.compute()
Expand Down Expand Up @@ -1372,12 +1412,26 @@ def _process(self, element, key=None):
coords = {xd.name: element.data.coords[xd.name],
yd.name: element.data.coords[yd.name],
'band': [0, 1, 2, 3]}
img = xr.DataArray(arr, coords=coords, dims=(yd.name, xd.name, 'band'))
return RGB(img, **params)
img_data = xr.DataArray(arr, coords=coords, dims=(yd.name, xd.name, 'band'))
img_data = self.add_selector_data(img_data=img_data, sel_data=element.data)
return RGB(img_data, **params)
else:
img = tf.shade(array, **shade_opts)
return RGB(self.uint32_to_uint8_xr(img), **params)
img_data = self.uint32_to_uint8_xr(img)
img_data = self.add_selector_data(img_data=img_data, sel_data=element.data)
return RGB(img_data, **params)

@classmethod
def add_selector_data(cls, *, img_data, sel_data):
if "selector_columns" in sel_data.attrs:
if {"R", "G", "B", "A"} & set(sel_data.attrs["selector_columns"]):
msg = "Cannot use 'R', 'G', 'B', or 'A' as columns, when using datashade with selector"
raise ValueError(msg)
img_data.coords["band"] = ["R", "G", "B", "A"]
img_data = img_data.to_dataset(dim="band")
img_data.update({k: sel_data[k] for k in sel_data.attrs["selector_columns"]})
img_data.attrs["selector_columns"] = sel_data.attrs["selector_columns"]
return img_data


class geometry_rasterize(LineAggregationOperation):
Expand Down Expand Up @@ -1654,7 +1708,7 @@ def uint8_to_uint32(cls, img):
rgb = img.reshape((flat_shape, 4)).view('uint32').reshape(shape[:2])
return rgb

def _apply_spreading(self, array):
def _apply_spreading(self, array, how=None):
"""Apply the spread function using the indicated parameters."""
raise NotImplementedError

Expand All @@ -1669,16 +1723,50 @@ def _process(self, element, key=None):
if isinstance(element, RGB):
rgb = element.rgb
data = self._preprocess_rgb(rgb)
elif isinstance(element, ImageStack):
data = element.data
elif isinstance(element, Image):
data = element.clone(datatype=['xarray']).data[element.vdims[0].name]
if element.interface.datatype != 'xarray':
element = element.clone(datatype=['xarray'])
data = shade._extract_data(element)
else:
raise ValueError('spreading can only be applied to Image or RGB Elements. '
f'Received object of type {type(element)!s}')

kwargs = {}
array = self._apply_spreading(data)
if "selector_columns" in getattr(element.data, "attrs", ()):
new_data = element.data.copy()
index = new_data["__index__"].copy()
mask = np.arange(index.size).reshape(index.shape)
mask[index == -1] = 0
index.data = mask
index = self._apply_spreading(index, how="source")
sel_data = {
sc: new_data[sc].data.ravel()[index].reshape(index.shape)
for sc in new_data.attrs["selector_columns"]
}

if isinstance(element, RGB):
img = datashade.uint32_to_uint8(array.data)[::-1]
for idx, k, in enumerate("RGBA"):
new_data[k].data = img[:, :, idx]
elif isinstance(element, ImageStack):
for k in map(str, element.vdims):
new_data[k].data = array.sel(z=k)
elif isinstance(element, Image):
new_data[element.vdims[0].name].data = array
else:
msg = f"{type(element).__name__} currently does not support spreading with selector_columns"
raise NotImplementedError(msg)

for k, v in sel_data.items():
new_data[k].data = v

# TODO: Investigate why this does not work
# element = element.clone(data=new_data, kdims=element.vdims.copy(), vdims=element.vdims.copy())
element = element.clone()
element.data = new_data
return element

kwargs = {}
if isinstance(element, RGB):
img = datashade.uint32_to_uint8(array.data)[::-1]
new_data = {
Expand Down Expand Up @@ -1710,8 +1798,8 @@ class spread(SpreadingOperation):
px = param.Integer(default=1, doc="""
Number of pixels to spread on all sides.""")

def _apply_spreading(self, array):
return tf.spread(array, px=self.p.px, how=self.p.how, shape=self.p.shape)
def _apply_spreading(self, array, how=None):
return tf.spread(array, px=self.p.px, how=how or self.p.how, shape=self.p.shape)


class dynspread(SpreadingOperation):
Expand All @@ -1737,10 +1825,10 @@ class dynspread(SpreadingOperation):
Higher values give more spreading, up to the max_px
allowed.""")

def _apply_spreading(self, array):
def _apply_spreading(self, array, how=None):
return tf.dynspread(
array, max_px=self.p.max_px, threshold=self.p.threshold,
how=self.p.how, shape=self.p.shape
how=how or self.p.how, shape=self.p.shape
)


Expand Down
Loading

0 comments on commit 02258ce

Please sign in to comment.