-
Notifications
You must be signed in to change notification settings - Fork 369
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
The PR #3167 is supporting weight-stripped engines, which works for PythonTorchTensorRTModule
but not for TorchTensorRTModule
.
I observed the issue in the test:
TensorRT/tests/py/dynamo/models/test_weight_stripped_engine.py
Lines 487 to 523 in 76bdf5e
def test_two_TRTRuntime_in_refitting(self): | |
pyt_model = models.resnet18(pretrained=True).eval().to("cuda") | |
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) | |
batch = torch.export.Dim("batch", min=1, max=200) | |
exp_program = torch.export.export( | |
pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} | |
) | |
inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] | |
pyt_results = pyt_model(*inputs) | |
for i in range(2): | |
if i == 0: | |
use_python_runtime = True | |
else: | |
use_python_runtime = False | |
trt_gm = torch_trt.dynamo.compile( | |
exp_program, | |
tuple(inputs), | |
use_python_runtime=use_python_runtime, | |
debug=False, | |
min_block_size=1, | |
strip_engine_weights=True, | |
refit_identical_engine_weights=False, | |
) | |
output = trt_gm(*inputs) | |
assertions.assertEqual(output.sum(), 0, msg="results should be all zeros") | |
refitted_trt_gm = refit_module_weights(trt_gm, exp_program) | |
refitted_output = refitted_trt_gm(*inputs) | |
cos_sim = cosine_similarity(pyt_results, refitted_output) | |
assertions.assertTrue( | |
cos_sim > COSINE_THRESHOLD, | |
msg=f"{'PythonTorchTensorRTModule' if use_python_runtime else 'TorchTensorRTModule'} outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", | |
) |
The CI test reports the error:
FAILED models/test_weight_stripped_engine.py::TestWeightStrippedEngine::test_two_TRTRuntime_in_refitting - AssertionError: False is not true : TorchTensorRTModule outputs don't match with the original model. Cosine sim score: 0.0 Threshold: 0.99
I output refitted_output
while using TorchTensorRTModule
, which is all zeros, so it seems like the refitting was not successful.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working