1
1
import math
2
- from typing import Optional , Sequence , Union
2
+ from typing import Dict , Optional , Sequence , Union
3
3
4
4
import tensorrt as trt
5
5
import torch_tensorrt .dynamo .conversion .impl as impl
@@ -128,6 +128,8 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
128
128
out_dim = output_size if isinstance (output_size , int ) else output_size [0 ]
129
129
output_list = []
130
130
131
+ # store {index: slice} for reducing repeated slice ops
132
+ idx_slice_map : Dict [int , TRTTensor ] = {}
131
133
# iterate over each output dimension
132
134
for i in range (out_dim ):
133
135
# calculate the start and end index of each pooling window
@@ -137,17 +139,22 @@ def end_index(idx: int, out_dim: int, in_dim: int) -> int:
137
139
# slice the input tensor from start to end index, the result of which is the window waiting for pooling
138
140
slices = []
139
141
for j in range (start , end ):
140
- slice = impl .select .select (
141
- ctx , target , source_ir , f"{ name } _select_{ i } _{ j } " , input , - 1 , j
142
- )
143
- slice = impl .shuffle .reshape (
144
- ctx ,
145
- target ,
146
- source_ir ,
147
- f"{ name } _reshape_{ i } _{ j } " ,
148
- slice ,
149
- (* slice .shape , 1 ),
150
- )
142
+ if j in idx_slice_map :
143
+ slice = idx_slice_map [j ]
144
+ else :
145
+ slice = impl .select .select (
146
+ ctx , target , source_ir , f"{ name } _select_{ j } " , input , - 1 , j
147
+ )
148
+ slice = impl .shuffle .reshape (
149
+ ctx ,
150
+ target ,
151
+ source_ir ,
152
+ f"{ name } _reshape_{ i } _{ j } " ,
153
+ slice ,
154
+ (* slice .shape , 1 ),
155
+ )
156
+ idx_slice_map [j ] = slice
157
+
151
158
slices .append (slice )
152
159
153
160
slices = impl .cat .cat (
0 commit comments