Skip to content

🐛 [Bug] aten.expand fails when rank disagrees with tensor shape #2183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gs-olive opened this issue Aug 8, 2023 · 0 comments · Fixed by #2234
Closed

🐛 [Bug] aten.expand fails when rank disagrees with tensor shape #2183

gs-olive opened this issue Aug 8, 2023 · 0 comments · Fixed by #2234
Assignees
Labels
bug Something isn't working component: converters Issues re: Specific op converters

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Aug 8, 2023

Bug Description

When the rank and tensor shape disagree, the torch.aten.ops.expand operator fails due to this portion of the code:

@tensorrt_converter(acc_ops.expand)
def acc_ops_expand_tensor(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_t = kwargs["input"]
shape = list(kwargs["sizes"])
input_val = get_trt_tensor(network, input_t, f"{name}_input")
if network.has_implicit_batch_dimension:
shape = shape[1:]
ranks = len(input_val.shape)
# TRT does not support different dimension size
assert len(shape) == ranks

This is not in agreement with Torch behavior, where calling .expand on a Tensor does not require that the expanded size have the same rank as the original Tensor. See documentation here.

import torch
x = torch.ones(2, 2)
y = x.expand([5, 5, 5, 5, -1, -1])
print(y.shape)
>>> torch.Size([5, 5, 5, 5, 2, 2])

This is the error message in the converter:

File "~/TensorRT/py/torch_tensorrt/fx/converters/acc_ops_converters.py", line 2475, in acc_ops_expand_tensor
     assert len(shape) == ranks
 torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
 AssertionError: While executing %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_3, [1, 512, 512]), kwargs =...

To Reproduce

See above code snippet for desired behavior from converter.

Expected behavior

Converter should succeed in this case.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 8c62fca
  • PyTorch Version (e.g. 1.0): 2.1.0.dev20230803+cu121
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: converters Issues re: Specific op converters
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants