Skip to content

🐛 [Bug] Graph breaks in torch 2.7.0.dev but not in torch 2.6.0 #3375

Closed
@HolyWu

Description

@HolyWu

Bug Description

The node torch.ops.aten._assert_tensor_metadata.default occurred when using torch 2.7.0.dev, which was the culprit causing graph breaks.

To Reproduce

from __future__ import annotations

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + x
        x = x.float()
        x = x * x
        x = x.half()
        x = x - x
        return x


with torch.inference_mode():
    model = MyModule().eval().cuda().half()

    inputs = [torch.randn(1, 3, 5, 7, dtype=torch.half, device="cuda")]

    trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1, use_explicit_typing=True)

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
    print("assert_close passed")

Log diff: https://www.diffchecker.com/qe8cYJAS/

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions