Skip to content

Commit

Permalink
perf(object_detector): reuse last successful batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Aug 7, 2024
1 parent c629860 commit d0cd0ce
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
1 change: 0 additions & 1 deletion src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def get_annotations(detector: ObjectDetector, id_to_image: Dict[str, Image.Image
to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
batch_size=int(state.object_detection_batch_size),
)
for id, annotations in predictions.items():
annotation_cache.add_item(id, annotations)
Expand Down
4 changes: 4 additions & 0 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def transformed_became_visible(old, new):
def on_server_ready(self, *args, **kwargs):
self.state.change("object_detection_model")(self.on_object_detection_model_change)
self.on_object_detection_model_change(self.state.object_detection_model)
self.state.change("current_dataset")(self.reset_detector)

def on_object_detection_model_change(self, model_name, **kwargs):
self.detector = object_detector.ObjectDetector(model_name=model_name)
# TODO clear detection results and rerun detection

def reset_detector(self, **kwargs):
self.detector.reset()

def set_on_transform(self, fn):
self._on_transform_fn = fn

Expand Down
4 changes: 2 additions & 2 deletions src/nrtk_explorer/app/ui/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, on_scroll, on_hover, **kwargs):
classes="full-height sticky-header",
flat=True,
hide_bottom=("image_list_view_mode !== 'grid'", True),
title="Selected Images",
title="Sampled Images",
grid=("image_list_view_mode === 'grid'", False),
filter=("image_list_search", ""),
id="image-list", # set id so that the ImageDetection component can select the container for tooltip positioning
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(self, on_scroll, on_hover, **kwargs):
v_slot_top=True,
__properties=[("v_slot_top", "v-slot:top='props'")],
):
html.Span("Selected Images", classes="col q-table__title")
html.Span("Sampled Images", classes="col q-table__title")
quasar.QSelect(
v_model=("visible_columns"),
multiple=True,
Expand Down
11 changes: 2 additions & 9 deletions src/nrtk_explorer/app/ui/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf
) = ui.card("collapse_dataset")

with dataset_title_slot:
html.Span("Dataset Selection", classes="text-h6")
html.Span("Dataset", classes="text-h6")

with dataset_content_slot:
quasar.QSelect(
Expand All @@ -57,7 +57,7 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf
quasar.QToggle(
v_model=("random_sampling", False),
dense=False,
label="Random selection",
label="Random sampling",
)

(
Expand Down Expand Up @@ -96,13 +96,6 @@ def parameters(dataset_paths=[], embeddings_app=None, filtering_app=None, transf
emit_value=True,
map_options=True,
)
quasar.QInput(
v_model=("object_detection_batch_size", 32),
filled=True,
stack_label=True,
label="Batch Size",
type="number",
)

filter_title_slot, filter_content_slot, filter_actions_slot = ui.card("collapse_filter")

Expand Down
25 changes: 17 additions & 8 deletions src/nrtk_explorer/library/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class ImageWithId(NamedTuple):
image: Image


STARTING_BATCH_SIZE = 32


class ObjectDetector:
"""Object detection using Hugging Face's transformers library"""

Expand All @@ -26,6 +29,7 @@ def __init__(
self.task = task
self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
self.pipeline = model_name
self.reset()

@property
def device(self) -> str:
Expand Down Expand Up @@ -55,10 +59,13 @@ def pipeline(self, model_name: str):
# Do not display warnings
transformers.utils.logging.set_verbosity_error()

def reset(self):
self.batch_size = STARTING_BATCH_SIZE

def eval(
self,
images: Dict[str, Image],
batch_size: int = 32,
batch_size: int = 0, # 0 means auto
) -> ImageIdToAnnotations:
"""Compute object recognition. Returns Annotations grouped by input image paths."""

Expand All @@ -72,15 +79,16 @@ def eval(
batches.setdefault(size, [])
batches[size].append(image)

adjusted_batch_size = batch_size
while adjusted_batch_size > 0:
if batch_size != 0:
self.batch_size = self.batch_size
while self.batch_size > 0:
try:
predictions_in_baches = [
zip(
[image.id for image in imagesInBatch],
self.pipeline(
[image.image for image in imagesInBatch],
batch_size=adjusted_batch_size,
batch_size=self.batch_size,
),
)
for imagesInBatch in batches.values()
Expand All @@ -94,11 +102,12 @@ def eval(
return predictions_by_image_id

except RuntimeError as e:
if "out of memory" in str(e) and adjusted_batch_size > 1:
previous_batch_size = adjusted_batch_size
adjusted_batch_size = adjusted_batch_size // 2
if "out of memory" in str(e) and self.batch_size > 1:
previous_batch_size = self.batch_size
self.batch_size = self.batch_size // 2
self.batch_size = self.batch_size
print(
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={adjusted_batch_size}"
f"OOM (Pytorch exception {e}) due to batch_size={previous_batch_size}, setting batch_size={self.batch_size}"
)
else:
raise
Expand Down

0 comments on commit d0cd0ce

Please sign in to comment.