Skip to content

Commit

Permalink
Consider symmetric keypoints when computing OKS (DeepLabCut#1551)
Browse files Browse the repository at this point in the history
* Handle symmetric keypoints for OKS

* Add a couple of unit tests
  • Loading branch information
jeylau authored Oct 20, 2021
1 parent 858ccfe commit d7db8d3
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,9 @@ def evaluate_multianimal_full(
data_path.replace("_full.", "_meta."),
n_graphs=n_graphs,
paf_inds=paf_inds,
oks_sigma=dlc_cfg.get("oks_sigma", 0.1),
margin=dlc_cfg.get("bbox_margin", 0),
symmetric_kpts=dlc_cfg.get("symmetric_kpts"),
)
df = results[1].copy()
df.loc(axis=0)[('mAP_train', 'mean')] = [d[0]['mAP'] for d in results[2]]
Expand Down
26 changes: 24 additions & 2 deletions deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def _benchmark_paf_graphs(
identity_only=False,
calibration_file="",
oks_sigma=0.1,
margin=0,
symmetric_kpts=None,
split_inds=None,
):
metadata = data.pop("metadata")
Expand Down Expand Up @@ -246,9 +248,23 @@ def _benchmark_paf_graphs(
oks = []
for inds in split_inds:
assemblies = {k: v for k, v in ass.assemblies.items() if k in inds}
oks.append(evaluate_assembly(assemblies, ass_true_dict, oks_sigma))
oks.append(
evaluate_assembly(
assemblies,
ass_true_dict,
oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
)
)
else:
oks = evaluate_assembly(ass.assemblies, ass_true_dict, oks_sigma)
oks = evaluate_assembly(
ass.assemblies,
ass_true_dict,
oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
)
all_metrics.append(oks)
scores = np.full((len(image_paths), 2), np.nan)
for i, imname in enumerate(tqdm(image_paths)):
Expand Down Expand Up @@ -359,12 +375,15 @@ def cross_validate_paf_graphs(
metadata_file,
output_name="",
pcutoff=0.1,
oks_sigma=0.1,
margin=0,
greedy=False,
add_discarded=True,
calibrate=False,
overwrite_config=True,
n_graphs=10,
paf_inds=None,
symmetric_kpts=None,
):
cfg = auxiliaryfunctions.read_config(config)
inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config)
Expand Down Expand Up @@ -406,6 +425,9 @@ def cross_validate_paf_graphs(
paf_inds,
greedy,
add_discarded,
oks_sigma=oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
calibration_file=calibration_file,
split_inds=[
metadata["data"]["trainIndices"],
Expand Down
62 changes: 50 additions & 12 deletions deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,23 +806,53 @@ def to_pickle(self, output_name):
pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)


def calc_object_keypoint_similarity(xy_pred, xy_true, sigma):
def calc_object_keypoint_similarity(
xy_pred,
xy_true,
sigma,
margin=0,
symmetric_kpts=None,
):
visible_gt = ~np.isnan(xy_true).all(axis=1)
if visible_gt.sum() < 2: # At least 2 points needed to calculate scale
return np.nan
true = xy_true[visible_gt]
pred = xy_pred[visible_gt]
pred[np.isnan(pred)] = np.inf
dist_squared = np.sum((pred - true) ** 2, axis=1)
scale_squared = np.product(np.ptp(true, axis=0) + np.spacing(1))
scale_squared = np.product(np.ptp(true, axis=0) + np.spacing(1) + margin * 2)
if np.isclose(scale_squared, 0):
return np.nan
k_squared = (2 * sigma) ** 2
oks = np.exp(-dist_squared / (2 * scale_squared * k_squared))
return np.mean(oks)


def match_assemblies(ass_pred, ass_true, sigma):
denom = 2 * scale_squared * k_squared
if symmetric_kpts is None:
pred = xy_pred[visible_gt]
pred[np.isnan(pred)] = np.inf
dist_squared = np.sum((pred - true) ** 2, axis=1)
oks = np.exp(-dist_squared / denom)
return np.mean(oks)
else:
oks = []
xy_preds = [xy_pred]
combos = (pair for l in range(len(symmetric_kpts))
for pair in itertools.combinations(symmetric_kpts, l + 1))
for pairs in combos:
# Swap corresponding keypoints
tmp = xy_pred.copy()
for pair in pairs:
tmp[pair, :] = tmp[pair[::-1], :]
xy_preds.append(tmp)
for xy_pred in xy_preds:
pred = xy_pred[visible_gt]
pred[np.isnan(pred)] = np.inf
dist_squared = np.sum((pred - true) ** 2, axis=1)
oks.append(np.mean(np.exp(-dist_squared / denom)))
return max(oks)

def match_assemblies(
ass_pred,
ass_true,
sigma,
margin=0,
symmetric_kpts=None,
):
inds_true = list(range(len(ass_true)))
inds_pred = np.argsort(
[ins.affinity if ins.n_links else ins.confidence for ins in ass_pred]
Expand All @@ -833,7 +863,11 @@ def match_assemblies(ass_pred, ass_true, sigma):
oks = []
for ind_true in inds_true:
xy_true = ass_true[ind_true].xy
oks.append(calc_object_keypoint_similarity(xy_pred, xy_true, sigma))
oks.append(
calc_object_keypoint_similarity(
xy_pred, xy_true, sigma, margin, symmetric_kpts,
)
)
if np.all(np.isnan(oks)):
continue
ind_best = np.nanargmax(oks)
Expand Down Expand Up @@ -898,6 +932,8 @@ def evaluate_assembly(
ass_true_dict,
oks_sigma=0.072,
oks_thresholds=np.linspace(0.5, 0.95, 10),
margin=0,
symmetric_kpts=None,
):
# sigma is taken as the median of all COCO keypoint standard deviations
all_matched = []
Expand All @@ -906,7 +942,9 @@ def evaluate_assembly(
ass_true = ass_true_dict.get(ind)
if ass_true is None:
continue
matched, unmatched = match_assemblies(ass_pred, ass_true, oks_sigma)
matched, unmatched = match_assemblies(
ass_pred, ass_true, oks_sigma, margin, symmetric_kpts,
)
all_matched.extend(matched)
all_unmatched.extend(unmatched)
if not all_matched:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_inferenceutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def test_calc_object_keypoint_similarity(real_assemblies):
assert inferenceutils.calc_object_keypoint_similarity(xy3, xy1, sigma) == 0
assert np.isnan(inferenceutils.calc_object_keypoint_similarity(xy1, xy3, sigma))

# Test flipped keypoints
xy4 = xy1.copy()
symmetric_pair = [0, 11]
xy4[symmetric_pair] = xy4[symmetric_pair[::-1]]
assert inferenceutils.calc_object_keypoint_similarity(xy1, xy4, sigma) != 1
assert inferenceutils.calc_object_keypoint_similarity(
xy1, xy4, sigma, symmetric_kpts=[symmetric_pair]
) == 1


def test_match_assemblies(real_assemblies):
assemblies = real_assemblies[0]
Expand Down Expand Up @@ -65,6 +74,17 @@ def test_evaluate_assemblies(real_assemblies):
assert dict_["precisions"].shape[1] == 101
np.testing.assert_allclose(dict_["precisions"], 1)

dict_ = inferenceutils.evaluate_assembly(
assemblies,
assemblies,
oks_thresholds=thresholds,
symmetric_kpts=[(0, 5), (1, 4)]
)
assert dict_["mAP"] == dict_["mAR"] == 1
assert len(dict_["precisions"]) == len(dict_["recalls"]) == n_thresholds
assert dict_["precisions"].shape[1] == 101
np.testing.assert_allclose(dict_["precisions"], 1)


def test_link():
pos1 = 1, 1
Expand Down

0 comments on commit d7db8d3

Please sign in to comment.