diff --git a/alonet/torch2trt/base_exporter.py b/alonet/torch2trt/base_exporter.py index 5cb7b07e..354cfa54 100644 --- a/alonet/torch2trt/base_exporter.py +++ b/alonet/torch2trt/base_exporter.py @@ -14,7 +14,8 @@ import tensorrt as trt import pycuda.driver as cuda prod_package_error = None -except Exception as prod_package_error: +except Exception as e: + prod_package_error = e pass @@ -93,10 +94,6 @@ def __init__( """ if prod_package_error is not None: raise prod_package_error - if prod_package_error is not None: - raise prod_package_error - if prod_package_error is not None: - raise prod_package_error if model is not None: assert hasattr(model, "tracing") and model.tracing, "Model must be instantiated with tracing=True" diff --git a/aloscene/tensors/augmented_tensor.py b/aloscene/tensors/augmented_tensor.py index d5a414b3..e9fe1b88 100644 --- a/aloscene/tensors/augmented_tensor.py +++ b/aloscene/tensors/augmented_tensor.py @@ -519,7 +519,7 @@ def _squeeze_unsqueeze_dim(self, tensor, func, types, squeeze, args=(), kwargs=N """ dim = kwargs["dim"] if "dim" in kwargs else 0 - if dim != 0: + if dim != 0 and dim !=1: raise Exception( f"Impossible to expand the labeld tensor on the given dim: {dim}. Export your labeled tensor into tensor before to do it." ) diff --git a/aloscene/tensors/spatial_augmented_tensor.py b/aloscene/tensors/spatial_augmented_tensor.py index e3029cae..b1049e4d 100644 --- a/aloscene/tensors/spatial_augmented_tensor.py +++ b/aloscene/tensors/spatial_augmented_tensor.py @@ -187,9 +187,14 @@ def relative_to_absolute(self, x, dim, assert_integer=False): assert x.is_integer(), f"relative coordinates {x} have produced non-integer absolute coordinates" return round(x) - def temporal(self): + def temporal(self, dim=0): """Add a temporal dimension on the tensor + Parameters + ---------- + dim : int + The dim on which to add the temporal dimension. Can be 0 or 1 + Returns ------- temporal_frame: aloscene.Frame @@ -198,13 +203,20 @@ def temporal(self): if "T" in self.names: # Already a temporal frame return self + def set_n_names(names): + pass + tensor = self.rename(None) - tensor = torch.unsqueeze(tensor, dim=0) - tensor.rename_(*tuple(["T"] + list(tensor._saved_names))) + tensor = torch.unsqueeze(tensor, dim=dim) + n_names = list(tensor._saved_names) + n_names.insert(dim, "T") + tensor.rename_(*tuple(n_names)) def batch_label(tensor, label, name): if tensor._child_property[name]["mergeable"]: - label.rename_(*tuple(["T"] + list(label._saved_names))) + n_label_names = list(label._saved_names) + n_label_names.insert(dim, "T") + label.rename_(*tuple(n_label_names)) for sub_name in label._children_list: sub_label = getattr(label, sub_name) if sub_label is not None: @@ -220,9 +232,14 @@ def batch_label(tensor, label, name): return tensor - def batch(self): + def batch(self, dim=0): """Add a batch dimension on the tensor + Parameters + ---------- + dim : int + The dim on which to add the batch dimension. Can be 0 or 1. + Returns ------- batch_frame: aloscene.Frame @@ -233,7 +250,10 @@ def batch(self): tensor = self.rename(None) tensor = torch.unsqueeze(tensor, dim=0) - tensor.rename_(*tuple(["B"] + list(tensor._saved_names))) + + n_names = list(tensor._saved_names) + n_names.insert(dim, "B") + tensor.rename_(*tuple(n_names)) def batch_label(tensor, label, name): """ @@ -242,7 +262,9 @@ def batch_label(tensor, label, name): - else, the previous name are restored. """ if tensor._child_property[name]["mergeable"]: - label.rename_(*tuple(["B"] + list(label._saved_names))) + n_label_names = list(label._saved_names) + n_label_names.insert(dim, "B") + label.rename_(*tuple(n_label_names)) for sub_name in label._children_list: sub_label = getattr(label, sub_name) if sub_label is not None: