Skip to content

Commit

Permalink
Merge pull request #123 from dattalab/calibration_zorder_hotfix
Browse files Browse the repository at this point in the history
enforce z order with bokeh hook
  • Loading branch information
calebweinreb authored Jan 17, 2024
2 parents 1c80236 + 7d7ef23 commit bf462f5
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions keypoint_moseq/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,15 @@ def sample_error_frames(
for low, high in zip(thresholds[:-1], thresholds[1:]):
samples_in_bin = []
for key, confs in confidences.items():
for t, k in zip(
*np.nonzero((confs >= low) * (confs < high) * mask)
):
for t, k in zip(*np.nonzero((confs >= low) * (confs < high) * mask)):
samples_in_bin.append((key, t, bodyparts[k]))

if len(samples_in_bin) > 0:
n = min(num_samples // num_bins, len(samples_in_bin))
for i in np.random.choice(len(samples_in_bin), n, replace=False):
sample_keys.append(samples_in_bin[i])

sample_keys = [
sample_keys[i] for i in np.random.permutation(len(sample_keys))
]
sample_keys = [sample_keys[i] for i in np.random.permutation(len(sample_keys))]
return sample_keys


Expand Down Expand Up @@ -106,8 +102,7 @@ def load_sampled_frames(sample_keys, video_dir, video_extension=None):
ncols=72,
)
return {
(key, frame, bodypart): readers[key][frame]
for key, frame, bodypart in pbar
(key, frame, bodypart): readers[key][frame] for key, frame, bodypart in pbar
}


Expand Down Expand Up @@ -176,18 +171,14 @@ def save_params(project_dir, estimator):
)


def _confs_and_dists_from_annotations(
coordinates, confidences, annotations, bodyparts
):
def _confs_and_dists_from_annotations(coordinates, confidences, annotations, bodyparts):
confs, dists = [], []
for (key, frame, bodypart), xy in annotations.items():
if key in coordinates and key in confidences:
k = bodyparts.index(bodypart)
confs.append(confidences[key][frame][k])
dists.append(
np.sqrt(
((coordinates[key][frame][k] - np.array(xy)) ** 2).sum()
)
np.sqrt(((coordinates[key][frame][k] - np.array(xy)) ** 2).sum())
)
return confs, dists

Expand All @@ -211,6 +202,7 @@ def _noise_calibration_widget(
from holoviews.streams import Tap, Stream
import holoviews as hv
import panel as pn
from bokeh.models import GlyphRenderer, ImageRGBA, Scatter, GraphRenderer

hv.extension("bokeh")

Expand All @@ -219,13 +211,9 @@ def _noise_calibration_widget(

edges = np.array(get_edges(bodyparts, skeleton))
conf_vals = np.hstack([v.flatten() for v in confidences.values()])
min_conf, max_conf = np.nanpercentile(conf_vals, 0.01), np.nanmax(
conf_vals
)
min_conf, max_conf = np.nanpercentile(conf_vals, 0.01), np.nanmax(conf_vals)

annotations_stream = Stream.define(
"Annotations", annotations=annotations
)()
annotations_stream = Stream.define("Annotations", annotations=annotations)()
current_sample = Stream.define("Current sample", sample_ix=0)()
estimator = Stream.define(
"Estimator",
Expand Down Expand Up @@ -275,9 +263,9 @@ def update_scatter(x, y, annotations):
default_tools=[],
)

curve = hv.Curve(
[(xlim[0], xlim[0] * m + b), (xlim[1], xlim[1] * m + b)]
).opts(xlim=xlim, ylim=ylim, axiswise=True, default_tools=[])
curve = hv.Curve([(xlim[0], xlim[0] * m + b), (xlim[1], xlim[1] * m + b)]).opts(
xlim=xlim, ylim=ylim, axiswise=True, default_tools=[]
)

vline_label = hv.Text(
x - (xlim[1] - xlim[0]) / 50,
Expand Down Expand Up @@ -305,6 +293,19 @@ def update_scatter(x, y, annotations):
ylabel="log10(error)",
)

def enforce_z_order_hook(plot, element):
bokeh_figure = plot.state
graph, scatter, rgb = None, None, None
for r in bokeh_figure.renderers:
if isinstance(r, GlyphRenderer):
if isinstance(r.glyph, ImageRGBA):
rgb = r
if isinstance(r.glyph, Scatter):
scatter = r
if isinstance(r, GraphRenderer):
graph = r
bokeh_figure.renderers = [rgb, graph, scatter]

def update_img(sample_ix, x, y):
key, frame, bodypart = sample_key = sample_keys[sample_ix]
image = sample_images[sample_key]
Expand Down Expand Up @@ -358,9 +359,7 @@ def update_img(sample_ix, x, y):
if len(masked_edges) > 0:
edge_data = (*masked_edges.T, colorvals[masked_edges[:, 0]])

sizes = np.where(np.arange(len(xys)) == keypoint_ix, 10, 6)[
masked_nodes
]
sizes = np.where(np.arange(len(xys)) == keypoint_ix, 10, 6)[masked_nodes]
masked_bodyparts = [bodyparts[i] for i in masked_nodes]
nodes = hv.Nodes(
(*xys[masked_nodes].T, masked_nodes, masked_bodyparts, sizes),
Expand All @@ -376,7 +375,11 @@ def update_img(sample_ix, x, y):
)

return (rgb * graph * hv_point).opts(
data_aspect=1, xlim=xlim, ylim=ylim, toolbar=None
data_aspect=1,
xlim=xlim,
ylim=ylim,
toolbar=None,
hooks=[enforce_z_order_hook],
)

def update_estimator_text(*, slope, intercept, conf_threshold):
Expand Down

0 comments on commit bf462f5

Please sign in to comment.