Skip to content

Upgrade perf_run script to support TRT 10 and fix some issues #3650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3579,3 +3579,23 @@ def aten_ops_nonzero(
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that registering a converter for OpOverloadPacket has no effect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found for some models in fp16, for example, bert, registering a linear op can reduce latency. It seems no effect for fp32 though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you double check in your case if the linear converter affects perf in fp16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that you should only register torch.ops.aten.linear.default (an OpOverload) and remove the second line because registering torch.ops.aten.linear (an OpOverloadPacket) was redundant.

def aten_ops_linear(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.linear.linear(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args_bounds_check(args, 2, None),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
embedding,
full,
grid,
linear,
matmul,
nccl_ops,
normalization,
Expand Down
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.fx.types import TRTTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_tensorrt.dynamo.types



def linear(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
) -> TRTTensor:
# Process weight terms
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
raise RuntimeError(
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
)
elif isinstance(weight, (torch.Tensor, np.ndarray)):
weight = get_trt_tensor(ctx, weight, f"{name}_weight")

# Process bias terms
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
raise RuntimeError(
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
)
elif isinstance(bias, (torch.Tensor, np.ndarray)):
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

# add IMatrixMultiplyLayer
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name,
input,
weight,
input_matrix_op=trt.MatrixOperation.NONE,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)

if bias is not None:
# add bias
out = impl.elementwise.add(ctx, target, source_ir, name, out, bias)

return out
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_bicubic2d.vec,
aten.linear.default,
}


Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def constant_fold(
# The constants are created on CPU to save GPU memory for TensorRT compilation.
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
if node.target == torch.ops.aten.embedding.default:
continue
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
)
Expand Down Expand Up @@ -103,7 +105,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.quantization_ops: Set[torch._ops.OpOverload] = set()
try:
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
import modelopt.torch.quantization as mtq
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.quantize_op.default
self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
Expand Down
4 changes: 2 additions & 2 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ This is a comprehensive Python benchmark suite to run perf runs using different
5. TensorRT


Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.

## Prerequisite

Benchmark scripts depends on following Python packages in addition to requirements.txt packages
Expand Down Expand Up @@ -47,13 +45,15 @@ Here are the list of `CompileSpec` options that can be provided directly to comp
* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
* `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
* `--batch_size` : Batch size
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
* `--device` : Device ID
* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT
* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
* `--report` : Path of the output file where performance summary is written.
* `--optimization_level` : Builder optimization level for TensorRT (from 1 to 5, 5 is the highest optimization).

Eg:

Expand Down
158 changes: 108 additions & 50 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
compile_settings = {
"inputs": input_tensors,
"enabled_precisions": {precision_to_dtype(precision)},
"truncate_long_and_double": params.get("truncate", False),
"use_python_runtime": params.get("use_python_runtime", False),
"truncate_double": params.get("truncate", False),
}

if precision == "int8":
Expand Down Expand Up @@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
ir="dynamo",
enabled_precisions={precision_to_dtype(precision)},
min_block_size=params.get("min_block_size", 1),
debug=False,
truncate_long_and_double=params.get("truncate", False),
truncate_double=params.get("truncate", False),
immutable_weights=params.get("immutable_weights", True),
strip_engine_weights=params.get("strip_engine_weights", False),
refit_identical_engine_weights=params.get(
Expand All @@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
cache_built_engines=params.get("cache_built_engines", False),
reuse_cached_engines=params.get("reuse_cached_engines", False),
use_python_runtime=params.get("use_python_runtime", False),
optimization_level=params.get("optimization_level", 3),
)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
Expand Down Expand Up @@ -437,61 +436,104 @@ def run_tensorrt(
precision,
batch_size=1,
):
# Export an ONNX model and convert to TRT
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse_from_file("./tmp.onnx")
if not success:
raise ValueError("ONNX conversion failed")

config = builder.create_builder_config()
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
start_compile = timeit.default_timer()
serialized_engine = builder.build_serialized_network(network, config)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
compile_time_s = 0
if params["is_trt_engine"]:
serialized_engine = model
else:
if params["onnx"]:
onnx_path = params["onnx"]
else:
onnx_path = "./onnx-trt.onnx"
torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse_from_file(onnx_path)
if not success:
raise ValueError("ONNX conversion failed")

config = builder.create_builder_config()
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
config.builder_optimization_level = params.get("optimization_level", 3)
start_compile = timeit.default_timer()
serialized_engine = builder.build_serialized_network(network, config)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
# Deserialize the TensorRT engine
with trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)

print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
iters = params.get("iterations", 20)

# Compiling the bindings
bindings = engine.num_bindings * [None]
k = 0
for idx, _ in enumerate(bindings):
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))
shape = tuple(engine.get_binding_shape(idx))
device = torch_device_from_trt(engine.get_location(idx))
if not engine.binding_is_input(idx):
# Output bindings
output = torch.empty(size=shape, dtype=dtype, device=device)
bindings[idx] = output.data_ptr()
else:
# Input bindings
bindings[idx] = input_tensors[k].data_ptr()
k += 1
start_time = timeit.default_timer()
# Get I/O tensor information using TensorRT 10 API
input_names = []
output_names = []
input_dtypes = []
output_dtypes = []
input_shapes = []
output_shapes = []

for i in range(engine.num_io_tensors):
tensor_name = engine.get_tensor_name(i)
tensor_mode = engine.get_tensor_mode(tensor_name)
tensor_dtype = engine.get_tensor_dtype(tensor_name)
tensor_shape = engine.get_tensor_shape(tensor_name)

if tensor_mode == trt.TensorIOMode.INPUT:
input_names.append(tensor_name)
input_dtypes.append(torch_dtype_from_trt(tensor_dtype))
input_shapes.append(tuple(tensor_shape))
else: # trt.TensorIOMode.OUTPUT
output_names.append(tensor_name)
output_dtypes.append(torch_dtype_from_trt(tensor_dtype))
output_shapes.append(tuple(tensor_shape))

# Create output tensors
output_tensors = []
for i, (shape, dtype) in enumerate(zip(output_shapes, output_dtypes)):
output = torch.empty(size=shape, dtype=dtype, device="cuda")
output_tensors.append(output)

timings = []
with engine.create_execution_context() as context:
# Set input tensor addresses
for i, (input_name, input_tensor) in enumerate(zip(input_names, input_tensors)):
context.set_tensor_address(input_name, input_tensor.data_ptr())

# Set output tensor addresses
for output_name, output_tensor in zip(output_names, output_tensors):
context.set_tensor_address(output_name, output_tensor.data_ptr())

# Create a dedicated stream for TensorRT execution
dedicated_stream = torch.cuda.Stream()
current_stream = torch.cuda.current_stream()

# Warm up
for i in range(WARMUP_ITER):
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
# Wait for current stream to finish
dedicated_stream.wait_stream(current_stream)
context.execute_async_v3(dedicated_stream.cuda_stream)
# Wait for TensorRT stream to finish
current_stream.wait_stream(dedicated_stream)
torch.cuda.synchronize()

# Performance measurement
for i in range(iters):
start_time = timeit.default_timer()
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
# Wait for current stream to finish
dedicated_stream.wait_stream(current_stream)
context.execute_async_v3(dedicated_stream.cuda_stream)
# Wait for TensorRT stream to finish
current_stream.wait_stream(dedicated_stream)
torch.cuda.synchronize()
end_time = timeit.default_timer()
meas_time = end_time - start_time
timings.append(meas_time)
infer_time = end_time - start_time
timings.append(infer_time)

recordStats("TensorRT", timings, precision, batch_size, compile_time_s)

Expand All @@ -504,7 +546,6 @@ def run(
params,
precision,
batch_size=1,
is_trt_engine=False,
model_torch=None,
):
for backend in backends:
Expand All @@ -523,7 +564,7 @@ def run(
print("int8 precision expects calibration cache file for inference")
return False

if (model is None) and (backend in ("tensorrt", "ts_trt", "all")):
if (model is None) and (backend in ("ts_trt", "all")):
warnings.warn(
f"Requested backend {backend} without specifying a TorchScript Model, "
+ "skipping this backend"
Expand All @@ -547,11 +588,10 @@ def run(
batch_size,
)
run_tensorrt(
model,
model_torch,
input_tensors,
params,
precision,
is_trt_engine,
batch_size,
)
run_dynamo(model_torch, input_tensors, params, precision, batch_size)
Expand Down Expand Up @@ -604,6 +644,12 @@ def run(
default="",
help="Name of torch model file",
)
arg_parser.add_argument(
"--onnx",
type=str,
default="",
help="ONNX model file which helps bypass the step of exporting ONNX from torchscript model. If this argument is provided, the ONNX will be directly converted to TRT engine",
)
arg_parser.add_argument(
"--inputs",
type=str,
Expand Down Expand Up @@ -643,6 +689,12 @@ def run(
action="store_true",
help="Truncate long and double weights in the network in Torch-TensorRT",
)
arg_parser.add_argument(
"--optimization_level",
type=int,
default=3,
help="Builder optimization level for TensorRT",
)
arg_parser.add_argument(
"--is_trt_engine",
action="store_true",
Expand Down Expand Up @@ -702,8 +754,13 @@ def run(

# Load TorchScript model, if provided
if os.path.exists(model_name):
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()
if params["is_trt_engine"]:
with open(model_name, "rb") as f:
model = f.read()
print("Loading user provided trt engine: ", model_name)
else:
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()

# Load PyTorch Model, if provided
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
Expand All @@ -719,7 +776,9 @@ def run(
)

backends = parse_backends(params["backends"])
if ("dynamo" in backends or "torch_compile" in backends) and (model_torch is None):
if any(
backend in ["dynamo", "torch_compile", "tensorrt"] for backend in backends
) and (model_torch is None):
raise ValueError(
"No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument"
)
Expand All @@ -746,7 +805,6 @@ def run(
params,
precision,
batch_size,
is_trt_engine,
model_torch=model_torch,
)

Expand Down
Loading
Loading