Skip to content

Commit

Permalink
refactor(transforms): dont duplicate annotations on context
Browse files Browse the repository at this point in the history
Stashing them on state once is enough
  • Loading branch information
PaulHax committed Jul 15, 2024
1 parent 6b434ee commit 65e5c8f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Iterable
from pathlib import Path

from trame.widgets import html
from trame_server.utils.namespace import Translator
Expand All @@ -13,7 +14,7 @@
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app import ui
import nrtk_explorer.test_data
from pathlib import Path
from nrtk_explorer.app.image_ids import image_id_to_result_id

import os

Expand Down Expand Up @@ -56,7 +57,6 @@ def __init__(self, server=None):

self.context["image_objects"] = {}
self.context["images_manager"] = images_manager.ImagesManager()
self.context["annotations"] = {}

self.state.collapse_dataset = False
self.state.collapse_embeddings = False
Expand Down Expand Up @@ -141,7 +141,7 @@ def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs):
for index, image_id in enumerate(self.state.images_ids):
image_annotations_categories = map(
lambda annotation: annotation["category_id"],
self.context["annotations"].get(f"img_{image_id}", []),
self.state.get(image_id_to_result_id(image_id), []),
)

include = filter.evaluate(image_annotations_categories)
Expand Down
58 changes: 20 additions & 38 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, server):

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
# self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)

Expand Down Expand Up @@ -151,16 +151,13 @@ def on_apply_transform(self, *args, **kwargs):
if len(transformed_image_ids) == 0:
return

# Erase current annotations
dataset_ids = [image_id_to_dataset_id(id) for id in self.state.source_image_ids]
for ann in self.context.dataset.anns.values():
if str(ann["image_id"]) in dataset_ids:
transformed_id = f"transformed_img_{ann['image_id']}"
if transformed_id in self.context["annotations"]:
del self.context["annotations"][transformed_id]
result_ids = [image_id_to_result_id(id) for id in transformed_image_ids]
for id in result_ids:
delete_state(self.state, id)

annotations = self.compute_annotations(transformed_image_ids)

dataset_ids = [image_id_to_dataset_id(id) for id in transformed_image_ids]
predictions = convert_from_predictions_to_second_arg(annotations)
scores = compute_score(
dataset_ids,
Expand All @@ -174,7 +171,7 @@ def on_apply_transform(self, *args, **kwargs):
{"original_detection_to_transformed_detection_score": score},
)

ground_truth_annotations = [self.context["annotations"][id] for id in dataset_ids]
ground_truth_annotations = [self.state[image_id_to_result_id(id)] for id in dataset_ids]
ground_truth_predictions = convert_from_ground_truth_to_first_arg(ground_truth_annotations)
scores = compute_score(
dataset_ids,
Expand Down Expand Up @@ -218,38 +215,28 @@ def compute_annotations(self, ids):
],
}
)
self.context["annotations"][id_] = image_annotations
self.state[image_id_to_result_id(id_)] = image_annotations

self.sync_annotations_to_state(ids)
return predictions

def on_current_num_elements_change(self, current_num_elements, **kwargs):
ids = [img["id"] for img in self.context.dataset.imgs.values()]
return self.set_source_images(ids[:current_num_elements])

def delete_annotations(self, ids):
for id in ids:
if id in self.context["annotations"]:
del self.context["annotations"][id]

def load_ground_truth_annotations(self, dataset_ids):
# collect annotations for each dataset_id
annotations = {
dataset_id: [
image_id_to_result_id(dataset_id): [
annotation
for annotation in self.context.dataset.anns.values()
if str(annotation["image_id"]) == dataset_id
]
for dataset_id in dataset_ids
}
for dataset_id, ground_truth_annotations in annotations.items():
self.context["annotations"][dataset_id] = ground_truth_annotations

self.sync_annotations_to_state(dataset_ids)
self.state.update(annotations)

def compute_predictions_source_images(self, ids):
"""Compute the predictions for the source images."""
dataset_ids = [image_id_to_dataset_id(id) for id in ids]

if len(ids) == 0:
return
Expand All @@ -262,7 +249,8 @@ def compute_predictions_source_images(self, ids):
ids,
)

ground_truth_annotations = [self.context["annotations"][id] for id in dataset_ids]
dataset_ids = [image_id_to_dataset_id(id) for id in ids]
ground_truth_annotations = [self.state[image_id_to_result_id(id)] for id in dataset_ids]
ground_truth_predictions = convert_from_ground_truth_to_second_arg(
ground_truth_annotations, self.context.dataset
)
Expand Down Expand Up @@ -324,13 +312,11 @@ def set_selected_dataset_ids(self, selected_dataset_ids: Sequence[int]):
self._start_update_images()

def delete_computed_image_data(self):
image_ids = self.state.source_image_ids + self.state.transformed_image_ids
for image_id in image_ids:
source_and_transformed = self.state.source_image_ids + self.state.transformed_image_ids
for image_id in source_and_transformed:
delete_state(self.state, image_id)
if image_id in self.context["image_objects"]:
del self.context["image_objects"][image_id]
result_id = image_id_to_result_id(image_id)
delete_state(self.state, result_id)

for dataset_id in self.context.selected_dataset_ids:
delete_image_meta(self.server.state, dataset_id)
Expand All @@ -340,7 +326,8 @@ def delete_computed_image_data(self):
+ self.state.source_image_ids
+ self.state.transformed_image_ids
)
self.delete_annotations(ids_with_annotations)
for id in ids_with_annotations:
delete_state(self.state, image_id_to_result_id(id))

self.state.source_image_ids = []
self.state.transformed_image_ids = []
Expand All @@ -365,16 +352,11 @@ def on_current_dataset_change(self, current_dataset, **kwargs):
if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

def on_feature_extraction_model_change(self, **kwargs):
logger.debug(f">>> on_feature_extraction_model_change change {self.state}")

self.sync_annotations_to_state(self.state.source_image_ids)
self.sync_annotations_to_state(self.state.transformed_image_ids)

def sync_annotations_to_state(self, image_ids):
for image_id in image_ids:
result_id = image_id_to_result_id(image_id)
self.state[result_id] = self.context["annotations"].get(image_id, [])
# No GUI, not tested
# def on_feature_extraction_model_change(self, **kwargs):
# logger.debug(f">>> on_feature_extraction_model_change change {self.state}")
# self.delete_computed_image_data()
# self._start_update_images()

def on_image_hovered(self, id):
self.state.hovered_id = id
Expand Down

0 comments on commit 65e5c8f

Please sign in to comment.