Skip to content

Commit

Permalink
refactor(images): move module level funcs to class
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Sep 23, 2024
1 parent 3bedfd0 commit 8291cb2
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 126 deletions.
7 changes: 6 additions & 1 deletion src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset, get_image_fpath

from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
Expand Down Expand Up @@ -52,11 +53,15 @@ def __init__(self, server=None):
self.state.current_dataset = str(Path(self.input_paths[0]).resolve())

self.ctrl.get_image_fpath = lambda i: get_image_fpath(i, self.state.current_dataset)
images = Images(server=self.server)

self._transforms_app = TransformsApp(server=self.server.create_child_server())
self._transforms_app = TransformsApp(
server=self.server.create_child_server(), images=images
)

self._embeddings_app = EmbeddingsApp(
server=self.server.create_child_server(),
images=images,
)

filtering_translator = Translator()
Expand Down
15 changes: 10 additions & 5 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
dataset_id_to_image_id,
is_transformed,
)
from nrtk_explorer.app.images.images import get_image
from nrtk_explorer.app.images.images import Images

from pathlib import Path

Expand All @@ -20,10 +20,16 @@


class EmbeddingsApp(Applet):
def __init__(self, server, datasets=None):
def __init__(
self,
server,
datasets=None,
images=None,
):
super().__init__(server)

self._dataset_paths = datasets
self.images = images or Images(server)
self._on_hover_fn = None
self._ui = None
self.reducer = dimension_reducers.DimReducerManager()
Expand Down Expand Up @@ -71,8 +77,7 @@ def on_current_dataset_change(self, **kwargs):
if self.context.dataset is None:
self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True)

self.images = list(self.context.dataset.imgs.values())
self.state.num_elements_max = len(self.images)
self.state.num_elements_max = len(list(self.context.dataset.imgs))
self.state.num_elements_disabled = False

def compute_points(self, fit_features, features):
Expand Down Expand Up @@ -116,7 +121,7 @@ async def compute_source_points(self):
# Don't lock server before enabling the spinner on client
await self.server.network_completion

