Skip to content

Commit

Permalink
Merge pull request #319 from Visual-Behavior/detr_pre
Browse files Browse the repository at this point in the history
feat -detr exportation- include preprocessing
  • Loading branch information
thibo73800 authored Feb 16, 2023
2 parents c825fc5 + fbd9169 commit 469114d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
7 changes: 6 additions & 1 deletion alonet/deformable_detr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,10 @@ Evaluation on 1000 images COCO with box refinement

## Exportation
```bash
python trt_exporter.py --refinement --HW 320 480 --verbose --ignore_adapt_graph
python trt_exporter.py --refinement --HW 320 480 --verbose
```
or (for preprocessing included)

```bash
python trt_exporter.py --refinement --HW 320 480 --verbose
```
36 changes: 30 additions & 6 deletions alonet/deformable_detr/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
return_intermediate_dec: bool = True,
strict_load_weights: bool = True,
tracing=False,
include_preprocessing=False,
):
print("WARNING : you are using DeformableDETR or an unherited class. Please launch aloception-oss/alonet/deformable_detr/ops/make.sh before proceeding with training. Please refer to the README for more info")
super().__init__()
Expand All @@ -97,6 +98,7 @@ def __init__(
self.return_dec_outputs = return_dec_outputs
self.return_enc_outputs = return_enc_outputs
self.return_bb_outputs = return_bb_outputs
self.include_preprocessing = include_preprocessing

if activation_fn not in ["sigmoid", "softmax"]:
raise Exception(f"activation_fn = {activation_fn} must be one of this two values: 'sigmoid' or 'softmax'.")
Expand Down Expand Up @@ -192,6 +194,22 @@ def tracing(self):
def tracing(self, is_tracing):
self._tracing = is_tracing
self.backbone.tracing = is_tracing

@staticmethod
def in_img_preprocess(frames):
frames = frames.permute(0, 3, 1, 2)
frames = frames.div(255)

n_shape = [1] * len(frames.shape)
n_shape[1] = 3

mean_std = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
std_tensor = torch.tensor(mean_std[1], device=frames.device).view(tuple(n_shape))
mean_tensor = torch.tensor(mean_std[0], device=frames.device).view(tuple(n_shape))

frames = frames - mean_tensor
frames = frames / std_tensor
return frames

@assert_and_export_onnx(check_mean_std=True, input_mean_std=INPUT_MEAN_STD)
def forward(self, frames: aloscene.Frame, **kwargs):
Expand Down Expand Up @@ -223,16 +241,22 @@ def forward(self, frames: aloscene.Frame, **kwargs):
- :attr:`dec_outputs`: Optional, only returned when transformer decoder outputs are activated.
"""

# ==== Backbone
features, pos = self.backbone(frames, **kwargs)

assert next(self.parameters()).is_cuda, "DeformableDETR cannot run on CPU (due to MSdeformable op)"

if self.tracing:
frame_masks = frames[:, 3:4]
if self.include_preprocessing:
frame_masks = torch.zeros((1, 1, *frames.shape[1:3]), dtype=torch.float32)
frame_masks = frame_masks.to(frames.device)
frames = self.in_img_preprocess(frames)
else:
frame_masks = torch.zeros((1, 1, *frames.shape[-2:]), dtype=torch.float32)
frame_masks = frame_masks.to(frames.device)
frames = torch.cat([frames, frame_masks], dim=1)
else:
assert next(self.parameters()).is_cuda, "DeformableDETR cannot run on CPU (due to MSdeformable op)"
frame_masks = frames.mask.as_tensor()

# ==== Backbone
features, pos = self.backbone(frames, **kwargs)

# ==== Transformer
srcs = []
masks = []
Expand Down
60 changes: 23 additions & 37 deletions alonet/deformable_detr/trt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,21 @@
import onnx_graphsurgeon as gs
from alonet.torch2trt import utils

from torch.onnx import register_custom_op_symbolic

from aloscene import Frame
from alonet.torch2trt import BaseTRTExporter
from alonet.torch2trt.utils import get_nodes_by_op
from alonet.torch2trt.onnx_hack import _add_grid_sampler_to_opset13
from alonet.deformable_detr import DeformableDetrR50, DeformableDetrR50Refinement
from alonet.torch2trt import BaseTRTExporter, MS_DEFORM_IM2COL_PLUGIN_LIB, load_trt_custom_plugins

CUSTOM_OP_VERSION = 9


def symbolic_ms_deform_attn_forward(
g, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
):
return g.op(
"alonet_custom::ms_deform_attn_forward",
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)


register_custom_op_symbolic(
"alonet_custom::ms_deform_attn_forward", symbolic_ms_deform_attn_forward, CUSTOM_OP_VERSION
)


def load_trt_plugins_for_deformable_detr():
load_trt_custom_plugins(MS_DEFORM_IM2COL_PLUGIN_LIB)


class DeformableDetrTRTExporter(BaseTRTExporter):
def __init__(self, model_name="deformable-detr-r50", weights="deformable-detr-r50", *args, **kwargs):
def __init__(self, model_name="deformable-detr-r50", weights="deformable-detr-r50", include_preprocessing=False, *args, **kwargs):
_add_grid_sampler_to_opset13()
super().__init__(*args, **kwargs)
self.weights = weights
self.do_constant_folding = False
self.custom_opset = {"alonet_custom": 1}
self.include_preprocessing = include_preprocessing
self.adapted_onnx_path = self.onnx_path.replace(".onnx", "_TRTadapted") + ".onnx"

def get_onnx_path(self):
Expand Down Expand Up @@ -145,23 +119,28 @@ def _onnx2engine(self, **kwargs):
def prepare_sample_inputs(self):
assert len(self.input_shapes) == 1, "DETR takes only 1 input"
shape = self.input_shapes[0]
x = torch.rand(shape, dtype=torch.float32)
x = Frame(x, names=["C", "H", "W"]).norm_resnet()
x = Frame.batch_list([x] * self.batch_size).to(self.device)
tensor_input = (x.as_tensor(), x.mask.as_tensor())
tensor_input = torch.cat(tensor_input, dim=1) # [b, 4, H, W]
if self.include_preprocessing:
tensor_input = torch.rand([1 * self.batch_size] + shape, dtype=torch.float32).to(self.device)
tensor_input = tensor_input * 255
else:
x = torch.rand(shape, dtype=torch.float32)
x = Frame(x, names=["C", "H", "W"]).norm_resnet()
x = Frame.batch_list([x] * self.batch_size).to(self.device)
tensor_input = (x.as_tensor(), x.mask.as_tensor())
tensor_input = torch.cat(tensor_input, dim=1) # [b, 4, H, W]
return (tensor_input,), {"is_export_onnx": None}


if __name__ == "__main__":
from alonet.common.pl_helpers import vb_folder


load_trt_plugins_for_deformable_detr()
# load_trt_plugins_for_deformable_detr()
device = torch.device("cuda")

parser = argparse.ArgumentParser()
parser.add_argument("--refinement", action="store_true", help="If set, use box refinement")
parser.add_argument("--include_preprocessing", action="store_true", help="Includes image preprocessing in the graph")
parser.add_argument(
"--HW", type=int, nargs=2, default=[1280, 1920], help="Height and width of input image, default 1280 1920"
)
Expand All @@ -171,15 +150,22 @@ def prepare_sample_inputs(self):

if args.refinement:
model_name = "deformable-detr-r50-refinement"
model = DeformableDetrR50Refinement(weights=model_name, tracing=True, aux_loss=False).eval()
model = DeformableDetrR50Refinement(
weights=model_name,
tracing=True,
aux_loss=False,
include_preprocessing=args.include_preprocessing).eval()
else:
model_name = "deformable-detr-r50"
model = DeformableDetrR50(weights=model_name, tracing=True, aux_loss=False).eval()

if args.onnx_path is None:
args.onnx_path = os.path.join(vb_folder(), "weights", model_name, model_name + ".onnx")

input_shape = [3] + list(args.HW)
if args.include_preprocessing:
input_shape = list(args.HW) + [3]
else:
input_shape = [3] + list(args.HW)

exporter = DeformableDetrTRTExporter(
model=model, weights=model_name, input_shapes=(input_shape,), input_names=["img"], device=device, **vars(args)
Expand Down
1 change: 0 additions & 1 deletion alonet/detr/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def wrapper(instance, frames: Union[torch.Tensor, Frame], is_export_onnx=False,
# because torch.onnx.export accepts only torch.Tensor or None
if hasattr(instance, "tracing") and instance.tracing:
assert isinstance(frames, torch.Tensor)
assert frames.shape[1] == 4 # rgb 3 + mask 1
kwargs["is_tracing"] = None
if is_export_onnx is None:
return forward(instance, frames, is_export_onnx=None, **kwargs)
Expand Down

0 comments on commit 469114d

Please sign in to comment.