Skip to content

Closed the perf gap of resnet and enabled refit #3629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def construct_refit_mapping(
compilation_settings=settings,
)
interpreter._construct_trt_network_def()
weight_refit_map: dict[str, torch.Tensor] = interpreter.ctx.weight_refit_map

return interpreter.ctx.weight_refit_map
return weight_refit_map


@needs_refit
Expand All @@ -90,7 +91,18 @@ def construct_refit_mapping_from_weight_name_map(
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
if sd_weight_name not in state_dict:
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should abstract this imo. Like if there are any weight types that require constant folding in converter this should be associated with the converter. Then the refit system will just iterate through all these constant fold operations. Ideally the converter can use the same implementation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently BN is the only one. Do you think we should have a constant_fold function and have refit and conversion call that function?

# Batch Norm Layer
params = {}
for w in sd_weight_name:
params[w.split(".")[-1]] = state_dict[w].cuda()
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
shift = params["bias"] - params["running_mean"] * scale
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)
elif sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
Expand Down Expand Up @@ -300,7 +312,7 @@ def refit_module_weights(

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
new_gm, settings.debug, settings.torch_executed_ops
new_gm, settings.torch_executed_ops
)

if num_supported_ops == 0 or (
Expand Down Expand Up @@ -363,7 +375,6 @@ def refit_module_weights(

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
new_weight_module.module().to(CPU_DEVICE)
for name, new_submodule in new_partitioned_module.named_children():
# Refit each submodule
# Extract engine from the submodule
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def to_trt_weights(
ctx: ConversionContext,
value: torch.Tensor,
name: str,
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT"],
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT"],
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"],
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"],
target: Optional[Union[Target, str]] = None,
source_ir: Optional[SourceIR] = None,
target_quantized_type: Optional[trt.DataType] = None,
Expand All @@ -362,8 +362,8 @@ def to_trt_weights(
)

# Weight Recording
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"]
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"]
assert (
layer_type_name in supported_layer_types
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."
Expand Down
222 changes: 148 additions & 74 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_trt_weights,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
Expand Down Expand Up @@ -47,90 +48,163 @@ def batch_norm(

# Save the original output shape for later use
output_shape = input.shape
# We perform constant folding for batch norm when the weight, bias, running_mean, and running_var are all tensors.
# Batch norm operation can be fused into a single layer, which is more efficient than the original implementation.
# In this way, the batch norm layer will be fused with the Convolution layer and get a performance boost.
if all(
[
isinstance(weight, torch.Tensor),
isinstance(bias, torch.Tensor),
isinstance(running_mean, torch.Tensor),
isinstance(running_var, torch.Tensor),
]
):
if weight is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is redundant. If weight was None, then the isinstance(weight, torch.Tensor) above would already return False, and hence it wouldn't enter this branch at all.

weight = 1.0

if bias is None:
bias = 0.0

if running_mean is None:
running_mean = 0.0

if running_var is None:
running_var = 1.0
adjusted_scale = weight / torch.sqrt(running_var + eps)
adjusted_bias = bias - running_mean * adjusted_scale
power = torch.ones_like(adjusted_scale)
adjusted_scale = to_trt_weights(
ctx,
adjusted_scale,
name,
layer_type_name="SCALE",
weight_type_name="SCALE",
target=target,
source_ir=source_ir,
)
adjusted_bias = to_trt_weights(
ctx,
adjusted_bias,
name,
layer_type_name="SCALE",
weight_type_name="SHIFT",
target=target,
source_ir=source_ir,
)

# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 0.0, f"{name}_bias")
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)
power = to_trt_weights(
ctx,
power,
name,
layer_type_name="SCALE",
weight_type_name="POWER",
target=target,
source_ir=source_ir,
)

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
output_shape = input.shape
if len(input.shape) < 4:

# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)
new_shape = (
(input.shape[0], input.shape[1], 1, 1)
if len(input.shape) == 2
else (input.shape[0], input.shape[1], input.shape[2], 1)
)
input = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
)

# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)
layer = ctx.net.add_scale_nd(
input, trt.ScaleMode.CHANNEL, adjusted_bias, adjusted_scale, power, 1
)
set_layer_name(layer, target, name, source_ir)
output = layer.get_output(0)

# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)
else:

# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)
# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 0.0, f"{name}_bias")
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)

# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)
# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension
# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
)

scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)
# sqrt_adjusted_var = sqrt(adjusted_var)
sqrt_adjusted_var = impl.unary.sqrt(
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
)

# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)
# scale = weight / sqrt_adjusted_var
scale = impl.elementwise.div(
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
)

# scaled_running_mean = running_mean * scale
scaled_running_mean = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
)

# bias_adjusted = bias - scaled_running_mean
bias_adjusted = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
)

# Reshape scale and bias_adjusted to match input shape for broadcasting
expanded_shape = [1] * len(output_shape)
expanded_shape[1] = output_shape[1] # Set channel dimension

scale_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_scale",
scale,
tuple(expanded_shape),
)
bias_adjusted_reshape = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_bias",
bias_adjusted,
tuple(expanded_shape),
)

# Apply the scale and bias to the input
scaled_input = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
)
output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_output",
scaled_input,
bias_adjusted_reshape,
)

# For BatchNorm1d, reshape output back to original shape if necessary
if len(output_shape) < 4:
Expand Down
Loading