diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0a1674ff94..4ad805461d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -364,7 +364,19 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) +def index_dtype_validator(node: Node) -> bool: + index = node.args[1] + for ind in index: + if ind is not None: + val = ind.meta.get("val") + if val is not None and val.dtype != torch.int32: + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator +) @enforce_tensor_types( { 0: (TRTTensor,),