Skip to content

Commit

Permalink
Fix merge skeletons
Browse files Browse the repository at this point in the history
  • Loading branch information
gitttt-1234 committed Dec 20, 2024
1 parent 6061951 commit 12081d7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
58 changes: 33 additions & 25 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,34 +466,42 @@ def _update_from_labels(self, merge: bool = False):
self.videos.extend(list(new_videos))

# Ditto for skeletons
if merge or len(self.skeletons) == 0:

if not self.skeletons:
# if `labels.skeletons` is empty, then add all new skeletons
self.skeletons = list(
set(self.skeletons).union(
{
instance.skeleton
for label in self.labels
for instance in label.instances
}
)
if len(self.skeletons) == 0:
# if `labels.skeletons` is empty, then add all new skeletons
self.skeletons = list(
set(self.skeletons).union(
{
instance.skeleton
for label in self.labels
for instance in label.instances
}
)
)

else:
for lf in self.labels:
for instance in lf.instances:
for skeleton in self.skeletons:
# check if the new skeleton is already in `labels.skeletons`
if not skeleton.matches(instance.skeleton):
self.skeletons.append(instance.skeleton)
else:
# assign the existing skeleton if the instance has duplicate skeleton
instance.skeleton = skeleton

# Ditto for nodes
if merge or len(self.nodes) == 0:
if len(self.nodes) == 0:
self.nodes = list(
set().union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
)
)

if merge:

# remove duplicate skeletons during merge
skeletons = [self.skeletons[0]]
for lf in self.labels:
for instance in lf.instances:
for skeleton in skeletons:
# check if the new skeleton is already in `labels.skeletons`
if not skeleton.matches(instance.skeleton):
skeletons.append(instance.skeleton)
else:
# assign the existing skeleton if the instance has duplicate skeleton
instance.skeleton = skeleton

self.skeletons = skeletons

# updates nodes after removing duplicate skeletons
self.nodes = list(
set().union(
{node for skeleton in self.skeletons for node in skeleton.nodes}
Expand Down
1 change: 1 addition & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ def test_dont_unify_skeletons():

skeleton_a = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")
skeleton_b = Skeleton.load_json("tests/data/skeleton/fly_skeleton_legs.json")
# skeleton_b.add_node("foo")

lf_a = LabeledFrame(vid, frame_idx=2, instances=[Instance(skeleton_a)])
lf_b = LabeledFrame(vid, frame_idx=3, instances=[Instance(skeleton_b)])
Expand Down

0 comments on commit 12081d7

Please sign in to comment.