Closed
Description
Bug Description
Running the vgg16_ptq example doesn't work and fails with error: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
To Reproduce
Copied the code from the site into a test.py
. Instead of training the model directly, used the default vgg16_bn checkpoint from Pytorch (here).
Changes from original code are:
- The module sizes were slightly different so I updated the values in the VGG module definition.
- The checkpoint loading code is slightly different since it isn't nested anymore.
I don't think these changes should have affected anything.
test.py
(pastebin link if easier)
import argparse
import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode
class VGG(nn.Module):
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
super(VGG, self).__init__()
layers = []
in_channels = 3
for l in layer_spec:
if l == "pool":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
layers += [
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
nn.BatchNorm2d(l),
nn.ReLU(),
]
in_channels = l
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def vgg16(num_classes=1000, init_weights=False):
vgg16_cfg = [
64,
64,
"pool",
128,
128,
"pool",
256,
256,
256,
"pool",
512,
512,
512,
"pool",
512,
512,
512,
"pool",
]
return VGG(vgg16_cfg, num_classes, init_weights)
PARSER = argparse.ArgumentParser(
description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
)
PARSER.add_argument(
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
"--batch-size",
default=128,
type=int,
help="Batch size for tuning the model with PTQ and FP8",
)
PARSER.add_argument(
"--quantize-type",
default="int8",
type=str,
help="quantization type, currently supported int8 or fp8 for PTQ",
)
args = PARSER.parse_args()
model = vgg16(num_classes=1000, init_weights=False)
model = model.cuda()
ckpt = torch.load(args.ckpt)
weights = ckpt
model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()
training_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
training_dataloader = torch.utils.data.DataLoader(
training_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
drop_last=True,
)
data = iter(training_dataloader)
images, _ = next(data)
crit = nn.CrossEntropyLoss()
def calibrate_loop(model):
# calibrate over the training dataset
total = 0
correct = 0
loss = 0.0
for data, labels in training_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
total += labels.size(0)
correct += (preds == labels).sum().item()
print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))
if args.quantize_type == "int8":
quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
# Load the testing dataset
testing_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
testing_dataloader = torch.utils.data.DataLoader(
testing_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
drop_last=True,
) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`
with torch.no_grad():
with export_torch_mode():
# Compile the model with Torch-TensorRT Dynamo backend
input_tensor = images.cuda()
# torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))
if args.quantize_type == "int8":
enabled_precisions = {torch.int8}
elif args.quantize_type == "fp8":
enabled_precisions = {torch.float8_e4m3fn}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions=enabled_precisions,
min_block_size=1,
debug=True,
)
# You can also use torch compile path to compile the model with Torch-TensorRT:
# trt_model = torch.compile(model, backend="tensorrt")
# Inference compiled Torch-TensorRT model over the testing dataset
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = trt_model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_preds = torch.cat(class_preds)
test_loss = loss / total
test_acc = correct / total
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
I then run (vgg16_bn-6c64b313.pth
is the locally downloaded checkpoint):
python test.py --ckpt vgg16_bn-6c64b313.pth
And get error:
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
from user code:
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
return self._dm_attribute_manager
Full trace
[WARNING | root ]: Supported flash-attn versions are >= 2.1.1, <= 2.6.3. Found flash-attn 2.7.4.post1.
[WARNING | torch_tensorrt.dynamo.conversion.converter_utils]: TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/27/2025-05:44:05] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
Inserted 86 quantizers
PTQ Loss: 0.18583 Acc: 0.00%
Loading extension modelopt_cuda_ext...
Traceback (most recent call last):
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 263, in __call__
self.call_reconstruct(value)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 90, in call_reconstruct
res = value.reconstruct(self)
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 358, in reconstruct
raise NotImplementedError
NotImplementedError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ubuntu/test.py", line 199, in <module>
exp_program = _export(model, (input_tensor,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1990, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1255, in _strict_export
return _strict_export_lower_to_aten_ir(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/test.py", line 57, in forward
def forward(self, x):
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 786, in __getattr__
manager = self._get_dm_attribute_manager(use_default=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
self._return(inst)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
self.output.compile_subgraph(
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1120, in compile_subgraph
self.codegen_suffix(tx, stack_values, pass1)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1193, in codegen_suffix
cg.restore_stack(stack_values, value_from_source=not tx.export)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 82, in restore_stack
self.foreach(stack_values)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 293, in foreach
self(i)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 265, in __call__
unimplemented(f"reconstruct: {value}")
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
from user code:
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
return self._dm_attribute_manager
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Expected behavior
Should run correctly.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.6.0
- PyTorch Version (e.g. 1.0): 2.6.0
- CPU Architecture: x86
- OS (e.g., Linux): Ubuntu 22.04.5 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
pip install torch_tensorrt
- Are you using local sources or building from archives:
- Python version: 3.11.10
- CUDA version: 550.127.05
- GPU models and configuration: Nvidia L4
- Any other relevant information: