Skip to content

cherry pick 3680: fix refit test bug #3687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def construct_refit_mapping_from_weight_name_map(
params[w.split(".")[-1]] = state_dict[w].cuda()
# Batch norm constant folding

scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
scale, shift = batch_norm_constant_folding(**params, eps=1e-5)
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
Expand Down
64 changes: 64 additions & 0 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,70 @@ def test_mapping():
torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
)
@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.refit,
"Refit feature is not supported in Python 3.13 or higher",
)
@unittest.skipIf(
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
@pytest.mark.unit
def test_conv_refit_with_weightmap():
class net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)

def forward(self, x):
return self.conv(x)

model = net().eval().to("cuda")
model2 = net().eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
min_block_size = 1
use_python_runtime = True

exp_program = torch.export.export(model, tuple(inputs))
exp_program2 = torch.export.export(model2, tuple(inputs))

trt_gm = torchtrt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
min_block_size=min_block_size,
immutable_weights=False,
)

new_trt_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=exp_program2,
arg_inputs=inputs,
use_weight_map_cache=True,
verify_output=True,
)

# Check the output
model2.to("cuda")
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(
*inputs
)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assertions.assertTrue(
torch.allclose(expected_output, refitted_output, 1e-2, 1e-2),
"Refit Result is not correct. Refit failed",
)
# Clean up model env

torch._dynamo.reset()


@unittest.skipIf(
not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime,
"TorchScript Frontend is not available",
Expand Down
4 changes: 1 addition & 3 deletions tests/py/dynamo/runtime/test_mutable_torchtrt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ def test_resnet18_modify_attribute():
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec)
mutable_module(*inputs)

mutable_module.conv1.weight = nn.Parameter(
torch.rand_like(mutable_module.conv1.weight)
)
mutable_module.fc.weight = nn.Parameter(torch.rand_like(mutable_module.fc.weight))
assertions.assertEqual(
mutable_module.refit_state.get_state(),
RefitFlag.UNKNOWN,
Expand Down
Loading