Skip to content

🐛 [Bug] [weight-stripped engine] doesn't work for TorchTensorRTModule #3217

Closed
@zewenli98

Description

@zewenli98

Bug Description

The PR #3167 is supporting weight-stripped engines, which works for PythonTorchTensorRTModule but not for TorchTensorRTModule.

I observed the issue in the test:

def test_two_TRTRuntime_in_refitting(self):
pyt_model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)
inputs = [torch.rand((128, 3, 224, 224)).to("cuda")]
pyt_results = pyt_model(*inputs)
for i in range(2):
if i == 0:
use_python_runtime = True
else:
use_python_runtime = False
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
debug=False,
min_block_size=1,
strip_engine_weights=True,
refit_identical_engine_weights=False,
)
output = trt_gm(*inputs)
assertions.assertEqual(output.sum(), 0, msg="results should be all zeros")
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)
cos_sim = cosine_similarity(pyt_results, refitted_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"{'PythonTorchTensorRTModule' if use_python_runtime else 'TorchTensorRTModule'} outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

The CI test reports the error:

FAILED models/test_weight_stripped_engine.py::TestWeightStrippedEngine::test_two_TRTRuntime_in_refitting - AssertionError: False is not true : TorchTensorRTModule outputs don't match with the original model. Cosine sim score: 0.0 Threshold: 0.99

I output refitted_output while using TorchTensorRTModule, which is all zeros, so it seems like the refitting was not successful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions