diff --git a/alonet/__init__.py b/alonet/__init__.py index cb9f618a..510f1064 100644 --- a/alonet/__init__.py +++ b/alonet/__init__.py @@ -11,4 +11,4 @@ from . import detr_panoptic from . import deformable_detr_panoptic -from . import torch2trt + diff --git a/alonet/torch2trt/calibrator.py b/alonet/torch2trt/calibrator.py index 5bd7bfaf..fe7a0a23 100644 --- a/alonet/torch2trt/calibrator.py +++ b/alonet/torch2trt/calibrator.py @@ -65,8 +65,6 @@ class DataBatchStreamer: >>> s_dataStreamer = DataBatchStreamer(dataset=s_calib) >>> m_dataStreamer = DataBatchStreamer(dataset=m_calib) """ - FTYPES = ["torch.Tensor", "ndarray", "aloscene.Frame"] - def __init__( self, dataset=None, @@ -76,7 +74,8 @@ def __init__( ): for sample in dataset[0]: if not isinstance(sample, (torch.Tensor, np.ndarray, Frame)): - raise TypeError(f"unknown sample type, expected samples to be instance of {' or '.join(self.FTYPES)} got {sample.__class__.__name__} instead") + ftypes = ["torch.Tensor", "ndarray", "aloscene.Frame"] + raise TypeError(f"unknown sample type, expected samples to be instance of {' or '.join(ftypes)} got {sample.__class__.__name__} instead") self.batch_idx = 0 self.dataset = dataset @@ -105,7 +104,8 @@ def convert_frame(frame): elif isinstance(frame, np.ndarray): pass else: - raise TypeError(f"Unknown sample type, expected samples to be instance of {' or '.join(self.FTYPES)} got {frame.__class__.__name__}.") + ftypes = ["torch.Tensor", "ndarray", "aloscene.Frame"] + raise TypeError(f"Unknown sample type, expected samples to be instance of {' or '.join(ftypes)} got {frame.__class__.__name__}.") return frame def next_(self): @@ -127,7 +127,7 @@ def next_(self): return None def __len__(self): - return max_batch + return self.max_batch class BaseCalibrator: