Skip to content

Commit

Permalink
refactor(annotations): move logic from images.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Sep 6, 2024
1 parent f2ee488 commit b0a42ee
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 30 deletions.
53 changes: 53 additions & 0 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from functools import lru_cache, partial
from typing import Sequence


class DeleteCallbackRef:
def __init__(self, del_callback, value):
self.del_callback = del_callback
self.value = value

def __del__(self):
self.del_callback()


ANNOTATION_CACHE_SIZE = 20


def get_annotations_from_dataset(
context, add_to_cache_callback, delete_from_cache_callback, dataset_id: str
):
dataset = context.dataset
annotations = [
annotation
for annotation in dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
add_to_cache_callback(dataset_id, annotations)
with_id = partial(delete_from_cache_callback, dataset_id)
return DeleteCallbackRef(with_id, annotations)


class GroundTruthAnnotations:
def __init__(
self,
context, # for dataset
add_to_cache_callback,
delete_from_cache_callback,
):
with_callbacks = partial(
get_annotations_from_dataset,
context,
add_to_cache_callback,
delete_from_cache_callback,
)
self.get_annotations_for_image = lru_cache(maxsize=ANNOTATION_CACHE_SIZE)(with_callbacks)

def get_annotations(self, dataset_ids: Sequence[str]):
return {
dataset_id: self.get_annotations_for_image(dataset_id).value
for dataset_id in dataset_ids
}

def cache_clear(self):
self.get_annotations_for_image.cache_clear()
25 changes: 0 additions & 25 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ def __del__(self):
state, context, ctrl = server.state, server.context, server.controller


# syncs trame state
def delete_from_state(key: str):
delete_state(state, key)


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_image(dataset_id: str):
image_path = server.controller.get_image_fpath(int(dataset_id))
Expand Down Expand Up @@ -158,26 +153,6 @@ def get_annotations(detector: ObjectDetector, id_to_image: Dict[str, Image.Image
return {id: predictions[id] for id in id_to_image.keys()}


def get_ground_truth_annotations(dataset_ids: Sequence[str]):
hits, misses = partition(annotation_cache.get_item, dataset_ids)

annotations = {
dataset_id: [
annotation
for annotation in context.dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
for dataset_id in misses
}

for id, boxes_for_image in annotations.items():
annotation_cache.add_item(id, boxes_for_image)
add_annotation_to_state(id, boxes_for_image)

annotations.update({id: annotation_cache.get_item(id) for id in hits})
return [annotations[dataset_id] for dataset_id in dataset_ids]


def get_image_state_keys(dataset_id: str):
return {
"ground_truth": image_id_to_result_id(dataset_id),
Expand Down
40 changes: 40 additions & 0 deletions src/nrtk_explorer/app/images/stateful_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from functools import partial
from typing import Any, Type
from .annotations import GroundTruthAnnotations
from trame.app import get_server
from trame.decorators import TrameApp, change
from nrtk_explorer.app.images.image_ids import (
image_id_to_result_id,
)
from nrtk_explorer.app.trame_utils import delete_state

server = get_server()
state, context, ctrl = server.state, server.context, server.controller


def add_annotation_to_state(state: Any, image_id: str, annotations: Any):
state[image_id_to_result_id(image_id)] = annotations


def remove_annotation_from_state(state: Any, image_id: str):
delete_state(state, image_id_to_result_id(image_id))


@TrameApp()
class StatefulAnnotations:
def __init__(
self,
annotations_factory_class: Type[GroundTruthAnnotations],
context=context, # for dataset
add_to_cache_callback=partial(add_annotation_to_state, state),
delete_from_cache_callback=partial(remove_annotation_from_state, state),
server=server,
):
self.annotations_factory = annotations_factory_class(
context, add_to_cache_callback, delete_from_cache_callback
)
self.server = server

@change("current_dataset")
def _on_dataset(self, **kwargs):
self.annotations_factory.cache_clear()
23 changes: 18 additions & 5 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
get_image,
get_transformed_image,
get_annotations,
get_ground_truth_annotations,
)
from nrtk_explorer.app.images.annotations import GroundTruthAnnotations
from nrtk_explorer.app.images.stately_annotations import (
StatefulAnnotations,
)

import nrtk_explorer.app.images.image_server # noqa module level side effects
Expand All @@ -48,9 +51,15 @@


class TransformsApp(Applet):
def __init__(self, server):
def __init__(
self,
server,
ground_truth_annotations=StatefulAnnotations(GroundTruthAnnotations),
):
super().__init__(server)

self.ground_truth_annotations = ground_truth_annotations.annotations_factory

self._parameters_app = ParametersApp(
server=server,
)
Expand Down Expand Up @@ -165,7 +174,9 @@ async def _update_transformed_images(self, dataset_ids):
{"original_detection_to_transformed_detection_score": score},
)

ground_truth_annotations = get_ground_truth_annotations(dataset_ids)
ground_truth_annotations = self.ground_truth_annotations.get_annotations(
dataset_ids
).values()
ground_truth_predictions = convert_from_ground_truth_to_first_arg(ground_truth_annotations)
scores = compute_score(
dataset_ids,
Expand Down Expand Up @@ -200,7 +211,9 @@ def compute_predictions_source_images(self, dataset_ids):
dataset_ids,
)

ground_truth_annotations = get_ground_truth_annotations(dataset_ids)
ground_truth_annotations = self.ground_truth_annotations.get_annotations(
dataset_ids
).values()
ground_truth_predictions = convert_from_ground_truth_to_second_arg(
ground_truth_annotations, self.context.dataset
)
Expand All @@ -216,7 +229,7 @@ def compute_predictions_source_images(self, dataset_ids):

async def _update_images(self, dataset_ids):
async with SetStateAsync(self.state):
get_ground_truth_annotations(dataset_ids) # updates state
self.ground_truth_annotations.get_annotations(dataset_ids) # updates state

async with SetStateAsync(self.state):
self.compute_predictions_source_images(dataset_ids)
Expand Down

0 comments on commit b0a42ee

Please sign in to comment.