-
Notifications
You must be signed in to change notification settings - Fork 371
feat: support 1D, 2D, and 3D avg and max pooling dynamo converters #2317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
139ad9b
8e89105
1ddd88f
767e8fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
| matmul, | ||
| normalization, | ||
| permutation, | ||
| pool, | ||
| reduce, | ||
| select, | ||
| shape, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| from typing import Optional, Sequence, Union | ||
|
|
||
| # @manual=//deeplearning/trt/python:py_tensorrt | ||
|
||
| import tensorrt as trt | ||
| from torch.fx.node import Target | ||
| from torch_tensorrt.dynamo.conversion.converter_utils import extend_attr_to_tuple | ||
| from torch_tensorrt.fx.converters.converter_utils import ( | ||
| SourceIR, | ||
|
||
| has_dynamic_shape, | ||
| set_layer_name, | ||
| ) | ||
| from torch_tensorrt.fx.types import TRTNetwork, TRTTensor | ||
|
|
||
|
|
||
| def avg_poolNd( | ||
| network: TRTNetwork, | ||
| target: Union[Target, str], | ||
| source_ir: Optional[SourceIR], | ||
| name: str, | ||
| input: TRTTensor, | ||
| kernel_size: Sequence[int], | ||
| stride: Union[int, Sequence[int]], | ||
| padding: Union[int, Sequence[int]] = 0, | ||
| ceil_mode: bool = False, | ||
| count_include_pad: bool = True, | ||
| divisor_override: Optional[int] = None, | ||
| ) -> TRTTensor: | ||
| if has_dynamic_shape(input.shape): | ||
| assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." | ||
|
|
||
| if ceil_mode is not False: | ||
| raise RuntimeError("ceil_mode is not yet supported!") | ||
|
|
||
| if divisor_override is not None: | ||
| raise RuntimeError("divisor_override is not yet supported!") | ||
|
|
||
| dim = len(kernel_size) | ||
|
|
||
| kernel_size = extend_attr_to_tuple(kernel_size, dim) | ||
|
|
||
| if stride is None: | ||
| stride = kernel_size | ||
| else: | ||
|
||
| stride = extend_attr_to_tuple(stride, dim) | ||
|
|
||
| padding = extend_attr_to_tuple(padding, dim) | ||
|
|
||
| # add average pooling layer | ||
| pool_layer = network.add_pooling_nd( | ||
| input=input, | ||
| type=trt.PoolingType.AVERAGE, | ||
| window_size=kernel_size, | ||
| ) | ||
|
|
||
| pool_layer.stride_nd = stride | ||
| pool_layer.padding_nd = padding | ||
| pool_layer.average_count_excludes_padding = not count_include_pad | ||
|
|
||
| set_layer_name(pool_layer, target, name, source_ir) | ||
| return pool_layer.get_output(0) | ||
|
|
||
|
|
||
| def max_poolNd( | ||
| network: TRTNetwork, | ||
| target: Union[Target, str], | ||
| source_ir: Optional[SourceIR], | ||
| name: str, | ||
| input: TRTTensor, | ||
| kernel_size: Sequence[int], | ||
| stride: Union[int, Sequence[int]], | ||
| padding: Union[int, Sequence[int]] = 0, | ||
| dilation: Union[int, Sequence[int]] = 1, | ||
| ceil_mode: bool = False, | ||
| ) -> TRTTensor: | ||
| if has_dynamic_shape(input.shape): | ||
| assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." | ||
|
|
||
| if dilation != 1: | ||
| raise RuntimeError("dilation is not yet supported!") | ||
|
|
||
| if ceil_mode is not False: | ||
| raise RuntimeError("ceil_mode is not yet supported!") | ||
|
|
||
| dim = len(kernel_size) | ||
|
|
||
| kernel_size = extend_attr_to_tuple(kernel_size, dim) | ||
|
|
||
| if stride is None: | ||
| stride = kernel_size | ||
| else: | ||
|
||
| stride = extend_attr_to_tuple(stride, dim) | ||
|
|
||
| padding = extend_attr_to_tuple(padding, dim) | ||
|
|
||
| # add max pooling layer | ||
| pool_layer = network.add_pooling_nd( | ||
| input=input, | ||
| type=trt.PoolingType.MAX, | ||
| window_size=kernel_size, | ||
| ) | ||
|
|
||
| pool_layer.stride_nd = stride | ||
| pool_layer.padding_nd = padding | ||
|
|
||
| set_layer_name(pool_layer, target, name, source_ir) | ||
| return pool_layer.get_output(0) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could
max_pool1dsupport be added here as well? SchemaUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add
torch.ops.aten.max_pool1d.defaultbut it won't be used. Even fortorch.nn.AvgPool1d, it still callstorch.ops.aten.avg_pool2d.default, as you can see in the test file: https://github.com/pytorch/TensorRT/pull/2317/files#diff-9fce39bc42c66d2866c41665779cab7da0a4d3fe54576925e2b66c17a1cf1ebfR20-R43But anyways, the 1d schema looks same as others, so I added here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for that - I plan to add a lowering pass which will lead us to that converter, so it will still be helpful.