From 3bf4b48ba27daa23e4ffaa09baa4be1d65701b9b Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 27 Nov 2023 18:01:47 -0800 Subject: [PATCH 1/2] add index_dtype_validator --- .../dynamo/conversion/aten_ops_converters.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0a1674ff94..d46f0c7345 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -364,7 +364,14 @@ def aten_ops_sigmoid( ) -@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) +def index_dtype_validator(node: Node) -> bool: + index = node.args[1] + return all(ind.meta["val"].dtype == torch.int32 for ind in index if ind is not None) + + +@dynamo_tensorrt_converter( + torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator +) @enforce_tensor_types( { 0: (TRTTensor,), From 52195eb6292b01a9f38f353b2629a11f2af527dd Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 4 Dec 2023 16:58:32 -0800 Subject: [PATCH 2/2] fix bugs --- py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d46f0c7345..4ad805461d 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -366,7 +366,12 @@ def aten_ops_sigmoid( def index_dtype_validator(node: Node) -> bool: index = node.args[1] - return all(ind.meta["val"].dtype == torch.int32 for ind in index if ind is not None) + for ind in index: + if ind is not None: + val = ind.meta.get("val") + if val is not None and val.dtype != torch.int32: + return False + return True @dynamo_tensorrt_converter(