-
Notifications
You must be signed in to change notification settings - Fork 372
Closed
Labels
bugSomething isn't workingSomething isn't workingcomponent: convertersIssues re: Specific op convertersIssues re: Specific op converters
Description
Bug Description
When the rank and tensor shape disagree, the torch.aten.ops.expand
operator fails due to this portion of the code:
TensorRT/py/torch_tensorrt/fx/converters/acc_ops_converters.py
Lines 2457 to 2475 in 8c62fca
@tensorrt_converter(acc_ops.expand) | |
def acc_ops_expand_tensor( | |
network: TRTNetwork, | |
target: Target, | |
args: Tuple[Argument, ...], | |
kwargs: Dict[str, Argument], | |
name: str, | |
) -> Union[TRTTensor, Sequence[TRTTensor]]: | |
input_t = kwargs["input"] | |
shape = list(kwargs["sizes"]) | |
input_val = get_trt_tensor(network, input_t, f"{name}_input") | |
if network.has_implicit_batch_dimension: | |
shape = shape[1:] | |
ranks = len(input_val.shape) | |
# TRT does not support different dimension size | |
assert len(shape) == ranks |
This is not in agreement with Torch behavior, where calling
.expand
on a Tensor does not require that the expanded size have the same rank as the original Tensor. See documentation here.
import torch
x = torch.ones(2, 2)
y = x.expand([5, 5, 5, 5, -1, -1])
print(y.shape)
>>> torch.Size([5, 5, 5, 5, 2, 2])
This is the error message in the converter:
File "~/TensorRT/py/torch_tensorrt/fx/converters/acc_ops_converters.py", line 2475, in acc_ops_expand_tensor
assert len(shape) == ranks
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
AssertionError: While executing %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%permute_3, [1, 512, 512]), kwargs =...
To Reproduce
See above code snippet for desired behavior from converter.
Expected behavior
Converter should succeed in this case.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 8c62fca
- PyTorch Version (e.g. 1.0):
2.1.0.dev20230803+cu121
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcomponent: convertersIssues re: Specific op convertersIssues re: Specific op converters