From 949ffe9f6addf9c24882a59491947645daf16bf0 Mon Sep 17 00:00:00 2001 From: John Welsh Date: Wed, 13 Jan 2021 22:05:56 +0000 Subject: [PATCH] added expand converter --- CHANGELOG.md | 1 + torch2trt/converters/__init__.py | 1 + torch2trt/converters/expand.py | 43 ++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) create mode 100644 torch2trt/converters/expand.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e672c02..df0fb9ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added +- Added converter for ``torch.Tensor.expand`` - Added support for custom converters for methods defined outside of ``torch`` module - Added names for TensorRT layers - Added GroupNorm plugin which internally uses PyTorch aten::group_norm diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 08164476..d6ae6876 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -25,6 +25,7 @@ from .clamp import * from .compare import * from .div import * +from .expand import * from .getitem import * from .identity import * from .instance_norm import * diff --git a/torch2trt/converters/expand.py b/torch2trt/converters/expand.py new file mode 100644 index 00000000..e0d07540 --- /dev/null +++ b/torch2trt/converters/expand.py @@ -0,0 +1,43 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +@tensorrt_converter('torch.Tensor.expand') +def convert_expand(ctx): + input = ctx.method_args[0] + sizes = ctx.method_args[1:] + output = ctx.method_return + + inshape = tuple(input.shape)[1:] # exclude batch + shape = tuple(output.shape)[1:] + ndim = len(shape) + start = tuple([0]*ndim) + stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise + + layer = ctx.network.add_slice(input._trt, start, shape, stride) + + output._trt = layer.get_output(0) + + +class ExpandModule(torch.nn.Module): + def __init__(self, *sizes): + super(ExpandModule, self).__init__() + self.sizes = sizes + + def forward(self, x): + return x.expand(*self.sizes) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1,1,3,3)]) +def test_tensor_expand_singledim(): + return ExpandModule(1, 3, 3, 3) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)]) +def test_tensor_expand_multidim(): + return ExpandModule(1, 3, 3, 3) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1,1,1,3)]) +def test_tensor_expand_inferdim(): + return ExpandModule(1, 3, -1, -1) \ No newline at end of file