From 9160c67f24e399a27c1a4f59fa4dfd1d031c5986 Mon Sep 17 00:00:00 2001 From: Julien Salotti Date: Wed, 15 Jun 2022 14:49:29 +0200 Subject: [PATCH] fix clone copy method --- aloscene/tensors/augmented_tensor.py | 43 ++++++++-------------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/aloscene/tensors/augmented_tensor.py b/aloscene/tensors/augmented_tensor.py index e9fe1b88..bc24c8cd 100644 --- a/aloscene/tensors/augmented_tensor.py +++ b/aloscene/tensors/augmented_tensor.py @@ -2,6 +2,7 @@ import inspect import numpy as np from typing import * +import copy class AugmentedTensor(torch.Tensor): @@ -148,11 +149,7 @@ def __check_child_name_alignment(var): elif not any(v in self.COMMON_DIM_NAMES for v in var.names): return True for dim_id, dim_name in enumerate(var.names): - if ( - dim_id < len(self.names) - and dim_name in self.COMMON_DIM_NAMES - and self.names[dim_id] == dim_name - ): + if dim_id < len(self.names) and dim_name in self.COMMON_DIM_NAMES and self.names[dim_id] == dim_name: return True raise Exception( f"Impossible to align label name dim ({var.names}) with the tensor name dim ({self.names})." @@ -319,9 +316,9 @@ def __setattr__(self, key, value): def clone(self, *args, **kwargs): n_frame = super().clone(*args, **kwargs) - n_frame._property_list = self._property_list - n_frame._children_list = self._children_list - n_frame._child_property = self._child_property + n_frame._property_list = copy.deepcopy(self._property_list) + n_frame._children_list = copy.deepcopy(self._children_list) + n_frame._child_property = copy.deepcopy(self._child_property) for name in self._property_list: setattr(n_frame, name, getattr(self, name)) @@ -393,9 +390,7 @@ def cuda(self, *args, **kwargs): for name in self._children_list: label = getattr(self, name) if label is not None: - setattr( - n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs)) - ) + setattr(n_frame, name, self.apply_on_child(label, lambda l: l.cuda(*args, **kwargs))) return n_frame def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=True): @@ -407,10 +402,7 @@ def _merge_child(self, label, label_name, key, dict_merge, kwargs, check_dim=Tru # merge everything on the real target dimension. target_dim = 0 - if ( - check_dim - and self.names[target_dim] not in self._child_property[label_name]["align_dim"] - ): + if check_dim and self.names[target_dim] not in self._child_property[label_name]["align_dim"]: raise Exception( "Can only merge labeled tensor on the following dimension '{}'. \ \nDrop the labels before to apply such operations or convert your labeled tensor to tensor first.".format( @@ -491,9 +483,7 @@ def _merge_tensor(self, n_tensor, tensor_list, func, types, args=(), kwargs=None label_value, label_name, key, labels_dict2list[label_name], kwargs ) else: - self._merge_child( - label_value, label_name, label_name, labels_dict2list, kwargs - ) + self._merge_child(label_value, label_name, label_name, labels_dict2list, kwargs) else: raise Exception("Can't merge none AugmentedTensor with AugmentedTensor") @@ -519,7 +509,7 @@ def _squeeze_unsqueeze_dim(self, tensor, func, types, squeeze, args=(), kwargs=N """ dim = kwargs["dim"] if "dim" in kwargs else 0 - if dim != 0 and dim !=1: + if dim != 0 and dim != 1: raise Exception( f"Impossible to expand the labeld tensor on the given dim: {dim}. Export your labeled tensor into tensor before to do it." ) @@ -538,9 +528,7 @@ def _handle_expand_on_label(label, name): for name in self._children_list: label = getattr(tensor, name) if label is not None: - results = self.apply_on_child( - label, lambda l: _handle_expand_on_label(l, name), on_list=False - ) + results = self.apply_on_child(label, lambda l: _handle_expand_on_label(l, name), on_list=False) setattr(tensor, name, results) def __iter__(self): @@ -732,12 +720,7 @@ def __repr__(self): if isinstance(values[key], list): cvalue = ( f"{key}:[" - + ", ".join( - [ - "{}".format(len(k) if k is not None else None) - for k in values[key] - ] - ) + + ", ".join(["{}".format(len(k) if k is not None else None) for k in values[key]]) + "]" ) content_value += f"{cvalue}, " @@ -970,9 +953,7 @@ def pad(self, offset_y: tuple, offset_x: tuple, **kwargs): offset_x = (offset_x[0] / self.W, offset_x[1] / self.W) padded = self._pad(offset_y, offset_x, **kwargs) - padded.recursive_apply_on_children_( - lambda label: self._pad_label(label, offset_y, offset_x, **kwargs) - ) + padded.recursive_apply_on_children_(lambda label: self._pad_label(label, offset_y, offset_x, **kwargs)) return padded def _spatial_shift_label(self, label, shift_y, shift_x, **kwargs):