Skip to content

Commit

Permalink
Add option to save img_tags
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaiPetukhov committed Aug 14, 2024
1 parent ad0364a commit 0688c96
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 4 deletions.
6 changes: 5 additions & 1 deletion project-dataset/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
# region ui-consants
deployed_nn_tags = ["deployed_nn"]
inference_modes = ["full image", "sliding window"]
add_predictions_modes = ["merge with existing labels", "replace existing labels"]
add_predictions_modes = [
"merge with existing labels",
"replace existing labels",
"replace existing labels and save image tags",
]
# endregion

# region caches
Expand Down
58 changes: 56 additions & 2 deletions project-dataset/src/ui/inference_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def apply_model_to_datasets(
else:
all_image_infos = {image_info.id: image_info for image_info in image_infos}
ann_info_batch = []
add_mode = settings.add_predictions_mode.get_value()
for i in inference_session.inference_project_id_async(project_id, dataset_ids):
ann_info_batch.append(i)
if len(ann_info_batch) < batch_size:
Expand All @@ -409,7 +410,29 @@ def apply_model_to_datasets(
res_ann_infos.append(ann_info)

original_anns = None
add_mode = settings.add_predictions_mode.get_value()
if add_mode == "replace existing labels and save image tags":
img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo]
merged_anns = []
for ann_info in res_ann_infos:
img_info = all_image_infos[ann_info.image_id]
img_infos_dict.setdefault(img_info.dataset_id, []).append(img_info)

# download original ann infos
for dataset_id, ds_image_infos in img_infos_dict.items():
original_ann_infos = g.api.annotation.download_batch(
dataset_id, [image_info.id for image_info in ds_image_infos]
)
original_anns_dict = {
ann_info.image_id: ann_info for ann_info in original_ann_infos
}
for ann_info in res_ann_infos:
original_ann_info = original_anns_dict[ann_info.image_id]
orig_ann = sly.Annotation.from_json(
original_ann_info.annotation, res_project_meta
)
pred_ann = sly.Annotation.from_json(ann_info.annotation, res_project_meta)
pred_ann = pred_ann.clone(img_tags=orig_ann.img_tags)
merged_anns.append(original_ann_info._replace(annotation=pred_ann.to_json()))
if add_mode == "merge with existing labels":
img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo]
merged_anns = []
Expand Down Expand Up @@ -450,7 +473,29 @@ def apply_model_to_datasets(
res_ann_infos.append(ann_info)

original_anns = None
add_mode = settings.add_predictions_mode.get_value()
if add_mode == "replace existing labels and save image tags":
img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo]
merged_anns = []
for ann_info in res_ann_infos:
img_info = all_image_infos[ann_info.image_id]
img_infos_dict.setdefault(img_info.dataset_id, []).append(img_info)

# download original ann infos
for dataset_id, ds_image_infos in img_infos_dict.items():
original_ann_infos = g.api.annotation.download_batch(
dataset_id, [image_info.id for image_info in ds_image_infos]
)
original_anns_dict = {
ann_info.image_id: ann_info for ann_info in original_ann_infos
}
for ann_info in res_ann_infos:
original_ann_info = original_anns_dict[ann_info.image_id]
orig_ann = sly.Annotation.from_json(
original_ann_info.annotation, res_project_meta
)
pred_ann = sly.Annotation.from_json(ann_info.annotation, res_project_meta)
pred_ann = pred_ann.clone(img_tags=orig_ann.img_tags)
merged_anns.append(original_ann_info._replace(annotation=pred_ann.to_json()))
if add_mode == "merge with existing labels":
img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo]
merged_anns = []
Expand Down Expand Up @@ -644,6 +689,15 @@ def get_inference_progress(inference_request_uuid):
add_mode = settings.add_predictions_mode.get_value()

original_anns = None
if add_mode == "replace existing labels and save image tags":
original_anns = g.api.annotation.download_batch(dataset_id, image_ids)
original_anns = [
sly.Annotation.from_json(ann_info.annotation, g.project_meta)
for ann_info in original_anns
]
merged_anns = []
for ann, pred in zip(original_anns, res_anns):
merged_anns.append(pred.clone(img_tags=ann.img_tags))
if add_mode == "merge with existing labels":
original_anns = g.api.annotation.download_batch(dataset_id, image_ids)
original_anns = [
Expand Down
4 changes: 3 additions & 1 deletion project-dataset/src/ui/output_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
card.collapse()


def apply_model_ds(src_project, dst_project, inference_settings, res_project_meta):
def apply_model_ds(
src_project, dst_project, inference_settings, res_project_meta, save_imag_tags=False
):
import time

timer = {}
Expand Down

0 comments on commit 0688c96

Please sign in to comment.