Skip to content

Commit 2595b2a

Browse files
committed
computation optimization
1 parent 36c927a commit 2595b2a

File tree

1 file changed

+19
-12
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+19
-12
lines changed

py/torch_tensorrt/dynamo/conversion/impl/pool.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional, Sequence, Union
2+
from typing import Dict, Optional, Sequence, Union
33

44
import tensorrt as trt
55
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -128,6 +128,8 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
128128
out_dim = output_size if isinstance(output_size, int) else output_size[0]
129129
output_list = []
130130

131+
# store {index: slice} for reducing repeated slice ops
132+
idx_slice_map: Dict[int, TRTTensor] = {}
131133
# iterate over each output dimension
132134
for i in range(out_dim):
133135
# calculate the start and end index of each pooling window
@@ -137,17 +139,22 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
137139
# slice the input tensor from start to end index, the result of which is the window waiting for pooling
138140
slices = []
139141
for j in range(start, end):
140-
slice = impl.select.select(
141-
ctx, target, source_ir, f"{name}_select_{i}_{j}", input, -1, j
142-
)
143-
slice = impl.shuffle.reshape(
144-
ctx,
145-
target,
146-
source_ir,
147-
f"{name}_reshape_{i}_{j}",
148-
slice,
149-
(*slice.shape, 1),
150-
)
142+
if j in idx_slice_map:
143+
slice = idx_slice_map[j]
144+
else:
145+
slice = impl.select.select(
146+
ctx, target, source_ir, f"{name}_select_{j}", input, -1, j
147+
)
148+
slice = impl.shuffle.reshape(
149+
ctx,
150+
target,
151+
source_ir,
152+
f"{name}_reshape_{i}_{j}",
153+
slice,
154+
(*slice.shape, 1),
155+
)
156+
idx_slice_map[j] = slice
157+
151158
slices.append(slice)
152159

153160
slices = impl.cat.cat(

0 commit comments

Comments
 (0)