Skip to content

Commit

Permalink
refactor: remove images_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Aug 21, 2024
1 parent d0cd0ce commit e65f08e
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 144 deletions.
14 changes: 0 additions & 14 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset

Expand Down Expand Up @@ -54,8 +53,6 @@ def __init__(self, server=None):
self.input_paths = known_args.dataset
self.state.current_dataset = str(Path(self.input_paths[0]).resolve())

self.context["images_manager"] = images_manager.ImagesManager()

self.state.collapse_dataset = False
self.state.collapse_embeddings = False
self.state.collapse_filter = False
Expand Down Expand Up @@ -111,7 +108,6 @@ def on_server_ready(self, *args, **kwargs):

def on_dataset_change(self, **kwargs):
# Reset cache
self.context.images_manager = images_manager.ImagesManager()
self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.random_sampling_disabled = False
Expand Down Expand Up @@ -158,16 +154,6 @@ def reload_images(self):
else:
selected_images = images

paths = list()
for image in selected_images:
paths.append(
os.path.join(
os.path.dirname(self.state.current_dataset),
image["file_name"],
)
)

self.context.paths = paths
self.state.dataset_ids = [str(img["id"]) for img in selected_images]

def _build_ui(self):
Expand Down
22 changes: 6 additions & 16 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.app.trame_utils import SetStateAsync
from nrtk_explorer.app.applet import Applet
Expand All @@ -12,6 +11,7 @@
dataset_id_to_image_id,
is_transformed,
)
from nrtk_explorer.app.images.images import get_image

import os

Expand All @@ -36,9 +36,6 @@ def __init__(self, server):

self._ui = None
self.reducer = dimension_reducers.DimReducerManager()
self.is_standalone_app = self.server.state.parent is None
if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

if self.state.current_dataset is None:
self.state.current_dataset = DATASET_DIRS[0]
Expand Down Expand Up @@ -70,7 +67,7 @@ def on_server_ready(self, *args, **kwargs):
def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=feature_extraction_model, manager=self.context.images_manager
model_name=feature_extraction_model
)

def on_current_dataset_change(self, **kwargs):
Expand All @@ -82,9 +79,6 @@ def on_current_dataset_change(self, **kwargs):
self.state.num_elements_max = len(self.images)
self.state.num_elements_disabled = False

if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

def compute_points(self, fit_features, features):
if self.state.tab == "PCA":
return self.reducer.reduce(
Expand Down Expand Up @@ -119,8 +113,9 @@ async def compute_source_points(self):
async with SetStateAsync(self.state):
self.state.is_loading = True

images = [get_image(id) for id in self.state.dataset_ids]
self.features = self.extractor.extract(
paths=self.context.paths,
images,
batch_size=int(self.state.model_batch_size),
)

Expand All @@ -147,19 +142,14 @@ def on_run_clicked(self):
self.update_points()

def on_run_transformations(self, id_to_image):
extractor_prepped = {
id: self.context.images_manager.prepare_for_model(image)
for id, image in id_to_image.items()
}
ids = extractor_prepped.keys()
transformation_features = self.extractor.extract(
paths=ids,
content=extractor_prepped,
id_to_image.values(),
batch_size=int(self.state.model_batch_size),
)

points = self.compute_points(self.features, transformation_features)

ids = id_to_image.keys()
updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)}
self.state.points_transformations = {**self.state.points_transformations, **updated_points}

Expand Down
10 changes: 9 additions & 1 deletion src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Callable, Dict, Sequence
from collections import OrderedDict
import base64
import io
from PIL import Image
from trame.app import get_server
from nrtk_explorer.app.images.image_ids import (
Expand All @@ -10,12 +12,18 @@
from nrtk_explorer.app.images.image_meta import dataset_id_to_meta, update_image_meta
from nrtk_explorer.library.dataset import get_image_path
from nrtk_explorer.app.trame_utils import delete_state, change_checker
from nrtk_explorer.library.images_manager import convert_to_base64
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.library.transforms import ImageTransform
from nrtk_explorer.library.coco_utils import partition


def convert_to_base64(img: Image) -> str:
"""Convert image to base64 string"""
buf = io.BytesIO()
img.save(buf, format="png")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


class BufferCache:
"""Least recently accessed item is removed when the cache is full."""

Expand Down
36 changes: 15 additions & 21 deletions src/nrtk_explorer/library/embeddings_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import numpy as np
import timm
import torch
from PIL.Image import Image

from nrtk_explorer.library import images_manager
from torch.utils.data import DataLoader, Dataset

IMAGE_MODEL_RESOLUTION = (224, 224)


def prepare_for_model(img):
"""Prepare image for model input"""


# Create a dataset for images
class ImagesDataset(Dataset):
Expand All @@ -21,10 +27,7 @@ def __getitem__(self, i):


class EmbeddingsExtractor:
def __init__(
self, model_name="resnet50d", manager=images_manager.ImagesManager(), force_cpu=False
):
self.manager = manager
def __init__(self, model_name="resnet50d", force_cpu=False):
self.device = "cuda" if torch.cuda.is_available() and not force_cpu else "cpu"
self.model = model_name

Expand Down Expand Up @@ -53,27 +56,18 @@ def model(self, model_name):
**timm.data.resolve_model_data_config(self._model.pretrained_cfg)
)

def transform_image(self, img):
def transform_image(self, image: Image):
"""Transform image to fit model input size and format"""
img = image.resize(IMAGE_MODEL_RESOLUTION).convert("RGB")
return self._model_transformer(img).unsqueeze(0)

def extract(self, paths, content=None, batch_size=32):
"""Extract features from images in paths"""
if len(paths) == 0:
return None
def extract(self, images, batch_size=32):
"""Extract features from images"""
if len(images) == 0:
return []

features = list()
transformed_images = list()

# Load images and transform them
for path in paths:
img = None
if content and path in content:
img = content[path]
else:
img = self.manager.load_image_for_model(path)

transformed_images.append(self.transform_image(img))
transformed_images = [self.transform_image(img) for img in images]

# Extract features from images
adjusted_batch_size = batch_size
Expand Down
55 changes: 0 additions & 55 deletions src/nrtk_explorer/library/images_manager.py

This file was deleted.

Loading

0 comments on commit e65f08e

Please sign in to comment.