From 7f3f7b8d75ebd497b6f268c2184c8985ea6fc961 Mon Sep 17 00:00:00 2001 From: Taha Azzim Date: Fri, 20 Jan 2023 15:27:48 +0100 Subject: [PATCH 1/4] feat -detr exportation- include preprocessing --- alonet/deformable_detr/deformable_detr.py | 34 +++++++++++-- alonet/deformable_detr/trt_exporter.py | 60 +++++++++-------------- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/alonet/deformable_detr/deformable_detr.py b/alonet/deformable_detr/deformable_detr.py index 6bcd9554..1ab3b76f 100644 --- a/alonet/deformable_detr/deformable_detr.py +++ b/alonet/deformable_detr/deformable_detr.py @@ -85,6 +85,7 @@ def __init__( return_intermediate_dec: bool = True, strict_load_weights: bool = True, tracing=False, + include_preprocessing=False, ): print("WARNING : you are using DeformableDETR or an unherited class. Please launch aloception-oss/alonet/deformable_detr/ops/make.sh before proceeding with training. Please refer to the README for more info") super().__init__() @@ -97,6 +98,7 @@ def __init__( self.return_dec_outputs = return_dec_outputs self.return_enc_outputs = return_enc_outputs self.return_bb_outputs = return_bb_outputs + self.include_preprocessing = include_preprocessing if activation_fn not in ["sigmoid", "softmax"]: raise Exception(f"activation_fn = {activation_fn} must be one of this two values: 'sigmoid' or 'softmax'.") @@ -192,6 +194,22 @@ def tracing(self): def tracing(self, is_tracing): self._tracing = is_tracing self.backbone.tracing = is_tracing + + @staticmethod + def in_img_preprocess(frames): + frames = frames.permute(0, 3, 1, 2) + frames = frames.div(255) + + n_shape = [1] * len(frames.shape) + n_shape[1] = 3 + + mean_std = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + std_tensor = torch.tensor(mean_std[1], device=frames.device).view(tuple(n_shape)) + mean_tensor = torch.tensor(mean_std[0], device=frames.device).view(tuple(n_shape)) + + frames = frames - mean_tensor + frames = frames / std_tensor + return frames @assert_and_export_onnx(check_mean_std=True, input_mean_std=INPUT_MEAN_STD) def forward(self, frames: aloscene.Frame, **kwargs): @@ -222,17 +240,23 @@ def forward(self, frames: aloscene.Frame, **kwargs): - :attr:`enc_outputs`: Optional, only returned when transformer encoder outputs are activated. - :attr:`dec_outputs`: Optional, only returned when transformer decoder outputs are activated. """ - - # ==== Backbone - features, pos = self.backbone(frames, **kwargs) - assert next(self.parameters()).is_cuda, "DeformableDETR cannot run on CPU (due to MSdeformable op)" if self.tracing: - frame_masks = frames[:, 3:4] + if self.include_preprocessing: + frame_masks = torch.zeros((1, 1, *frames.shape[1:3]), dtype=torch.float32) + frame_masks = frame_masks.to(frames.device) + frames = self.in_img_preprocess(frames) + else: + frame_masks = torch.zeros((1, 1, *frames.shape[-2:]), dtype=torch.float32) + frame_masks = frame_masks.to(frames.device) + frames = torch.cat([frames, frame_masks], dim=1) else: frame_masks = frames.mask.as_tensor() + # ==== Backbone + features, pos = self.backbone(frames, **kwargs) + # ==== Transformer srcs = [] masks = [] diff --git a/alonet/deformable_detr/trt_exporter.py b/alonet/deformable_detr/trt_exporter.py index db462078..bdf39ece 100644 --- a/alonet/deformable_detr/trt_exporter.py +++ b/alonet/deformable_detr/trt_exporter.py @@ -9,47 +9,21 @@ import onnx_graphsurgeon as gs from alonet.torch2trt import utils -from torch.onnx import register_custom_op_symbolic from aloscene import Frame +from alonet.torch2trt import BaseTRTExporter from alonet.torch2trt.utils import get_nodes_by_op from alonet.torch2trt.onnx_hack import _add_grid_sampler_to_opset13 from alonet.deformable_detr import DeformableDetrR50, DeformableDetrR50Refinement -from alonet.torch2trt import BaseTRTExporter, MS_DEFORM_IM2COL_PLUGIN_LIB, load_trt_custom_plugins - -CUSTOM_OP_VERSION = 9 - - -def symbolic_ms_deform_attn_forward( - g, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step -): - return g.op( - "alonet_custom::ms_deform_attn_forward", - value, - value_spatial_shapes, - value_level_start_index, - sampling_locations, - attention_weights, - im2col_step, - ) - - -register_custom_op_symbolic( - "alonet_custom::ms_deform_attn_forward", symbolic_ms_deform_attn_forward, CUSTOM_OP_VERSION -) - - -def load_trt_plugins_for_deformable_detr(): - load_trt_custom_plugins(MS_DEFORM_IM2COL_PLUGIN_LIB) class DeformableDetrTRTExporter(BaseTRTExporter): - def __init__(self, model_name="deformable-detr-r50", weights="deformable-detr-r50", *args, **kwargs): + def __init__(self, model_name="deformable-detr-r50", weights="deformable-detr-r50", include_preprocessing=False, *args, **kwargs): _add_grid_sampler_to_opset13() super().__init__(*args, **kwargs) self.weights = weights self.do_constant_folding = False - self.custom_opset = {"alonet_custom": 1} + self.include_preprocessing = include_preprocessing self.adapted_onnx_path = self.onnx_path.replace(".onnx", "_TRTadapted") + ".onnx" def get_onnx_path(self): @@ -145,11 +119,15 @@ def _onnx2engine(self, **kwargs): def prepare_sample_inputs(self): assert len(self.input_shapes) == 1, "DETR takes only 1 input" shape = self.input_shapes[0] - x = torch.rand(shape, dtype=torch.float32) - x = Frame(x, names=["C", "H", "W"]).norm_resnet() - x = Frame.batch_list([x] * self.batch_size).to(self.device) - tensor_input = (x.as_tensor(), x.mask.as_tensor()) - tensor_input = torch.cat(tensor_input, dim=1) # [b, 4, H, W] + if self.include_preprocessing: + tensor_input = torch.rand([1 * self.batch_size] + shape, dtype=torch.float32).to(self.device) + tensor_input = tensor_input * 255 + else: + x = torch.rand(shape, dtype=torch.float32) + x = Frame(x, names=["C", "H", "W"]).norm_resnet() + x = Frame.batch_list([x] * self.batch_size).to(self.device) + tensor_input = (x.as_tensor(), x.mask.as_tensor()) + tensor_input = torch.cat(tensor_input, dim=1) # [b, 4, H, W] return (tensor_input,), {"is_export_onnx": None} @@ -157,11 +135,12 @@ def prepare_sample_inputs(self): from alonet.common.pl_helpers import vb_folder - load_trt_plugins_for_deformable_detr() + # load_trt_plugins_for_deformable_detr() device = torch.device("cuda") parser = argparse.ArgumentParser() parser.add_argument("--refinement", action="store_true", help="If set, use box refinement") + parser.add_argument("--include_preprocessing", action="store_true", help="Includes image preprocessing in the graph") parser.add_argument( "--HW", type=int, nargs=2, default=[1280, 1920], help="Height and width of input image, default 1280 1920" ) @@ -171,7 +150,11 @@ def prepare_sample_inputs(self): if args.refinement: model_name = "deformable-detr-r50-refinement" - model = DeformableDetrR50Refinement(weights=model_name, tracing=True, aux_loss=False).eval() + model = DeformableDetrR50Refinement( + weights=model_name, + tracing=True, + aux_loss=False, + include_preprocessing=args.include_preprocessing).eval() else: model_name = "deformable-detr-r50" model = DeformableDetrR50(weights=model_name, tracing=True, aux_loss=False).eval() @@ -179,7 +162,10 @@ def prepare_sample_inputs(self): if args.onnx_path is None: args.onnx_path = os.path.join(vb_folder(), "weights", model_name, model_name + ".onnx") - input_shape = [3] + list(args.HW) + if args.include_preprocessing: + input_shape = list(args.HW) + [3] + else: + input_shape = [3] + list(args.HW) exporter = DeformableDetrTRTExporter( model=model, weights=model_name, input_shapes=(input_shape,), input_names=["img"], device=device, **vars(args) From e85a8eb02e694a5dc87e268e2a9cd2f25e17e66f Mon Sep 17 00:00:00 2001 From: Taha Azzim Date: Fri, 20 Jan 2023 16:01:36 +0100 Subject: [PATCH 2/4] update readme --- alonet/deformable_detr/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/alonet/deformable_detr/README.md b/alonet/deformable_detr/README.md index ac224012..fcb9c432 100644 --- a/alonet/deformable_detr/README.md +++ b/alonet/deformable_detr/README.md @@ -61,5 +61,10 @@ Evaluation on 1000 images COCO with box refinement ## Exportation ```bash -python trt_exporter.py --refinement --HW 320 480 --verbose --ignore_adapt_graph +python trt_exporter.py --refinement --HW 320 480 --verbose +``` +or (for preprocessing included) + +```bash +python trt_exporter.py --refinement --HW 320 480 --verbose ``` \ No newline at end of file From 49ea8b534d776bb5611de6e1e158617381d6605f Mon Sep 17 00:00:00 2001 From: Taha Azzim Date: Thu, 16 Feb 2023 15:40:55 +0100 Subject: [PATCH 3/4] separate tracing assertion --- alonet/deformable_detr/deformable_detr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alonet/deformable_detr/deformable_detr.py b/alonet/deformable_detr/deformable_detr.py index 1ab3b76f..573ae963 100644 --- a/alonet/deformable_detr/deformable_detr.py +++ b/alonet/deformable_detr/deformable_detr.py @@ -240,7 +240,6 @@ def forward(self, frames: aloscene.Frame, **kwargs): - :attr:`enc_outputs`: Optional, only returned when transformer encoder outputs are activated. - :attr:`dec_outputs`: Optional, only returned when transformer decoder outputs are activated. """ - assert next(self.parameters()).is_cuda, "DeformableDETR cannot run on CPU (due to MSdeformable op)" if self.tracing: if self.include_preprocessing: @@ -252,6 +251,7 @@ def forward(self, frames: aloscene.Frame, **kwargs): frame_masks = frame_masks.to(frames.device) frames = torch.cat([frames, frame_masks], dim=1) else: + assert next(self.parameters()).is_cuda, "DeformableDETR cannot run on CPU (due to MSdeformable op)" frame_masks = frames.mask.as_tensor() # ==== Backbone From fbd9169ea81e453b123cd925e72dde7f853f5df2 Mon Sep 17 00:00:00 2001 From: Taha Azzim Date: Thu, 16 Feb 2023 15:41:17 +0100 Subject: [PATCH 4/4] remove dim assert --- alonet/detr/misc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alonet/detr/misc.py b/alonet/detr/misc.py index 14634771..f3f6837d 100644 --- a/alonet/detr/misc.py +++ b/alonet/detr/misc.py @@ -14,7 +14,6 @@ def wrapper(instance, frames: Union[torch.Tensor, Frame], is_export_onnx=False, # because torch.onnx.export accepts only torch.Tensor or None if hasattr(instance, "tracing") and instance.tracing: assert isinstance(frames, torch.Tensor) - assert frames.shape[1] == 4 # rgb 3 + mask 1 kwargs["is_tracing"] = None if is_export_onnx is None: return forward(instance, frames, is_export_onnx=None, **kwargs)