Skip to content

Commit

Permalink
fix(embeddings): add transformed img point as computed
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Aug 6, 2024
1 parent 81d858b commit da0bb37
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
20 changes: 12 additions & 8 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,19 @@ def __init__(self, server):
"is_transformed": True,
}

self.state.highlighted_image_id = ""
self.state.highlighted_image_is_transformed = False

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.on_current_dataset_change()
self.on_feature_extraction_model_change()
self.state.change("current_dataset")(self.on_current_dataset_change)

self.on_feature_extraction_model_change()
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)

self.update_points()
self.state.change("dataset_ids")(self.update_points)

self.server.controller.apply_transform.add(self.clear_points_transformations)

def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
Expand Down Expand Up @@ -112,6 +113,9 @@ def compute_points(self, fit_features, features):
**args,
)

def clear_points_transformations(self, **kwargs):
self.state.points_transformations = {} # ID to points

async def compute_source_points(self):
async with SetStateAsync(self.state):
self.state.is_loading = True
Expand All @@ -127,7 +131,8 @@ async def compute_source_points(self):
id: point for id, point in zip(self.state.dataset_ids, points)
}

self.state.points_transformations = {} # ID to points
self.clear_points_transformations()

self.state.user_selected_ids = []
self.state.camera_position = []

Expand Down Expand Up @@ -156,9 +161,8 @@ def on_run_transformations(self, id_to_image):

points = self.compute_points(self.features, transformation_features)

self.state.points_transformations = {
image_id_to_dataset_id(id): point for id, point in zip(ids, points)
}
updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)}
self.state.points_transformations = {**self.state.points_transformations, **updated_points}

def on_select(self, image_ids):
self.state.user_selected_ids = image_ids
Expand Down
5 changes: 4 additions & 1 deletion src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def init_state(old, new):
state.hovered_id = None


def clear_transformed():
def clear_transformed(**kwargs):
for id in state.dataset_ids:
keys = get_image_state_keys(id)
image_cache.clear_item(keys["transformed_image"])
Expand All @@ -225,3 +225,6 @@ def clear_transformed():
"ground_truth_to_transformed_detection_score": 0,
},
)


ctrl.apply_transform.add(clear_transformed)
4 changes: 1 addition & 3 deletions src/nrtk_explorer/app/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def __init__(self, server):
self.state.transforms = [k for k in self._transforms.keys()]
self.state.current_transform = self.state.transforms[0]

self.on_apply_transform = lambda: None

self.server.controller.add("on_server_ready")(self.on_server_ready)

self._ui = None
Expand Down Expand Up @@ -76,7 +74,7 @@ def transform_apply_ui(self):
with html.Div(trame_server=self.server):
quasar.QBtn(
"Apply",
click=(self.on_apply_transform,),
click=(self.server.controller.apply_transform),
classes="full-width",
flat=True,
)
Expand Down
12 changes: 3 additions & 9 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def __init__(self, server):
server=server,
)

self._parameters_app.on_apply_transform = self.on_apply_transform

self._ui = None

self._on_transform_fn = None
Expand Down Expand Up @@ -89,6 +87,7 @@ def transformed_became_visible(old, new):
)

self.server.controller.add("on_server_ready")(self.on_server_ready)
self.server.controller.apply_transform.add(self.schedule_transformed_images)
self._on_hover_fn = None

def on_server_ready(self, *args, **kwargs):
Expand All @@ -106,13 +105,8 @@ def on_transform(self, *args, **kwargs):
if self._on_transform_fn:
self._on_transform_fn(*args, **kwargs)

def on_apply_transform(self, *args, **kwargs):
"""Parameters changed"""
logger.debug("on_apply_transform")
clear_transformed()
self.schedule_transformed_images()

def schedule_transformed_images(self, *args):
def schedule_transformed_images(self, *args, **kwargs):
logger.debug("schedule_transformed_images")
if self._updating_images():
if self._updating_transformed_images:
# computing stale transformed images, restart task
Expand Down

0 comments on commit da0bb37

Please sign in to comment.