11from typing import Optional , cast
22import math
3+ import numpy as np
34
45from torch .fx .node import Target
56
@@ -25,12 +26,6 @@ def slice_op(
2526 stop : int ,
2627 step : int ,
2728) -> TRTTensor :
28- if not isinstance (input , TRTTensor ):
29- raise RuntimeError (
30- f"slice_tensor received input { input } that is not part "
31- "of the TensorRT region!"
32- )
33-
3429 ranks = len (input .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
3530 dim = get_positive_dim (cast (int , dim ), ranks )
3631 dynamic_shape = has_dynamic_shape (input .shape )
@@ -49,6 +44,22 @@ def slice_op(
4944 if stop_int == 2 ** 63 - 1 :
5045 stop_int = input .shape [dim ]
5146 step_int = cast (int , step )
47+
48+ if isinstance (input , np .ndarray ):
49+ tensor_to_freeze = np .take (
50+ input , np .arange (start_int , stop_int , step_int ), axis = dim
51+ )
52+ # TODO: Fix naming for constant tensors
53+ frozen_trt_tensor = get_trt_tensor (network , tensor_to_freeze , name )
54+ return frozen_trt_tensor
55+
56+ if not isinstance (input , TRTTensor ):
57+ raise RuntimeError (
58+ f"slice_tensor received input { input } that is not part "
59+ "of the TensorRT region!"
60+ )
61+
62+ # TRT Input Formatting
5263 start = [0 ] * len (input .shape )
5364 start [dim ] = start_int
5465 stride = [1 ] * len (start )
0 commit comments