Skip to content

Commit

Permalink
Fix OneFormer post_process_instance_segmentation for panoptic tasks (
Browse files Browse the repository at this point in the history
…#29304)

* 🐛 Fix oneformer instance post processing when using panoptic task type

* ✅ Add unit test for oneformer instance post processing panoptic bug

---------

Co-authored-by: Nick DeGroot <1966472+nickthegroot@users.noreply.github.com>
  • Loading branch information
nickthegroot and nickthegroot authored Mar 4, 2024
1 parent 81220cb commit 8ef9862
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1244,8 +1244,8 @@ def post_process_instance_segmentation(
# if this is panoptic segmentation, we only keep the "thing" classes
if task_type == "panoptic":
keep = torch.zeros_like(scores_per_image).bool()
for i, lab in enumerate(labels_per_image):
keep[i] = lab in self.metadata["thing_ids"]
for j, lab in enumerate(labels_per_image):
keep[j] = lab in self.metadata["thing_ids"]

scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
Expand All @@ -1258,8 +1258,8 @@ def post_process_instance_segmentation(
continue

if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
for i in range(labels_per_image.shape[0]):
labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item())
for j in range(labels_per_image.shape[0]):
labels_per_image[j] = self.metadata["thing_ids"].index(labels_per_image[j].item())

# Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None
Expand Down
13 changes: 13 additions & 0 deletions tests/models/oneformer/test_image_processing_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,19 @@ def test_post_process_instance_segmentation(self):
el["segmentation"].shape, (self.image_processor_tester.height, self.image_processor_tester.width)
)

segmentation_with_opts = image_processor.post_process_instance_segmentation(
outputs,
threshold=0,
target_sizes=[(1, 4) for _ in range(self.image_processor_tester.batch_size)],
task_type="panoptic",
)
self.assertTrue(len(segmentation_with_opts) == self.image_processor_tester.batch_size)
for el in segmentation_with_opts:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (1, 4))

def test_post_process_panoptic_segmentation(self):
image_processor = self.image_processing_class(
num_labels=self.image_processor_tester.num_classes,
Expand Down

0 comments on commit 8ef9862

Please sign in to comment.