From b2f65ceb3d4046b92363fff1f449d6a30b6ccbdf Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Fri, 27 Sep 2024 14:37:55 -0400 Subject: [PATCH] feat(transforms): add 3 more object detection models --- .../app/images/stateful_annotations.py | 4 +-- src/nrtk_explorer/app/transforms.py | 28 +++++++++++-------- src/nrtk_explorer/app/ui/layout.py | 12 ++++++++ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/nrtk_explorer/app/images/stateful_annotations.py b/src/nrtk_explorer/app/images/stateful_annotations.py index 35dd3f2d..55ea3080 100644 --- a/src/nrtk_explorer/app/images/stateful_annotations.py +++ b/src/nrtk_explorer/app/images/stateful_annotations.py @@ -67,8 +67,8 @@ def __init__( add_to_cache_callback, delete_from_cache_callback ) - @change("current_dataset") - def _on_dataset(self, **kwargs): + @change("current_dataset", "object_detection_model") + def _cache_clear(self, **kwargs): self.annotations_factory.cache_clear() diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index fb1790a7..a8dcfcbf 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -172,25 +172,31 @@ def delete_meta_state(old_ids, new_ids): feature_enabled_state_key="transform_enabled", gui_switch_key="transform_enabled_switch", column_name=TRANSFORM_COLUMNS[0], - enabled_callback=self.schedule_transformed_images, + enabled_callback=self._start_transformed_images, ) self.server.controller.on_server_ready.add(self.on_server_ready) self.server.controller.apply_transform.add(self.on_apply_transform) self._on_hover_fn = None + self.visible_dataset_ids = [] # set by ImageList via self.on_scroll callback + @property def get_image_fpath(self): return self.server.controller.get_image_fpath def on_server_ready(self, *args, **kwargs): self.state.change("object_detection_model")(self.on_object_detection_model_change) - self.on_object_detection_model_change(self.state.object_detection_model) + self.on_object_detection_model_change() self.state.change("current_dataset")(self.reset_detector) - def on_object_detection_model_change(self, model_name, **kwargs): - self.detector = object_detector.ObjectDetector(model_name=model_name) - # TODO clear detection results and rerun detection + def on_object_detection_model_change(self, **kwargs): + self.original_detection_annotations.cache_clear() + self.transformed_detection_annotations.cache_clear() + self.detector = object_detector.ObjectDetector( + model_name=self.state.object_detection_model + ) + self._start_update_images() def reset_detector(self, **kwargs): self.detector.reset() @@ -205,10 +211,10 @@ def on_transform(self, *args, **kwargs): def on_apply_transform(self, **kwargs): # Turn on switch if user clicked lower apply button self.state.transform_enabled_switch = True - self.schedule_transformed_images() + self._start_transformed_images() - def schedule_transformed_images(self, *args, **kwargs): - logger.debug("schedule_transformed_images") + def _start_transformed_images(self, *args, **kwargs): + logger.debug("_start_transformed_images") if self._updating_images(): if self._updating_transformed_images: # computing stale transformed images, restart task @@ -216,7 +222,7 @@ def schedule_transformed_images(self, *args, **kwargs): else: return # update_images will call update_transformed_images() at the end self._update_task = asynchronous.create_task( - self.update_transformed_images(self.visible_ids) + self.update_transformed_images(self.visible_dataset_ids) ) async def update_transformed_images(self, dataset_ids): @@ -343,13 +349,13 @@ async def _update_images(self, dataset_ids): def _start_update_images(self): if hasattr(self, "_update_task"): self._update_task.cancel() - self._update_task = asynchronous.create_task(self._update_images(self.visible_ids)) + self._update_task = asynchronous.create_task(self._update_images(self.visible_dataset_ids)) def _updating_images(self): return hasattr(self, "_update_task") and not self._update_task.done() def on_scroll(self, visible_ids): - self.visible_ids = visible_ids + self.visible_dataset_ids = visible_ids self._start_update_images() def on_image_hovered(self, id): diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py index 6b91c190..cea6d57e 100644 --- a/src/nrtk_explorer/app/ui/layout.py +++ b/src/nrtk_explorer/app/ui/layout.py @@ -74,6 +74,18 @@ def __init__( "label": "facebook/detr-resnet-50", "value": "facebook/detr-resnet-50", }, + { + "label": "facebook/detr-resnet-50-dc5", + "value": "facebook/detr-resnet-50-dc5", + }, + { + "label": "hustvl/yolos-tiny", + "value": "hustvl/yolos-tiny", + }, + { + "label": "valentinafeve/yolos-fashionpedia", + "value": "valentinafeve/yolos-fashionpedia", + }, ], ), filled=True,