images = [get_image(id) for id in self.state.dataset_ids]
images = [self.images.get_image(id) for id in self.state.dataset_ids]
self.features = self.extractor.extract(
images,
batch_size=int(self.state.model_batch_size),
Expand Down
1 change: 1 addition & 0 deletions src/nrtk_explorer/app/images/image_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def is_transformed(image_id: str):

def get_image_state_keys(dataset_id: str):
return {
"original_image": dataset_id_to_image_id(dataset_id),
"ground_truth": image_id_to_result_id(dataset_id),
"original_image_detection": image_id_to_result_id(dataset_id_to_image_id(dataset_id)),
"transformed_image": dataset_id_to_transformed_image_id(dataset_id),
Expand Down
46 changes: 0 additions & 46 deletions src/nrtk_explorer/app/images/image_server.py

This file was deleted.

94 changes: 43 additions & 51 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
import io
from functools import lru_cache
from PIL import Image
from trame.app import get_server
from trame.decorators import TrameApp, change, controller
from nrtk_explorer.app.images.image_ids import (
dataset_id_to_image_id,
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.app.images.annotations import DeleteCallbackRef
from nrtk_explorer.app.trame_utils import delete_state
from nrtk_explorer.library.transforms import ImageTransform


IMAGE_CACHE_SIZE = 50
IMAGE_CACHE_SIZE = 200


def convert_to_base64(img: Image.Image) -> str:
Expand All @@ -20,52 +22,42 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


server = get_server()
state, context, ctrl = server.state, server.context, server.controller


class RefCountedState:
def __init__(self, key, value):
self.key = key
self.value = value

def __del__(self):
delete_state(state, self.key)


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_image(dataset_id: str):
image_path = server.controller.get_image_fpath(int(dataset_id))
image = Image.open(image_path)
return image


@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def get_cached_transformed_image(transform: ImageTransform, dataset_id: str):
key = dataset_id_to_transformed_image_id(dataset_id)

original = get_image(dataset_id)
transformed = transform.execute(original)
if original.size != transformed.size:
# Resize so pixel-wise annotation similarity score works
transformed = transformed.resize(original.size)

state[key] = convert_to_base64(transformed)
return RefCountedState(key, transformed)


def get_transformed_image(transform: ImageTransform, dataset_id: str):
return get_cached_transformed_image(transform, dataset_id).value


@state.change("current_dataset")
def clear_all(**kwargs):
get_image.cache_clear()
get_cached_transformed_image.cache_clear()


def clear_transformed(**kwargs):
get_cached_transformed_image.cache_clear()


ctrl.apply_transform.add(clear_transformed)
@TrameApp()
class Images:
def __init__(self, server):
self.server = server

@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def _get_cached_image(self, dataset_id: str):
image_path = self.server.controller.get_image_fpath(int(dataset_id))
image = Image.open(image_path)
image_id = dataset_id_to_image_id(dataset_id)
self.server.state[image_id] = convert_to_base64(image)
return DeleteCallbackRef(lambda: delete_state(self.server.state, image_id), image)

def get_image(self, dataset_id: str):
return self._get_cached_image(dataset_id).value

@lru_cache(maxsize=IMAGE_CACHE_SIZE)
def _get_cached_transformed_image(self, transform: ImageTransform, dataset_id: str):
original = self.get_image(dataset_id)
transformed = transform.execute(original)
if original.size != transformed.size:
# Resize so pixel-wise annotation similarity score works
transformed = transformed.resize(original.size)

image_id = dataset_id_to_transformed_image_id(dataset_id)
self.server.state[image_id] = convert_to_base64(transformed)
return DeleteCallbackRef(lambda: delete_state(self.server.state, image_id), transformed)

def get_transformed_image(self, transform: ImageTransform, dataset_id: str):
return self._get_cached_transformed_image(transform, dataset_id).value

@change("current_dataset")
def clear_all(self, **kwargs):
self._get_cached_image.cache_clear()
self.clear_transformed()

@controller.add("apply_transform")
def clear_transformed(self, **kwargs):
self._get_cached_transformed_image.cache_clear()
28 changes: 18 additions & 10 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,13 @@
dataset_id_to_transformed_image_id,
)
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.app.images.images import (
get_image,
get_transformed_image,
)
from nrtk_explorer.app.images.images import Images
from nrtk_explorer.app.images.stateful_annotations import (
make_stateful_annotations,
make_stateful_predictor,
)
from nrtk_explorer.app.ui.image_list import TRANSFORM_COLUMNS, ORIGINAL_COLUMNS

import nrtk_explorer.app.images.image_server # noqa module level side effects


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -92,12 +87,15 @@ class TransformsApp(Applet):
def __init__(
self,
server,
images=None,
ground_truth_annotations=None,
original_detection_annotations=None,
transformed_detection_annotations=None,
):
super().__init__(server)

self.images = images or Images(server)

ground_truth_annotations = ground_truth_annotations or make_stateful_annotations(server)
self.ground_truth_annotations = ground_truth_annotations.annotations_factory

Expand Down Expand Up @@ -237,7 +235,7 @@ async def _update_transformed_images(self, dataset_ids):
id_to_matching_size_img = {}
for id in dataset_ids:
with self.state:
transformed = get_transformed_image(transform, id)
transformed = self.images.get_transformed_image(transform, id)
id_to_matching_size_img[dataset_id_to_transformed_image_id(id)] = transformed
await self.server.network_completion

Expand Down Expand Up @@ -279,7 +277,9 @@ async def _update_transformed_images(self, dataset_ids):
)

id_to_image = {
dataset_id_to_transformed_image_id(id): get_transformed_image(transform, id)
dataset_id_to_transformed_image_id(id): self.images.get_transformed_image(
transform, id
)
for id in dataset_ids
}

Expand All @@ -291,7 +291,9 @@ def compute_predictions_original_images(self, dataset_ids):
if not self.state.predictions_original_images_enabled:
return

image_id_to_image = {dataset_id_to_image_id(id): get_image(id) for id in dataset_ids}
image_id_to_image = {
dataset_id_to_image_id(id): self.images.get_image(id) for id in dataset_ids
}
annotations = self.original_detection_annotations.get_annotations(
self.detector, image_id_to_image
)
Expand Down Expand Up @@ -319,8 +321,14 @@ def compute_predictions_original_images(self, dataset_ids):
)

