Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add obj detect model ui #49

Merged
merged 3 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ jobs:
name: ubuntu-latest-linters-python
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
Expand All @@ -32,7 +32,7 @@ jobs:
run:
working-directory: vue-components
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
Expand All @@ -53,14 +53,14 @@ jobs:
matrix:
python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: Invoke PyTest
run: pytest .
run: pytest -v .
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ dependencies = [
"Pillow",
"scikit-learn==1.4.1.post1",
"smqtk-classifier==0.19.0",
"accelerate",
"smqtk-core==0.19.0",
"smqtk-dataprovider==0.18.0",
"smqtk-descriptors==0.19.0",
"smqtk-detection[torch,centernet]==0.20.1",
"smqtk-image-io==0.17.1",
"tabulate",
"transformers",
"timm",
"torch",
"torchvision",
Expand Down
12 changes: 8 additions & 4 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@

DIR_NAME = os.path.dirname(nrtk_explorer.test_data.__file__)
DEFAULT_DATASETS = [
f"{DIR_NAME}/coco-od-2017/mini_val2017.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds_test.json",
f"{DIR_NAME}/OIRDS_v1_0/oirds_train.json",
]


Expand Down Expand Up @@ -97,14 +96,18 @@ def __init__(self, server=None):
self.state.client_only("horizontal_split", "vertical_split")

transforms_translator = Translator()
transforms_translator.add_translation("current_model", "current_transforms_model")
transforms_translator.add_translation(
"feature_extraction_model", "current_transforms_model"
)

self._transforms_app = TransformsApp(
server=self.server.create_child_server(translator=transforms_translator)
)

embeddings_translator = Translator()
embeddings_translator.add_translation("current_model", "current_embeddings_model")
embeddings_translator.add_translation(
"feature_extraction_model", "current_embeddings_model"
)

self._embeddings_app = EmbeddingsApp(
server=self.server.create_child_server(translator=embeddings_translator),
Expand Down Expand Up @@ -319,6 +322,7 @@ def ui(self, *args, **kwargs):
with html.Template(v_slot_after=True):
with quasar.QSplitter(
v_model=("vertical_split",),
limits=("[0,100]",),
horizontal=True,
classes="inherit-height zero-height",
before_class="q-pa-md",
Expand Down
14 changes: 7 additions & 7 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def __init__(self, server):
self.features = None

self.state.client_only("camera_position")
self.state.current_model = "resnet50.a1_in1k"
self.state.feature_extraction_model = "resnet50.a1_in1k"

self.server.controller.add("on_server_ready")(self.on_server_ready)
self.transformed_images_cache = {}

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.on_current_dataset_change()
self.on_current_model_change()
self.on_feature_extraction_model_change()
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_model")(self.on_current_model_change)
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)

def on_current_model_change(self, **kwargs):
current_model = self.state.current_model
def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=current_model, manager=self.context.images_manager
model_name=feature_extraction_model, manager=self.context.images_manager
)

def on_current_dataset_change(self, **kwargs):
Expand Down Expand Up @@ -228,7 +228,7 @@ def settings_widget(self):

quasar.QSelect(
label="Embeddings Model",
v_model=("current_model",),
v_model=("feature_extraction_model",),
options=(
[
{"label": "ResNet50", "value": "resnet50.a1_in1k"},
Expand Down
88 changes: 64 additions & 24 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import nrtk_explorer.library.transforms as trans
import nrtk_explorer.library.nrtk_transforms as nrtk_trans
from nrtk_explorer.library import images_manager
from nrtk_explorer.library import images_manager, object_detector
from nrtk_explorer.app.ui.image_list import image_list_component
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app.parameters import ParametersApp
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, server):

self._on_transform_fn = None
self.state.models = [k for k in self.models.keys()]
self.state.current_model = self.state.models[0]
self.state.feature_extraction_model = self.state.models[0]

self._transforms: Dict[str, trans.ImageTransform] = {
"identity": trans.IdentityTransform(),
Expand All @@ -106,10 +106,11 @@ def __init__(self, server):

self.server.controller.add("on_server_ready")(self.on_server_ready)
self._on_hover_fn = None
self.detector = object_detector.ObjectDetector(model_name="hustvl/yolos-tiny")

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("current_model")(self.on_current_model_change)
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.state.change("current_dataset")(self.on_current_dataset_change)
self.state.change("current_num_elements")(self.on_current_num_elements_change)

Expand All @@ -128,7 +129,6 @@ def on_apply_transform(self, *args, **kwargs):
current_transform = self.state.current_transform
transformed_image_ids = []
transform = self._transforms[current_transform]

for image_id in self.state.source_image_ids:
image = self.context["image_objects"][image_id]

Expand All @@ -148,13 +148,46 @@ def on_apply_transform(self, *args, **kwargs):
self.state.hovered_id = -1

self.state.transformed_image_ids = transformed_image_ids

self.update_model_result(self.state.transformed_image_ids, self.state.current_model)
self.compute_annotations(transformed_image_ids)

# Only invoke callbacks when we transform images
if len(transformed_image_ids) > 0:
self.on_transform(transformed_image_ids)

def compute_annotations(self, ids):
"""Compute annotations for the given image ids using the object detector model."""
if len(ids) == 0:
return

for id_ in ids:
self.context["annotations"][id_] = []

prediction = self.detector.eval(paths=ids, content=self.context.image_objects)

for id_, annotations in zip(ids, prediction):
image_annotations = self.context["annotations"].setdefault(id_, [])
for prediction in annotations:
category_id = 0
for cat_id, cat in self.state.annotation_categories.items():
if cat["name"] == prediction["label"]:
category_id = cat_id

bbox = prediction["box"]
image_annotations.append(
{
"category_id": category_id,
"id": category_id,
"bbox": [
bbox["xmin"],
bbox["ymin"],
bbox["xmax"] - bbox["xmin"],
bbox["ymax"] - bbox["ymin"],
],
}
)

self.update_model_result(ids, self.state.feature_extraction_model)

def on_current_num_elements_change(self, current_num_elements, **kwargs):
with open(self.state.current_dataset) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -183,7 +216,7 @@ def on_selected_images_change(self, selected_ids):

image_filename = os.path.join(current_dir, image_metadata["file_name"])

img = self.context.images_manager.load_thumbnail(image_filename)
img = self.context.images_manager.load_image(image_filename)

self.state[image_id] = images_manager.convert_to_base64(img)
self.state[meta_id] = {
Expand All @@ -195,8 +228,8 @@ def on_selected_images_change(self, selected_ids):
self.context.image_objects[image_id] = img

self.state.source_image_ids = source_image_ids

self.update_model_result(self.state.source_image_ids, self.state.current_model)
self.compute_annotations(source_image_ids)
self.update_model_result(self.state.source_image_ids, self.state.feature_extraction_model)
self.on_apply_transform()

def reset_data(self):
Expand Down Expand Up @@ -260,27 +293,18 @@ def on_current_dataset_change(self, current_dataset, **kwargs):
for i, image in enumerate(dataset["images"]):
self.context.image_id_to_index[image["id"]] = i

for annotation in dataset["annotations"]:
image_id = f"img_{annotation['image_id']}"
image_annotations = self.context["annotations"].setdefault(image_id, [])
image_annotations.append(annotation)

transformed_image_id = f"transformed_{image_id}"
image_annotations = self.context["annotations"].setdefault(transformed_image_id, [])
image_annotations.append(annotation)

if self.is_standalone_app:
self.context.images_manager = images_manager.ImagesManager()

def on_current_model_change(self, **kwargs):
logger.info(f">>> ENGINE(a): on_current_model_change change {self.state}")
def on_feature_extraction_model_change(self, **kwargs):
logger.info(f">>> ENGINE(a): on_feature_extraction_model_change change {self.state}")

current_model = self.state.current_model
feature_extraction_model = self.state.feature_extraction_model

self.update_model_result(self.state.source_image_ids, current_model)
self.update_model_result(self.state.transformed_image_ids, current_model)
self.update_model_result(self.state.source_image_ids, feature_extraction_model)
self.update_model_result(self.state.transformed_image_ids, feature_extraction_model)

def update_model_result(self, image_ids, current_model):
def update_model_result(self, image_ids, feature_extraction_model):
for image_id in image_ids:
result_id = image_id_to_result(image_id)
self.state[result_id] = self.context["annotations"].get(image_id, [])
Expand All @@ -301,6 +325,22 @@ def on_hover(self, hover_event):
def settings_widget(self):
with html.Div(trame_server=self.server):
with html.Div(classes="col"):
quasar.QSelect(
label="Object detection Model",
v_model=("object_detection_model", "facebook/detr-resnet-50"),
options=(
[
{
"label": "facebook/detr-resnet-50",
"value": "facebook/detr-resnet-50",
},
],
),
filled=True,
emit_value=True,
map_options=True,
)

self._parameters_app.transform_select_ui()

with html.Div(
Expand Down
5 changes: 3 additions & 2 deletions src/nrtk_explorer/library/images_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from PIL import Image as ImageModule

import base64
import copy
import io

# Resolution for images to be used in model
Expand Down Expand Up @@ -39,15 +40,15 @@ def load_image(self, path):
def load_image_for_model(self, path):
"""Load image for model from path and store it in cache if not already loaded"""
if path not in self.images_for_model:
img = self.load_thumbnail(path)
img = copy.copy(self.load_image(path))
self.images_for_model[path] = self.prepare_for_model(img)

return self.images_for_model[path]

def load_thumbnail(self, path):
"""Load thumbnail from path and store it in cache if not already loaded"""
if path not in self.thumbnails:
img = self.load_image(path)
img = copy.copy(self.load_image(path))
img.thumbnail(THUMBNAIL_RESOLUTION)
self.thumbnails[path] = img

Expand Down
Loading