Skip to content

Commit

Permalink
Merge pull request #191 from Visual-Behavior/fix_clone
Browse files Browse the repository at this point in the history
fix: AugmentedTensor clone method
  • Loading branch information
thibo73800 authored Jul 5, 2022
2 parents 4b983be + 9160c67 commit 91db508
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions aloscene/tensors/augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import numpy as np
from typing import *
import copy


class AugmentedTensor(torch.Tensor):
Expand Down Expand Up @@ -148,11 +149,7 @@ def __check_child_name_alignment(var):
elif not any(v in self.COMMON_DIM_NAMES for v in var.names):
return True
for dim_id, dim_name in enumerate(var.names):
if (
dim_id < len(self.names)
and dim_name in self.COMMON_DIM_NAMES
and self.names[dim_id] == dim_name
):
if dim_id < len(self.names) and dim_name in self.COMMON_DIM_NAMES and self.names[dim_id] == dim_name:
return True
raise Exception(
f"Impossible to align label name dim ({var.names}) with the tensor name dim ({self.names})."
Expand Down Expand Up @@ -319,9 +316,9 @@ def __setattr__(self, key, value):

def clone(self, *args, **kwargs):
n_frame = super().clone(*args, **kwargs)
n_frame._property_list = self._property_list
n_frame._children_list = self._children_list
n_frame._child_property = self._child_property
n_frame._property_list = copy.deepcopy(self._property_list)
n_frame._children_list = copy.deepcopy(self._children_list)
n_frame._child_property = copy.deepcopy(self._child_property)

for name in self._property_list:
setattr(n_frame, name, getattr(self, name))
Expand Down Expand Up @@ -393,9 +390,7 @@ def cuda(self, *args, **kwargs):
for name in self._children_list:
label = getattr(self, name)
if label is not None:
setattr(
n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs))
)
setattr(n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs)))
return n_frame

def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=True):
Expand All @@ -407,10 +402,7 @@ def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=Tru
# merge everything on the real target dimension.
target_dim = 0

if (
check_dim
and self.names[target_dim] not in self._child_property[label_name]["align_dim"]
):
if check_dim and self.names[target_dim] not in self._child_property[label_name]["align_dim"]:
raise Exception(
"Can only merge labeled tensor on the following dimension '{}'. \
\nDrop the labels before to apply such operations or convert your labeled tensor to tensor first.".format(
Expand Down Expand Up @@ -491,9 +483,7 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
label_value, label_name, key, labels_dict2list[label_name], kwargs
)
else:
self._merge_child(
label_value, label_name, label_name, labels_dict2list, kwargs
)
self._merge_child(label_value, label_name, label_name, labels_dict2list, kwargs)
else:
raise Exception("Can't merge none AugmentedTensor with AugmentedTensor")

Expand All @@ -519,7 +509,7 @@ def _squeeze_unsqueeze_dim(self, tensor, func, types, squeeze, args=(), kwargs=N
"""
dim = kwargs["dim"] if "dim" in kwargs else 0

if dim != 0 and dim !=1:
if dim != 0 and dim != 1:
raise Exception(
f"Impossible to expand the labeld tensor on the given dim: {dim}. Export your labeled tensor into tensor before to do it."
)
Expand All @@ -538,9 +528,7 @@ def _handle_expand_on_label(label, name):
for name in self._children_list:
label = getattr(tensor, name)
if label is not None:
results = self.apply_on_child(
label, lambda l: _handle_expand_on_label(l, name), on_list=False
)
results = self.apply_on_child(label, lambda l: _handle_expand_on_label(l, name), on_list=False)
setattr(tensor, name, results)

def __iter__(self):
Expand Down Expand Up @@ -732,12 +720,7 @@ def __repr__(self):
if isinstance(values[key], list):
cvalue = (
f"{key}:["
+ ", ".join(
[
"{}".format(len(k) if k is not None else None)
for k in values[key]
]
)
+ ", ".join(["{}".format(len(k) if k is not None else None) for k in values[key]])
+ "]"
)
content_value += f"{cvalue}, "
Expand Down Expand Up @@ -970,9 +953,7 @@ def pad(self, offset_y: tuple, offset_x: tuple, **kwargs):
offset_x = (offset_x[0] / self.W, offset_x[1] / self.W)

padded = self._pad(offset_y, offset_x, **kwargs)
padded.recursive_apply_on_children_(
lambda label: self._pad_label(label, offset_y, offset_x, **kwargs)
)
padded.recursive_apply_on_children_(lambda label: self._pad_label(label, offset_y, offset_x, **kwargs))
return padded

def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):
Expand Down

0 comments on commit 91db508

Please sign in to comment.