async def _update_images(self, dataset_ids):
# load images on state for ImageList
with self.state:
for id in dataset_ids:
self.images.get_image(id)
await self.server.network_completion

with self.state:
self.ground_truth_annotations.get_annotations(dataset_ids) # updates state
self.ground_truth_annotations.get_annotations(dataset_ids)
await self.server.network_completion

with self.state:
Expand Down
28 changes: 15 additions & 13 deletions src/nrtk_explorer/app/ui/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,25 @@ def column_toggler(old_columns, new_columns):
make_dependent_columns_handler(state, TRANSFORM_COLUMNS)


state.client_only("image_list_ids", "image_size_image_list")
state.client_only("image_size_image_list")


# create reactive annotation variables so ImageDetection component has live Refs
@state.change("dataset_ids")
def init_state(**kwargs):
for id in state.dataset_ids:
def set_image_list_ids(dataset_ids):
# create reactive variables so ImageDetection components have live Refs
for id in dataset_ids:
keys = get_image_state_keys(id)
for key in keys.values():
if not state.has(key):
state[key] = None
state.image_list_ids = dataset_ids


@state.change("dataset_ids", "user_selected_ids")
def update_image_list_ids(**kwargs):
if len(state.user_selected_ids) > 0:
state.image_list_ids = state.user_selected_ids
set_image_list_ids(state.user_selected_ids)
else:
state.image_list_ids = state.dataset_ids
set_image_list_ids(state.dataset_ids)


state.pagination = {}
Expand All @@ -119,7 +119,7 @@ def update_pagination(**kwargs):
state.pagination = {**state.pagination, "rowsPerPage": 12}
ctrl.get_visible_ids()
else:
state.pagination = {**state.pagination, "rowsPerPage": 0}
state.pagination = {**state.pagination, "rowsPerPage": 0} # show all rows


class ImageWithSpinner(html.Div):
Expand Down Expand Up @@ -202,16 +202,18 @@ def __init__(self, on_scroll, on_hover, **kwargs):
r"""image_list_ids.map((id) =>
{
const meta = get(`meta_${id}`)?.value ?? {original_ground_to_original_detection_score: 0, ground_truth_to_transformed_detection_score: 0, original_detection_to_transformed_detection_score: 0}
const original_id = `img_${id}`
const transformed_id = `transformed_img_${id}`
return {
...meta,
original_ground_to_original_detection_score: meta.original_ground_to_original_detection_score.toFixed(2),
ground_truth_to_transformed_detection_score: meta.ground_truth_to_transformed_detection_score.toFixed(2),
original_detection_to_transformed_detection_score: meta.original_detection_to_transformed_detection_score.toFixed(2),
id,
original: `img_${id}`,
original_src: `original-image/${id}`,
transformed: `transformed_img_${id}`,
transformed_src: get(`transformed_img_${id}`).value,
original: original_id,
original_src: get(original_id).value,
transformed: transformed_id,
transformed_src: get(transformed_id).value,
groundTruthAnnotations: get(`result_${id}`),
originalAnnotations: get(`result_img_${id}`),
transformedAnnotations: get(`result_transformed_img_${id}`),
Expand Down Expand Up @@ -350,7 +352,7 @@ def __init__(self, on_scroll, on_hover, **kwargs):
)
ImageWithSpinner(
identifier=("props.row.transformed",),
src=("get(props.row.transformed).value",),
src=("props.row.transformed_src",),
annotations=("props.row.transformedAnnotations",),
categories=("annotation_categories",),
selected=("(props.row.transformed == hovered_id)",),
Expand Down

0 comments on commit 8291cb2

Please sign in to comment.