Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: AugmentedTensor clone method #191

Merged
merged 1 commit into from
Jul 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 12 additions & 31 deletions aloscene/tensors/augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import numpy as np
from typing import *
import copy


class AugmentedTensor(torch.Tensor):
Expand Down Expand Up @@ -148,11 +149,7 @@ def __check_child_name_alignment(var):
elif not any(v in self.COMMON_DIM_NAMES for v in var.names):
return True
for dim_id, dim_name in enumerate(var.names):
if (
dim_id < len(self.names)
and dim_name in self.COMMON_DIM_NAMES
and self.names[dim_id] == dim_name
):
if dim_id < len(self.names) and dim_name in self.COMMON_DIM_NAMES and self.names[dim_id] == dim_name:
return True
raise Exception(
f"Impossible to align label name dim ({var.names}) with the tensor name dim ({self.names})."
Expand Down Expand Up @@ -319,9 +316,9 @@ def __setattr__(self, key, value):

def clone(self, *args, **kwargs):
n_frame = super().clone(*args, **kwargs)
n_frame._property_list = self._property_list
n_frame._children_list = self._children_list
n_frame._child_property = self._child_property
n_frame._property_list = copy.deepcopy(self._property_list)
n_frame._children_list = copy.deepcopy(self._children_list)
n_frame._child_property = copy.deepcopy(self._child_property)

for name in self._property_list:
setattr(n_frame, name, getattr(self, name))
Expand Down Expand Up @@ -393,9 +390,7 @@ def cuda(self, *args, **kwargs):
for name in self._children_list:
label = getattr(self, name)
if label is not None:
setattr(
n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs))
)
setattr(n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs)))
return n_frame

def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=True):
Expand All @@ -407,10 +402,7 @@ def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=Tru
# merge everything on the real target dimension.
target_dim = 0

if (
check_dim
and self.names[target_dim] not in self._child_property[label_name]["align_dim"]
):
if check_dim and self.names[target_dim] not in self._child_property[label_name]["align_dim"]:
raise Exception(
"Can only merge labeled tensor on the following dimension '{}'. \
\nDrop the labels before to apply such operations or convert your labeled tensor to tensor first.".format(
Expand Down Expand Up @@ -491,9 +483,7 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None
label_value, label_name, key, labels_dict2list[label_name], kwargs
)
else:
self._merge_child(
label_value, label_name, label_name, labels_dict2list, kwargs
)
self._merge_child(label_value, label_name, label_name, labels_dict2list, kwargs)
else:
raise Exception("Can't merge none AugmentedTensor with AugmentedTensor")

Expand All @@ -519,7 +509,7 @@ def _squeeze_unsqueeze_dim(self, tensor, func, types, squeeze, args=(), kwargs=N
"""
dim = kwargs["dim"] if "dim" in kwargs else 0

if dim != 0 and dim !=1:
if dim != 0 and dim != 1:
raise Exception(
f"Impossible to expand the labeld tensor on the given dim: {dim}. Export your labeled tensor into tensor before to do it."
)
Expand All @@ -538,9 +528,7 @@ def _handle_expand_on_label(label, name):
for name in self._children_list:
label = getattr(tensor, name)
if label is not None:
results = self.apply_on_child(
label, lambda l: _handle_expand_on_label(l, name), on_list=False
)
results = self.apply_on_child(label, lambda l: _handle_expand_on_label(l, name), on_list=False)
setattr(tensor, name, results)

def __iter__(self):
Expand Down Expand Up @@ -732,12 +720,7 @@ def __repr__(self):
if isinstance(values[key], list):
cvalue = (
f"{key}:["
+ ", ".join(
[
"{}".format(len(k) if k is not None else None)
for k in values[key]
]
)
+ ", ".join(["{}".format(len(k) if k is not None else None) for k in values[key]])
+ "]"
)
content_value += f"{cvalue}, "
Expand Down Expand Up @@ -970,9 +953,7 @@ def pad(self, offset_y: tuple, offset_x: tuple, **kwargs):
offset_x = (offset_x[0] / self.W, offset_x[1] / self.W)

padded = self._pad(offset_y, offset_x, **kwargs)
padded.recursive_apply_on_children_(
lambda label: self._pad_label(label, offset_y, offset_x, **kwargs)
)
padded.recursive_apply_on_children_(lambda label: self._pad_label(label, offset_y, offset_x, **kwargs))
return padded

def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):
Expand Down