diff --git a/project-dataset/src/globals.py b/project-dataset/src/globals.py index 3fcb1fb..83907ff 100644 --- a/project-dataset/src/globals.py +++ b/project-dataset/src/globals.py @@ -43,7 +43,11 @@ # region ui-consants deployed_nn_tags = ["deployed_nn"] inference_modes = ["full image", "sliding window"] -add_predictions_modes = ["merge with existing labels", "replace existing labels"] +add_predictions_modes = [ + "merge with existing labels", + "replace existing labels", + "replace existing labels and save image tags", +] # endregion # region caches diff --git a/project-dataset/src/ui/inference_preview.py b/project-dataset/src/ui/inference_preview.py index 2b6f781..f3a0896 100644 --- a/project-dataset/src/ui/inference_preview.py +++ b/project-dataset/src/ui/inference_preview.py @@ -392,6 +392,7 @@ def apply_model_to_datasets( else: all_image_infos = {image_info.id: image_info for image_info in image_infos} ann_info_batch = [] + add_mode = settings.add_predictions_mode.get_value() for i in inference_session.inference_project_id_async(project_id, dataset_ids): ann_info_batch.append(i) if len(ann_info_batch) < batch_size: @@ -409,7 +410,29 @@ def apply_model_to_datasets( res_ann_infos.append(ann_info) original_anns = None - add_mode = settings.add_predictions_mode.get_value() + if add_mode == "replace existing labels and save image tags": + img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo] + merged_anns = [] + for ann_info in res_ann_infos: + img_info = all_image_infos[ann_info.image_id] + img_infos_dict.setdefault(img_info.dataset_id, []).append(img_info) + + # download original ann infos + for dataset_id, ds_image_infos in img_infos_dict.items(): + original_ann_infos = g.api.annotation.download_batch( + dataset_id, [image_info.id for image_info in ds_image_infos] + ) + original_anns_dict = { + ann_info.image_id: ann_info for ann_info in original_ann_infos + } + for ann_info in res_ann_infos: + original_ann_info = original_anns_dict[ann_info.image_id] + orig_ann = sly.Annotation.from_json( + original_ann_info.annotation, res_project_meta + ) + pred_ann = sly.Annotation.from_json(ann_info.annotation, res_project_meta) + pred_ann = pred_ann.clone(img_tags=orig_ann.img_tags) + merged_anns.append(original_ann_info._replace(annotation=pred_ann.to_json())) if add_mode == "merge with existing labels": img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo] merged_anns = [] @@ -450,7 +473,29 @@ def apply_model_to_datasets( res_ann_infos.append(ann_info) original_anns = None - add_mode = settings.add_predictions_mode.get_value() + if add_mode == "replace existing labels and save image tags": + img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo] + merged_anns = [] + for ann_info in res_ann_infos: + img_info = all_image_infos[ann_info.image_id] + img_infos_dict.setdefault(img_info.dataset_id, []).append(img_info) + + # download original ann infos + for dataset_id, ds_image_infos in img_infos_dict.items(): + original_ann_infos = g.api.annotation.download_batch( + dataset_id, [image_info.id for image_info in ds_image_infos] + ) + original_anns_dict = { + ann_info.image_id: ann_info for ann_info in original_ann_infos + } + for ann_info in res_ann_infos: + original_ann_info = original_anns_dict[ann_info.image_id] + orig_ann = sly.Annotation.from_json( + original_ann_info.annotation, res_project_meta + ) + pred_ann = sly.Annotation.from_json(ann_info.annotation, res_project_meta) + pred_ann = pred_ann.clone(img_tags=orig_ann.img_tags) + merged_anns.append(original_ann_info._replace(annotation=pred_ann.to_json())) if add_mode == "merge with existing labels": img_infos_dict = {} # dataset_id -> image_id -> [ImageInfo] merged_anns = [] @@ -644,6 +689,15 @@ def get_inference_progress(inference_request_uuid): add_mode = settings.add_predictions_mode.get_value() original_anns = None + if add_mode == "replace existing labels and save image tags": + original_anns = g.api.annotation.download_batch(dataset_id, image_ids) + original_anns = [ + sly.Annotation.from_json(ann_info.annotation, g.project_meta) + for ann_info in original_anns + ] + merged_anns = [] + for ann, pred in zip(original_anns, res_anns): + merged_anns.append(pred.clone(img_tags=ann.img_tags)) if add_mode == "merge with existing labels": original_anns = g.api.annotation.download_batch(dataset_id, image_ids) original_anns = [ diff --git a/project-dataset/src/ui/output_data.py b/project-dataset/src/ui/output_data.py index 1bd1074..63b94d7 100644 --- a/project-dataset/src/ui/output_data.py +++ b/project-dataset/src/ui/output_data.py @@ -30,7 +30,9 @@ card.collapse() -def apply_model_ds(src_project, dst_project, inference_settings, res_project_meta): +def apply_model_ds( + src_project, dst_project, inference_settings, res_project_meta, save_imag_tags=False +): import time timer = {}