Closed
Description
Bug Description
The node torch.ops.aten._assert_tensor_metadata.default
occurred when using torch 2.7.0.dev, which was the culprit causing graph breaks.
To Reproduce
from __future__ import annotations
import os
import torch
import torch_tensorrt
os.environ["CI_BUILD"] = "1"
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
x = x.float()
x = x * x
x = x.half()
x = x - x
return x
with torch.inference_mode():
model = MyModule().eval().cuda().half()
inputs = [torch.randn(1, 3, 5, 7, dtype=torch.half, device="cuda")]
trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1, use_explicit_typing=True)
torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
print("assert_close passed")
Log diff: https://www.diffchecker.com/qe8cYJAS/