Skip to content

Commit

Permalink
Merge pull request #162 from Visual-Behavior/temporal-batch-on-dim
Browse files Browse the repository at this point in the history
add batch(dim) and temporal(dim)
  • Loading branch information
thibo73800 authored Apr 5, 2022
2 parents 26ba306 + 4657ec1 commit 89205f3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
7 changes: 2 additions & 5 deletions alonet/torch2trt/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import tensorrt as trt
import pycuda.driver as cuda
prod_package_error = None
except Exception as prod_package_error:
except Exception as e:
prod_package_error = e
pass


Expand Down Expand Up @@ -93,10 +94,6 @@ def __init__(
"""
if prod_package_error is not None:
raise prod_package_error
if prod_package_error is not None:
raise prod_package_error
if prod_package_error is not None:
raise prod_package_error

if model is not None:
assert hasattr(model, "tracing") and model.tracing, "Model must be instantiated with tracing=True"
Expand Down
2 changes: 1 addition & 1 deletion aloscene/tensors/augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _squeeze_unsqueeze_dim(self, tensor, func, types, squeeze, args=(), kwargs=N
"""
dim = kwargs["dim"] if "dim" in kwargs else 0

if dim != 0:
if dim != 0 and dim !=1:
raise Exception(
f"Impossible to expand the labeld tensor on the given dim: {dim}. Export your labeled tensor into tensor before to do it."
)
Expand Down
36 changes: 29 additions & 7 deletions aloscene/tensors/spatial_augmented_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,14 @@ def relative_to_absolute(self, x, dim, assert_integer=False):
assert x.is_integer(), f"relative coordinates {x} have produced non-integer absolute coordinates"
return round(x)

def temporal(self):
def temporal(self, dim=0):
"""Add a temporal dimension on the tensor
Parameters
----------
dim : int
The dim on which to add the temporal dimension. Can be 0 or 1
Returns
-------
temporal_frame: aloscene.Frame
Expand All @@ -198,13 +203,20 @@ def temporal(self):
if "T" in self.names: # Already a temporal frame
return self

def set_n_names(names):
pass

tensor = self.rename(None)
tensor = torch.unsqueeze(tensor, dim=0)
tensor.rename_(*tuple(["T"] + list(tensor._saved_names)))
tensor = torch.unsqueeze(tensor, dim=dim)
n_names = list(tensor._saved_names)
n_names.insert(dim, "T")
tensor.rename_(*tuple(n_names))

def batch_label(tensor, label, name):
if tensor._child_property[name]["mergeable"]:
label.rename_(*tuple(["T"] + list(label._saved_names)))
n_label_names = list(label._saved_names)
n_label_names.insert(dim, "T")
label.rename_(*tuple(n_label_names))
for sub_name in label._children_list:
sub_label = getattr(label, sub_name)
if sub_label is not None:
Expand All @@ -220,9 +232,14 @@ def batch_label(tensor, label, name):

return tensor

def batch(self):
def batch(self, dim=0):
"""Add a batch dimension on the tensor
Parameters
----------
dim : int
The dim on which to add the batch dimension. Can be 0 or 1.
Returns
-------
batch_frame: aloscene.Frame
Expand All @@ -233,7 +250,10 @@ def batch(self):

tensor = self.rename(None)
tensor = torch.unsqueeze(tensor, dim=0)
tensor.rename_(*tuple(["B"] + list(tensor._saved_names)))

n_names = list(tensor._saved_names)
n_names.insert(dim, "B")
tensor.rename_(*tuple(n_names))

def batch_label(tensor, label, name):
"""
Expand All @@ -242,7 +262,9 @@ def batch_label(tensor, label, name):
- else, the previous name are restored.
"""
if tensor._child_property[name]["mergeable"]:
label.rename_(*tuple(["B"] + list(label._saved_names)))
n_label_names = list(label._saved_names)
n_label_names.insert(dim, "B")
label.rename_(*tuple(n_label_names))
for sub_name in label._children_list:
sub_label = getattr(label, sub_name)
if sub_label is not None:
Expand Down

0 comments on commit 89205f3

Please sign in to comment.