Skip to content

🐛 [Bug] Part of the weights are placed to CPU during compilation #3450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
cehongwang opened this issue Mar 25, 2025 · 3 comments
Open
Labels
bug Something isn't working

Comments

@cehongwang
Copy link
Collaborator

cehongwang commented Mar 25, 2025

Bug Description

When compiling Bert, a device mismatch occurs. This seems to be caused by weights moved to CPU during compilation.

To Reproduce

Steps to reproduce the behavior:

Run this script:

import torch
import torch_tensorrt as torchtrt

from transformers import BertModel

inputs = [
        torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"),
    ]
model = BertModel.from_pretrained("bert-base-uncased").eval().to("cuda")
enabled_precisions = {torch.float}
debug = True
min_block_size = 1
use_python_runtime = False

exp_program = torch.export.export(model, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
    exp_program,
    tuple(inputs),
    use_python_runtime=use_python_runtime,
    enabled_precisions=enabled_precisions,
    debug=debug,
    min_block_size=min_block_size,
    immutable_weights=False,
)


Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): main branch
  • PyTorch Version (e.g. 1.0): nightly
  • OS (e.g., Linux): LInux
  • How you installed PyTorch (conda, pip, libtorch, source): pip

Additional context

@cehongwang cehongwang added the bug Something isn't working label Mar 25, 2025
@cehongwang
Copy link
Collaborator Author

Image

If you add a device check in torchtrt.dynamo.compile function, the result is

{device(type='cuda', index=0)}
{device(type='cuda', index=0), device(type='cpu')}

@HolyWu
Copy link
Contributor

HolyWu commented Mar 25, 2025

It seems the culprit is immutable_weights=False because immutable_weights=True (default) compiles fine.

@cehongwang
Copy link
Collaborator Author

It's because when immutable_weight=False weights on the CPU device raise an error. When immutable_weight is True weights are still on CPU but no error is raised.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants