diff --git a/keypoint_moseq/calibration.py b/keypoint_moseq/calibration.py index a07461d..6178c71 100644 --- a/keypoint_moseq/calibration.py +++ b/keypoint_moseq/calibration.py @@ -57,9 +57,7 @@ def sample_error_frames( for low, high in zip(thresholds[:-1], thresholds[1:]): samples_in_bin = [] for key, confs in confidences.items(): - for t, k in zip( - *np.nonzero((confs >= low) * (confs < high) * mask) - ): + for t, k in zip(*np.nonzero((confs >= low) * (confs < high) * mask)): samples_in_bin.append((key, t, bodyparts[k])) if len(samples_in_bin) > 0: @@ -67,9 +65,7 @@ def sample_error_frames( for i in np.random.choice(len(samples_in_bin), n, replace=False): sample_keys.append(samples_in_bin[i]) - sample_keys = [ - sample_keys[i] for i in np.random.permutation(len(sample_keys)) - ] + sample_keys = [sample_keys[i] for i in np.random.permutation(len(sample_keys))] return sample_keys @@ -106,8 +102,7 @@ def load_sampled_frames(sample_keys, video_dir, video_extension=None): ncols=72, ) return { - (key, frame, bodypart): readers[key][frame] - for key, frame, bodypart in pbar + (key, frame, bodypart): readers[key][frame] for key, frame, bodypart in pbar } @@ -176,18 +171,14 @@ def save_params(project_dir, estimator): ) -def _confs_and_dists_from_annotations( - coordinates, confidences, annotations, bodyparts -): +def _confs_and_dists_from_annotations(coordinates, confidences, annotations, bodyparts): confs, dists = [], [] for (key, frame, bodypart), xy in annotations.items(): if key in coordinates and key in confidences: k = bodyparts.index(bodypart) confs.append(confidences[key][frame][k]) dists.append( - np.sqrt( - ((coordinates[key][frame][k] - np.array(xy)) ** 2).sum() - ) + np.sqrt(((coordinates[key][frame][k] - np.array(xy)) ** 2).sum()) ) return confs, dists @@ -211,6 +202,7 @@ def _noise_calibration_widget( from holoviews.streams import Tap, Stream import holoviews as hv import panel as pn + from bokeh.models import GlyphRenderer, ImageRGBA, Scatter, GraphRenderer hv.extension("bokeh") @@ -219,13 +211,9 @@ def _noise_calibration_widget( edges = np.array(get_edges(bodyparts, skeleton)) conf_vals = np.hstack([v.flatten() for v in confidences.values()]) - min_conf, max_conf = np.nanpercentile(conf_vals, 0.01), np.nanmax( - conf_vals - ) + min_conf, max_conf = np.nanpercentile(conf_vals, 0.01), np.nanmax(conf_vals) - annotations_stream = Stream.define( - "Annotations", annotations=annotations - )() + annotations_stream = Stream.define("Annotations", annotations=annotations)() current_sample = Stream.define("Current sample", sample_ix=0)() estimator = Stream.define( "Estimator", @@ -275,9 +263,9 @@ def update_scatter(x, y, annotations): default_tools=[], ) - curve = hv.Curve( - [(xlim[0], xlim[0] * m + b), (xlim[1], xlim[1] * m + b)] - ).opts(xlim=xlim, ylim=ylim, axiswise=True, default_tools=[]) + curve = hv.Curve([(xlim[0], xlim[0] * m + b), (xlim[1], xlim[1] * m + b)]).opts( + xlim=xlim, ylim=ylim, axiswise=True, default_tools=[] + ) vline_label = hv.Text( x - (xlim[1] - xlim[0]) / 50, @@ -305,6 +293,19 @@ def update_scatter(x, y, annotations): ylabel="log10(error)", ) + def enforce_z_order_hook(plot, element): + bokeh_figure = plot.state + graph, scatter, rgb = None, None, None + for r in bokeh_figure.renderers: + if isinstance(r, GlyphRenderer): + if isinstance(r.glyph, ImageRGBA): + rgb = r + if isinstance(r.glyph, Scatter): + scatter = r + if isinstance(r, GraphRenderer): + graph = r + bokeh_figure.renderers = [rgb, graph, scatter] + def update_img(sample_ix, x, y): key, frame, bodypart = sample_key = sample_keys[sample_ix] image = sample_images[sample_key] @@ -358,9 +359,7 @@ def update_img(sample_ix, x, y): if len(masked_edges) > 0: edge_data = (*masked_edges.T, colorvals[masked_edges[:, 0]]) - sizes = np.where(np.arange(len(xys)) == keypoint_ix, 10, 6)[ - masked_nodes - ] + sizes = np.where(np.arange(len(xys)) == keypoint_ix, 10, 6)[masked_nodes] masked_bodyparts = [bodyparts[i] for i in masked_nodes] nodes = hv.Nodes( (*xys[masked_nodes].T, masked_nodes, masked_bodyparts, sizes), @@ -376,7 +375,11 @@ def update_img(sample_ix, x, y): ) return (rgb * graph * hv_point).opts( - data_aspect=1, xlim=xlim, ylim=ylim, toolbar=None + data_aspect=1, + xlim=xlim, + ylim=ylim, + toolbar=None, + hooks=[enforce_z_order_hook], ) def update_estimator_text(*, slope, intercept, conf_threshold):