Skip to content

🐛 [Bug] vgg16_ptq doesn't run correctly. #3419

Closed
@dragoneye-alex

Description

@dragoneye-alex

Bug Description

Running the vgg16_ptq example doesn't work and fails with error: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

To Reproduce

Copied the code from the site into a test.py. Instead of training the model directly, used the default vgg16_bn checkpoint from Pytorch (here).

Changes from original code are:

  1. The module sizes were slightly different so I updated the values in the VGG module definition.
  2. The checkpoint loading code is slightly different since it isn't nested anymore.

I don't think these changes should have affected anything.

test.py (pastebin link if easier)

import argparse

import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode


class VGG(nn.Module):
    def __init__(self, layer_spec, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()

        layers = []
        in_channels = 3
        for l in layer_spec:
            if l == "pool":
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers += [
                    nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
                    nn.BatchNorm2d(l),
                    nn.ReLU(),
                ]
                in_channels = l

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def vgg16(num_classes=1000, init_weights=False):
    vgg16_cfg = [
        64,
        64,
        "pool",
        128,
        128,
        "pool",
        256,
        256,
        256,
        "pool",
        512,
        512,
        512,
        "pool",
        512,
        512,
        512,
        "pool",
    ]
    return VGG(vgg16_cfg, num_classes, init_weights)


PARSER = argparse.ArgumentParser(
    description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
)
PARSER.add_argument(
    "--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
    "--batch-size",
    default=128,
    type=int,
    help="Batch size for tuning the model with PTQ and FP8",
)
PARSER.add_argument(
    "--quantize-type",
    default="int8",
    type=str,
    help="quantization type, currently supported int8 or fp8 for PTQ",
)
args = PARSER.parse_args()

model = vgg16(num_classes=1000, init_weights=False)
model = model.cuda()


ckpt = torch.load(args.ckpt)
weights = ckpt


model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()

training_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)
training_dataloader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=2,
    drop_last=True,
)

data = iter(training_dataloader)
images, _ = next(data)

crit = nn.CrossEntropyLoss()

def calibrate_loop(model):
    # calibrate over the training dataset
    total = 0
    correct = 0
    loss = 0.0
    for data, labels in training_dataloader:
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        out = model(data)
        loss += crit(out, labels)
        preds = torch.max(out, 1)[1]
        total += labels.size(0)
        correct += (preds == labels).sum().item()

    print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))

if args.quantize_type == "int8":
    quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
    quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point


# Load the testing dataset
testing_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)

testing_dataloader = torch.utils.data.DataLoader(
    testing_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=2,
    drop_last=True,
)  # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`

with torch.no_grad():
    with export_torch_mode():
        # Compile the model with Torch-TensorRT Dynamo backend
        input_tensor = images.cuda()
        # torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
        from torch.export._trace import _export

        exp_program = _export(model, (input_tensor,))
        if args.quantize_type == "int8":
            enabled_precisions = {torch.int8}
        elif args.quantize_type == "fp8":
            enabled_precisions = {torch.float8_e4m3fn}
        trt_model = torchtrt.dynamo.compile(
            exp_program,
            inputs=[input_tensor],
            enabled_precisions=enabled_precisions,
            min_block_size=1,
            debug=True,
        )
        # You can also use torch compile path to compile the model with Torch-TensorRT:
        # trt_model = torch.compile(model, backend="tensorrt")

        # Inference compiled Torch-TensorRT model over the testing dataset
        total = 0
        correct = 0
        loss = 0.0
        class_probs = []
        class_preds = []
        for data, labels in testing_dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = trt_model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

        test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
        test_preds = torch.cat(class_preds)
        test_loss = loss / total
        test_acc = correct / total
        print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

I then run (vgg16_bn-6c64b313.pth is the locally downloaded checkpoint):

python test.py --ckpt vgg16_bn-6c64b313.pth

And get error:

torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

from user code:
   File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
    return self._dm_attribute_manager

Full trace

[WARNING  | root               ]: Supported flash-attn versions are >= 2.1.1, <= 2.6.3. Found flash-attn 2.7.4.post1.
[WARNING  | torch_tensorrt.dynamo.conversion.converter_utils]: TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/27/2025-05:44:05] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
Inserted 86 quantizers
PTQ Loss: 0.18583 Acc: 0.00%
Loading extension modelopt_cuda_ext...
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 263, in __call__
    self.call_reconstruct(value)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 90, in call_reconstruct
    res = value.reconstruct(self)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 358, in reconstruct
    raise NotImplementedError
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/test.py", line 199, in <module>
    exp_program = _export(model, (input_tensor,))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1990, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1255, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/test.py", line 57, in forward
    def forward(self, x):
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 786, in __getattr__
    manager = self._get_dm_attribute_manager(use_default=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
    self._return(inst)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
    self.output.compile_subgraph(
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1120, in compile_subgraph
    self.codegen_suffix(tx, stack_values, pass1)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1193, in codegen_suffix
    cg.restore_stack(stack_values, value_from_source=not tx.export)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 82, in restore_stack
    self.foreach(stack_values)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 293, in foreach
    self(i)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 265, in __call__
    unimplemented(f"reconstruct: {value}")
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

from user code:
   File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
    return self._dm_attribute_manager

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Expected behavior

Should run correctly.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0
  • PyTorch Version (e.g. 1.0): 2.6.0
  • CPU Architecture: x86
  • OS (e.g., Linux): Ubuntu 22.04.5 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): pip install torch_tensorrt
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 550.127.05
  • GPU models and configuration: Nvidia L4
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions