From d0cd0ced6d2139eeba6053e0e1711cdb0e10914e Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Wed, 7 Aug 2024 11:53:18 -0400 Subject: [PATCH] perf(object_detector): reuse last successful batch size --- src/nrtk_explorer/app/images/images.py | 1 - src/nrtk_explorer/app/transforms.py | 4 ++++ src/nrtk_explorer/app/ui/image_list.py | 4 ++-- src/nrtk_explorer/app/ui/layout.py | 11 ++------- src/nrtk_explorer/library/object_detector.py | 25 +++++++++++++------- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/nrtk_explorer/app/images/images.py b/src/nrtk_explorer/app/images/images.py index 7f6b450e..18c0b2a6 100644 --- a/src/nrtk_explorer/app/images/images.py +++ b/src/nrtk_explorer/app/images/images.py @@ -137,7 +137,6 @@ def get_annotations(detector: ObjectDetector, id_to_image: Dict[str, Image.Image to_detect = {id: id_to_image[id] for id in misses} predictions = detector.eval( to_detect, - batch_size=int(state.object_detection_batch_size), ) for id, annotations in predictions.items(): annotation_cache.add_item(id, annotations) diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index 9baef998..234e9f4d 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -92,11 +92,15 @@ def transformed_became_visible(old, new): def on_server_ready(self, *args, **kwargs): self.state.change("object_detection_model")(self.on_object_detection_model_change) self.on_object_detection_model_change(self.state.object_detection_model) + self.state.change("current_dataset")(self.reset_detector) def on_object_detection_model_change(self, model_name, **kwargs): self.detector = object_detector.ObjectDetector(model_name=model_name) # TODO clear detection results and rerun detection + def reset_detector(self, **kwargs): + self.detector.reset() + def set_on_transform(self, fn): self._on_transform_fn = fn diff --git a/src/nrtk_explorer/app/ui/image_list.py b/src/nrtk_explorer/app/ui/image_list.py index 435fef6c..536a081b 100644 --- a/src/nrtk_explorer/app/ui/image_list.py +++ b/src/nrtk_explorer/app/ui/image_list.py @@ -112,7 +112,7 @@ def __init__(self, on_scroll, on_hover, **kwargs): classes="full-height sticky-header", flat=True, hide_bottom=("image_list_view_mode !== 'grid'", True), - title="Selected Images", + title="Sampled Images", grid=("image_list_view_mode === 'grid'", False), filter=("image_list_search", ""), id="image-list", # set id so that the ImageDetection component can select the container for tooltip positioning @@ -297,7 +297,7 @@ def __init__(self, on_scroll, on_hover, **kwargs): v_slot_top=True, __properties=[("v_slot_top", "v-slot:top='props'")], ): - html.Span("Selected Images", classes="col q-table__title") + html.Span("Sampled Images", classes="col q-table__title") quasar.QSelect( v_model=("visible_columns"), multiple=True, diff --git a/src/nrtk_explorer/app/ui/layout.py b/src/nrtk_explorer/app/ui/layout.py index fd6bb3f1..c762dc7c 100644 --- a/src/nrtk_explorer/app/ui/layout.py +++ b/src/nrtk_explorer/app/ui/layout.py @@ -30,7 +30,7 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf ) = ui.card("collapse_dataset") with dataset_title_slot: - html.Span("Dataset Selection", classes="text-h6") + html.Span("Dataset", classes="text-h6") with dataset_content_slot: quasar.QSelect( @@ -57,7 +57,7 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf quasar.QToggle( v_model=("random_sampling", False), dense=False, - label="Random selection", + label="Random sampling", ) ( @@ -96,13 +96,6 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf emit_value=True, map_options=True, ) - quasar.QInput( - v_model=("object_detection_batch_size", 32), - filled=True, - stack_label=True, - label="Batch Size", - type="number", - ) filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter") diff --git a/src/nrtk_explorer/library/object_detector.py b/src/nrtk_explorer/library/object_detector.py index 95a828bb..0f8a2352 100644 --- a/src/nrtk_explorer/library/object_detector.py +++ b/src/nrtk_explorer/library/object_detector.py @@ -14,6 +14,9 @@ class ImageWithId(NamedTuple): image: Image +STARTING_BATCH_SIZE = 32 + + class ObjectDetector: """Object detection using Hugging Face's transformers library""" @@ -26,6 +29,7 @@ def __init__( self.task = task self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu" self.pipeline = model_name + self.reset() @property def device(self) -> str: @@ -55,10 +59,13 @@ def pipeline(self, model_name: str): # Do not display warnings transformers.utils.logging.set_verbosity_error() + def reset(self): + self.batch_size = STARTING_BATCH_SIZE + def eval( self, images: Dict[str, Image], - batch_size: int = 32, + batch_size: int = 0, # 0 means auto ) -> ImageIdToAnnotations: """Compute object recognition. Returns Annotations grouped by input image paths.""" @@ -72,15 +79,16 @@ def eval( batches.setdefault(size, []) batches[size].append(image) - adjusted_batch_size = batch_size - while adjusted_batch_size > 0: + if batch_size != 0: + self.batch_size = self.batch_size + while self.batch_size > 0: try: predictions_in_baches = [ zip( [image.id for image in imagesInBatch], self.pipeline( [image.image for image in imagesInBatch], - batch_size=adjusted_batch_size, + batch_size=self.batch_size, ), ) for imagesInBatch in batches.values() @@ -94,11 +102,12 @@ def eval( return predictions_by_image_id except RuntimeError as e: - if "out of memory" in str(e) and adjusted_batch_size > 1: - previous_batch_size = adjusted_batch_size - adjusted_batch_size = adjusted_batch_size // 2 + if "out of memory" in str(e) and self.batch_size > 1: + previous_batch_size = self.batch_size + self.batch_size = self.batch_size // 2 + self.batch_size = self.batch_size print( - f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}" + f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={self.batch_size}" ) else: raise