Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nested dataset support #56

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion project-dataset/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
# gera
TEAM_ID=431
WORKSPACE_ID=1019
PROJECT_ID=38241
PROJECT_ID=41933
90 changes: 49 additions & 41 deletions project-dataset/src/ui/output_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
output_project_thumbnail = ProjectThumbnail()
output_project_thumbnail.hide()

api: sly.Api = g.api

card = Card(
"6️⃣ Output data",
"New project with predictions will be created. Original project will not be modified.",
Expand All @@ -33,6 +35,35 @@
def apply_model_ds(
src_project, dst_project, inference_settings, res_project_meta, save_imag_tags=False
):
def process_ds(src_ds_info, parent_id):
t = time.time()
dst_dataset_info = api.dataset.create(
dst_project.id, src_ds_info.name, src_ds_info.description, parent_id=parent_id
)
dst_dataset_infos[src_ds_info] = dst_dataset_info

src_images = api.image.get_list(src_ds_info.id)
src_ds_image_infos_dict[src_ds_info.id] = {
image_info.id: image_info for image_info in src_images
}
src_image_ids = [image.id for image in src_images]
if len(src_image_ids) > 0:
dst_img_infos = api.image.copy_batch(dst_dataset_info.id, src_image_ids)
else:
dst_img_infos = []
for image_info in dst_img_infos:
dst_image_infos_dict.setdefault(dst_dataset_info.id, {})[image_info.name] = image_info
timer.setdefault(src_ds_info.id, {})["copy"] = time.time() - t

pbar.update(1)
return dst_dataset_info.id

def process_ds_tree(ds_tree, parent_id=None):
for ds_info, children in ds_tree.items():
current_ds_id = process_ds(ds_info, parent_id)
if children:
process_ds_tree(children, current_ds_id)

import time

timer = {}
Expand All @@ -46,35 +77,12 @@ def apply_model_ds(
with inference_progress(
message="Creating datasets...", total=len(selected_datasets)
) as pbar:

src_dataset_infos = g.api.dataset.get_list(src_project)
for src_dataset_info in src_dataset_infos:
t = time.time()
dst_dataset_info = g.api.dataset.copy(
dst_project_id=dst_project.id,
id=src_dataset_info.id,
new_name=src_dataset_info.name,
)
dst_dataset_infos[src_dataset_info] = dst_dataset_info
timer.setdefault(src_dataset_info.id, {})["copy"] = time.time() - t
t = time.time()
for image_info in g.api.image.get_list(dst_dataset_info.id):
dst_image_infos_dict.setdefault(dst_dataset_info.id, {})[
image_info.name
] = image_info
timer.setdefault(src_dataset_info.id, {})["dst_image_infos"] = time.time() - t
t = time.time()
src_ds_image_infos_dict[src_dataset_info.id] = {
image_info.id: image_info
for image_info in g.api.image.get_list(src_dataset_info.id)
}
timer.setdefault(src_dataset_info.id, {})["src_image_infos"] = time.time() - t

pbar.update(1)

src_ds_tree = api.dataset.get_tree(src_project)
src_ds_tree = {k: v for k, v in src_ds_tree.items() if k.id in selected_datasets}
process_ds_tree(src_ds_tree)
# 2. Apply model to the datasets
with inference_progress(message="Processing images...", total=len(g.input_images)) as pbar:
for src_dataset_info in src_dataset_infos:
for src_dataset_info in list(dst_dataset_infos.keys()):
# iterating over batches of predictions
t = time.time()
for (
Expand All @@ -101,7 +109,7 @@ def apply_model_ds(
# Update project meta if needed
if res_project_meta != final_project_meta:
res_project_meta = final_project_meta
g.api.project.update_meta(dst_project.id, res_project_meta.to_json())
api.project.update_meta(dst_project.id, res_project_meta.to_json())
timer.setdefault(src_dataset_info.id, {}).setdefault("update_meta", 0)
timer[src_dataset_info.id]["update_meta"] += time.time() - t
t = time.time()
Expand All @@ -123,14 +131,14 @@ def apply_model_ds(
t = time.time()
# upload_annotations
try:
g.api.annotation.upload_anns(
api.annotation.upload_anns(
[image_info.id for image_info in dst_image_infos], dst_anns
)
pbar.update(len(dst_anns))
except:
for img_info, ann in zip(dst_image_infos, dst_anns):
try:
g.api.annotation.upload_ann(img_info.id, ann)
api.annotation.upload_ann(img_info.id, ann)
except Exception as e:
sly.logger.warn(
msg=f"Image: {img_info.name} (Image ID: {img_info.id}) couldn't be uploaded, image will be skipped, error: {e}.",
Expand All @@ -149,7 +157,7 @@ def apply_model_ds(
timer[src_dataset_info.id]["upload_anns"] += time.time() - t
t = time.time()
except Exception:
g.api.dataset.remove_batch([ds.id for ds in dst_dataset_infos.values()])
api.dataset.remove_batch([ds.id for ds in dst_dataset_infos.values()])
raise
finally:
sly.logger.debug("Timer:", extra={"timer": timer})
Expand All @@ -170,10 +178,10 @@ def apply_model():
)

res_project_meta = g.project_meta.clone()
res_project = g.api.project.create(
res_project = api.project.create(
g.workspace_id, output_project_name.get_value(), change_name_if_conflict=True
)
g.api.project.update_meta(res_project.id, res_project_meta.to_json())
api.project.update_meta(res_project.id, res_project_meta.to_json())

# -------------------------------------- Add Workflow Input -------------------------------------- #
g.workflow.add_input(project_id=g.selected_project, session_id=g.model_session_id)
Expand All @@ -189,11 +197,11 @@ def apply_model():

with inference_progress(message="Processing images...", total=len(g.input_images)) as pbar:
for dataset_id in g.selected_datasets:
dataset_info = g.api.dataset.get_info_by_id(dataset_id)
res_dataset = g.api.dataset.create(
dataset_info = api.dataset.get_info_by_id(dataset_id)
res_dataset = api.dataset.create(
res_project.id, dataset_info.name, dataset_info.description
)
image_infos = g.api.image.get_list(dataset_info.id)
image_infos = api.image.get_list(dataset_info.id)

for batched_image_infos in sly.batched(image_infos, batch_size=10):
try:
Expand Down Expand Up @@ -236,18 +244,18 @@ def apply_model():

if res_project_meta != final_project_meta:
res_project_meta = final_project_meta
g.api.project.update_meta(res_project.id, res_project_meta.to_json())
api.project.update_meta(res_project.id, res_project_meta.to_json())

res_images_infos = g.api.image.upload_ids(
res_images_infos = api.image.upload_ids(
res_dataset.id, res_names, image_ids, metas=res_metas
)
res_ids = [image_info.id for image_info in res_images_infos]
try:
g.api.annotation.upload_anns(res_ids, res_anns)
api.annotation.upload_anns(res_ids, res_anns)
except:
for res_img_info, ann in zip(res_images_infos, res_anns):
try:
g.api.annotation.upload_ann(res_img_info.id, ann)
api.annotation.upload_ann(res_img_info.id, ann)
except Exception as e:
sly.logger.warn(
msg=f"Image: {res_img_info.name} (Image ID: {res_img_info.id}) couldn't be uploaded, image will be skipped, error: {e}.",
Expand All @@ -260,7 +268,7 @@ def apply_model():
)
continue
pbar.update(len(batched_image_infos))
output_project_thumbnail.set(g.api.project.get_info_by_id(res_project.id))
output_project_thumbnail.set(api.project.get_info_by_id(res_project.id))
output_project_thumbnail.show()
# -------------------------------------- Add Workflow Output ------------------------------------- #
g.workflow.add_output(project_id=res_project.id)
Expand Down