From ca552fff4684ab6a9f4e31abd1346f850bac2147 Mon Sep 17 00:00:00 2001 From: Brian Moore Date: Tue, 11 Jul 2023 13:22:02 -0400 Subject: [PATCH] Fixing #3277 (#3279) * fixing filter_keypoints() bug * adding dynamic doc test --- fiftyone/core/collections.py | 8 +- fiftyone/core/stages.py | 21 ++--- tests/unittests/view_tests.py | 146 ++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 18 deletions(-) diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index 80a2371ae8..cdcc194bc2 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -9787,7 +9787,7 @@ def _make_set_field_pipeline( self, field, expr, - embedded_root, + embedded_root=embedded_root, allow_missing=allow_missing, new_field=new_field, context=context, @@ -10367,7 +10367,9 @@ def _parse_field_name( other_list_fields = sorted(other_list_fields) def _replace(path): - return ".".join([new_field] + path.split(".")[1:]) + n = new_field.count(".") + 1 + chunks = path.split(".", n) + return new_field + "." + chunks[-1] if len(chunks) > n else new_field if new_field: field_name = _replace(field_name) @@ -10458,7 +10460,7 @@ def _make_set_field_pipeline( sample_collection, field, expr, - embedded_root, + embedded_root=False, allow_missing=False, new_field=None, context=None, diff --git a/fiftyone/core/stages.py b/fiftyone/core/stages.py index 6af6acdaf4..5c05bab244 100644 --- a/fiftyone/core/stages.py +++ b/fiftyone/core/stages.py @@ -2733,7 +2733,6 @@ def to_mongo(self, sample_collection): _, points_path = sample_collection._get_label_field_path( self._field, "points" ) - new_field = self._get_new_field(sample_collection) pipeline = [] @@ -2808,25 +2807,17 @@ def to_mongo(self, sample_collection): if self._only_matches: # Remove Keypoint objects with no points after filtering + has_points = ( + F("points").filter(F()[0] != float("nan")).length() > 0 + ) if is_list_field: - has_points = ( - F("points").filter(F()[0] != float("nan")).length() > 0 - ) - match_expr = F("keypoints").filter(has_points) + only_expr = F().filter(has_points) else: - field, _ = sample_collection._handle_frame_field(new_field) - has_points = ( - F(field + ".points") - .filter(F()[0] != float("nan")) - .length() - > 0 - ) - match_expr = has_points.if_else(F(field), None) + only_expr = has_points.if_else(F(), None) _pipeline, _ = sample_collection._make_set_field_pipeline( root_path, - match_expr, - embedded_root=True, + only_expr, allow_missing=True, new_field=self._new_field, ) diff --git a/tests/unittests/view_tests.py b/tests/unittests/view_tests.py index 441ba2ab85..603817d48b 100644 --- a/tests/unittests/view_tests.py +++ b/tests/unittests/view_tests.py @@ -2738,6 +2738,152 @@ def test_filter_keypoints(self): view = dataset.filter_keypoints("kps", labels=[]) self.assertEqual(len(view), 0) + def test_filter_keypoints_embedded_document(self): + sample1 = fo.Sample( + filepath="image1.jpg", + dynamic=fo.DynamicEmbeddedDocument( + kp=fo.Keypoint( + label="person", + points=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], + confidence=[0.5, 0.6, 0.7, 0.8, 0.9], + ), + kps=fo.Keypoints( + keypoints=[ + fo.Keypoint( + label="person", + points=[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], + confidence=[0.5, 0.6, 0.7, 0.8, 0.9], + ), + fo.Keypoint(), + ] + ), + ), + ) + + sample2 = fo.Sample(filepath="image2.jpg") + + dataset = fo.Dataset() + dataset.add_samples([sample1, sample2], dynamic=True) + + dataset.default_skeleton = fo.KeypointSkeleton( + labels=["nose", "left eye", "right eye", "left ear", "right ear"], + edges=[[0, 1, 2, 0], [0, 3], [0, 4]], + ) + + count_nans = lambda points: len([p for p in points if np.isnan(p[0])]) + + # + # Test `Keypoint` sample fields + # + + # only_matches=True + view = dataset.filter_keypoints( + "dynamic.kp", filter=F("confidence") > 0.75 + ) + self.assertEqual(len(view), 1) + sample = view.first() + self.assertEqual(len(sample["dynamic.kp"].points), 5) + self.assertEqual(count_nans(sample["dynamic.kp"].points), 3) + + # only_matches=False + view = dataset.filter_keypoints( + "dynamic.kp", filter=F("confidence") > 0.75, only_matches=False + ) + self.assertEqual(len(view), 2) + sample = view.first() + self.assertEqual(len(sample["dynamic.kp"].points), 5) + self.assertEqual(count_nans(sample["dynamic.kp"].points), 3) + + # view with no matches + view = dataset.filter_keypoints( + "dynamic.kp", filter=F("confidence") > 0.95 + ) + self.assertEqual(len(view), 0) + + # only_matches=True + view = dataset.filter_keypoints( + "dynamic.kp", labels=["left eye", "right eye"] + ) + self.assertEqual(len(view), 1) + sample = view.first() + self.assertEqual(len(sample["dynamic.kp"].points), 5) + self.assertEqual(count_nans(sample["dynamic.kp"].points), 3) + + # only_matches=False + view = dataset.filter_keypoints( + "dynamic.kp", labels=["left eye", "right eye"], only_matches=False + ) + self.assertEqual(len(view), 2) + sample = view.first() + self.assertEqual(len(sample["dynamic.kp"].points), 5) + self.assertEqual(count_nans(sample["dynamic.kp"].points), 3) + + # view with no matches + view = dataset.filter_keypoints("dynamic.kp", labels=[]) + self.assertEqual(len(view), 0) + + # + # Test `Keypoints` sample fields + # + + # only_matches=True + view = dataset.filter_keypoints( + "dynamic.kps", filter=F("confidence") > 0.75 + ) + self.assertEqual(len(view), 1) + self.assertEqual(view.count("dynamic.kps.keypoints"), 1) + sample = view.first() + self.assertEqual(len(sample["dynamic.kps"].keypoints[0].points), 5) + self.assertEqual( + count_nans(sample["dynamic.kps"].keypoints[0].points), 3 + ) + + # only_matches=False + view = dataset.filter_keypoints( + "dynamic.kps", filter=F("confidence") > 0.75, only_matches=False + ) + self.assertEqual(len(view), 2) + self.assertEqual(view.count("dynamic.kps.keypoints"), 2) + sample = view.first() + self.assertEqual(len(sample["dynamic.kps"].keypoints[0].points), 5) + self.assertEqual( + count_nans(sample["dynamic.kps"].keypoints[0].points), 3 + ) + + # view with no matches + view = dataset.filter_keypoints( + "dynamic.kps", filter=F("confidence") > 0.95 + ) + self.assertEqual(len(view), 0) + + # only_matches=True + view = dataset.filter_keypoints( + "dynamic.kps", labels=["left eye", "right eye"] + ) + self.assertEqual(len(view), 1) + self.assertEqual(view.count("dynamic.kps.keypoints"), 1) + sample = view.first() + self.assertEqual(len(sample["dynamic.kps"].keypoints[0].points), 5) + self.assertEqual( + count_nans(sample["dynamic.kps"].keypoints[0].points), 3 + ) + + # only_matches=False + view = dataset.filter_keypoints( + "dynamic.kps", labels=["left eye", "right eye"], only_matches=False + ) + self.assertEqual(len(view), 2) + self.assertEqual(view.count("dynamic.kps.keypoints"), 2) + sample = view.first() + self.assertEqual(len(sample["dynamic.kps"].keypoints[0].points), 5) + self.assertEqual( + count_nans(sample["dynamic.kps"].keypoints[0].points), 3 + ) + + # view with no matches + view = dataset.filter_keypoints("dynamic.kps", labels=[]) + self.assertEqual(len(view), 0) + def test_filter_keypoints_frames(self): sample1 = fo.Sample(filepath="video1.mp4") sample1.frames[1] = fo.Frame(