From ce01409d9c84f9785fdc4d9bdb09be532036af35 Mon Sep 17 00:00:00 2001 From: Taha Date: Wed, 24 Aug 2022 16:10:46 +0200 Subject: [PATCH] fix quatization arguments init values --- alonet/torch2trt/base_exporter.py | 2 +- alonet/torch2trt/calibrator.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/alonet/torch2trt/base_exporter.py b/alonet/torch2trt/base_exporter.py index cfa4701c..e83f5e30 100644 --- a/alonet/torch2trt/base_exporter.py +++ b/alonet/torch2trt/base_exporter.py @@ -402,7 +402,7 @@ def add_argparse_args(parent_parser): parser.add_argument("--verbose", action="store_true", help="Helpful when debugging") parser.add_argument("--profiling_verbosity", default=0, type=int, help="Helpful when profiling the engine (default: %(default)s)") parser.add_argument("--calibration_batch_size", type=int, default=8, help="Calibration data batch size (default: %(default)s)") - parser.add_argument("--limit_calibration_batches", type=int, default=10, help="Limits number of batches (default: %(default)s)") + parser.add_argument("--limit_calibration_batches", type=int, default=None, help="Limits number of batches (default: %(default)s)") parser.add_argument("--cache_file", type=str, default="calib.bin", help="Path to caliaration cache file (default: %(default)s)") parser.add_argument( "--calibrator", diff --git a/alonet/torch2trt/calibrator.py b/alonet/torch2trt/calibrator.py index 5bd7bfaf..fe7a0a23 100644 --- a/alonet/torch2trt/calibrator.py +++ b/alonet/torch2trt/calibrator.py @@ -65,8 +65,6 @@ class DataBatchStreamer: >>> s_dataStreamer = DataBatchStreamer(dataset=s_calib) >>> m_dataStreamer = DataBatchStreamer(dataset=m_calib) """ - FTYPES = ["torch.Tensor", "ndarray", "aloscene.Frame"] - def __init__( self, dataset=None, @@ -76,7 +74,8 @@ def __init__( ): for sample in dataset[0]: if not isinstance(sample, (torch.Tensor, np.ndarray, Frame)): - raise TypeError(f"unknown sample type, expected samples to be instance of {' or '.join(self.FTYPES)} got {sample.__class__.__name__} instead") + ftypes = ["torch.Tensor", "ndarray", "aloscene.Frame"] + raise TypeError(f"unknown sample type, expected samples to be instance of {' or '.join(ftypes)} got {sample.__class__.__name__} instead") self.batch_idx = 0 self.dataset = dataset @@ -105,7 +104,8 @@ def convert_frame(frame): elif isinstance(frame, np.ndarray): pass else: - raise TypeError(f"Unknown sample type, expected samples to be instance of {' or '.join(self.FTYPES)} got {frame.__class__.__name__}.") + ftypes = ["torch.Tensor", "ndarray", "aloscene.Frame"] + raise TypeError(f"Unknown sample type, expected samples to be instance of {' or '.join(ftypes)} got {frame.__class__.__name__}.") return frame def next_(self): @@ -127,7 +127,7 @@ def next_(self): return None def __len__(self): - return max_batch + return self.max_batch class BaseCalibrator: