From f1098f2520f30a6082597c93895c350905c8245d Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 20 Mar 2023 14:45:35 -0700 Subject: [PATCH 01/15] feat: Add sample torch.compile backend for tensorrt aten path - Add backend adapted from previous `fx2trt_compiler` provided by Dynamo - Currently, the TRTSplitter needs work to fully support the `aten` path - Additionally, the existing `aten` pass was reworked to exclude the `torch._dynamo.export` call, which may be necessary here --- .../fx/tracer/dispatch_tracer/aten_tracer.py | 8 +- .../tensorrt_dynamo_backend.py | 107 ++++++++++++++++++ 2 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..356ddc978e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, *rest): +def opt_trace(f, args, perform_trace=True, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,7 +148,11 @@ def opt_trace(f, args, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - fx_module, _ = trace(f, args) + if perform_trace: + fx_module, _ = trace(f, args) + else: + fx_module = f + print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py new file mode 100644 index 0000000000..bb6e68b0b5 --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -0,0 +1,107 @@ +import torch +import traceback +import torch._dynamo as td + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt +from torch_tensorrt.fx.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +MAX_SPLITS_THRESHOLD = 10 + + +def tensorrt_backend(gm, sample_inputs): + # Invoke AOTAutograd to compile model + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(fx2trt_compiler), + ) + + +def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): + model = gm + inputs = example_inputs + + # Perform lowering pass on model + model = aten_tracer.opt_trace(model, inputs, perform_trace=False) + + # Split out unsupported ops --> Needs rewrite/revision for ATEN + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + + splitter.node_support_preview() + split_mod = splitter() + num_piece = 0 + + for name, _ in split_mod.named_children(): + print(f"Graph is split into {name}") + num_pieces += 1 + + # Select threshold above which segmentation is not beneficial and run graph in Torch + if num_pieces > MAX_SPLITS_THRESHOLD: + raise AssertionError( + f"The graph module is split into {num_piece} which is large than the \ + threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." + ) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + + return split_mod + + +@td.register_backend +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + try: + trt_compiled = fx2trt(gm, example_inputs) + return trt_compiled + except Exception: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + ) + return gm.forward From 243bf9bc340e27837a33c3d6fc3c0998381aff0a Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 21 Mar 2023 16:17:51 -0700 Subject: [PATCH 02/15] Add decompositions to aot call --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index bb6e68b0b5..a76162b93b 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -17,6 +17,9 @@ from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler +from torch._inductor.decomposition import decompositions + +DECOMPOSITIONS = decompositions.copy() MAX_SPLITS_THRESHOLD = 10 @@ -26,6 +29,7 @@ def tensorrt_backend(gm, sample_inputs): gm, sample_inputs, fw_compiler=make_boxed_compiler(fx2trt_compiler), + decompositions=DECOMPOSITIONS, ) From 76fd3c8207bdf017af294f1883863a755045b1a8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:31:22 -0700 Subject: [PATCH 03/15] Mark FX2TRT converter as fake tensor unsupported --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index a76162b93b..20cea4ffd5 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -15,6 +15,8 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision +from torch._dynamo.backends.common import fake_tensor_unsupported + from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler from torch._inductor.decomposition import decompositions @@ -99,6 +101,7 @@ def get_input(self, inputs): @td.register_backend +@fake_tensor_unsupported def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): try: trt_compiled = fx2trt(gm, example_inputs) From 6a8102c14f3c0fa7a200222979888e9d213d0d84 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 28 Mar 2023 18:52:12 -0700 Subject: [PATCH 04/15] Minor naming bugfix --- .../fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 20cea4ffd5..55c5e2df33 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -49,7 +49,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): splitter.node_support_preview() split_mod = splitter() - num_piece = 0 + num_pieces = 0 for name, _ in split_mod.named_children(): print(f"Graph is split into {name}") @@ -58,7 +58,7 @@ def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): # Select threshold above which segmentation is not beneficial and run graph in Torch if num_pieces > MAX_SPLITS_THRESHOLD: raise AssertionError( - f"The graph module is split into {num_piece} which is large than the \ + f"The graph module is split into {num_pieces} which is large than the \ threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." ) From e97ed50eeb17b661cb7da060b5dd24bc32d9bb43 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 7 Apr 2023 11:12:12 -0700 Subject: [PATCH 05/15] Implementing aten::chunk, aten::layer_norm, aten::softmax, aten::where, aten::rsub, aten::rsqrt --- .../fx/converters/acc_ops_converters.py | 220 +------------- .../fx/converters/aten_ops_converters.py | 113 ++++++++ py/torch_tensorrt/fx/converters/operator.py | 269 +++++++++++++++++- .../converters/aten_op/test_chunk_aten.py | 58 ++++ .../aten_op/test_layer_norm_aten.py | 45 +++ .../converters/aten_op/test_rsqrt_aten.py | 0 .../test/converters/aten_op/test_rsub_aten.py | 0 .../converters/aten_op/test_softmax_aten.py | 44 +++ .../converters/aten_op/test_squeeze_aten.py | 67 +++++ .../converters/aten_op/test_where_aten.py | 56 ++++ 10 files changed, 662 insertions(+), 210 deletions(-) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index e556e81bb5..a321bb8dfe 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -678,7 +678,13 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) -def acc_ops_layer_norm(network, target, args, kwargs, name): +def acc_ops_layer_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_layer_norm(network, target, kwargs, name) @@ -690,37 +696,7 @@ def acc_ops_softmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"softmax received input {input_val} that is not part " - "of the TensorRT region!" - ) - - # Used to get dim when dim is None. Copied from PyTorch softmax implementation. - def get_softmax_dim(ndim: int) -> int: - if ndim == 0 or ndim == 1 or ndim == 3: - ret = 0 - else: - ret = 1 - return ret - - if kwargs["dim"] is None: - dim = get_softmax_dim(input_ranks) - else: - dim = cast(int, kwargs["dim"]) - - dim = get_positive_dim(dim, input_ranks) - if network.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 - - layer = network.add_softmax(input_val) - layer.axes = 1 << dim - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_softmax(network, target, kwargs, name) @tensorrt_converter(acc_ops.tile) @@ -956,9 +932,7 @@ def acc_ops_sqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return add_sqrt(network, target, kwargs, name) @tensorrt_converter(acc_ops.reciprocal) @@ -1619,40 +1593,7 @@ def acc_ops_squeeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"squeeze received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) - # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic - # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." - - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 - - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." - - output_shape = [] - for i, s in enumerate(input_val.shape): - if i == dim and s == 1: - continue - output_shape.append(s) - layer = network.add_shuffle(input_val) - layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_squeeze(network, target, kwargs, name) @tensorrt_converter(acc_ops.add) @@ -2022,89 +1963,7 @@ def acc_ops_where( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - - condition_t = kwargs["condition"] - x_t = kwargs["x"] - y_t = kwargs["y"] - - if type(x_t) != TRTTensor: - assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" - - if type(y_t) != TRTTensor: - assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" - - # get output shape - - x_shape = list(x_t.shape) - y_shape = list(y_t.shape) - condition_shape = list(condition_t.shape) - output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) - - # expand shape - if type(condition_t) != TRTTensor: - assert condition_t.dtype == torch.bool, "condition dtype is not bool" - if condition_shape != output_shape: - condition_t.expand(output_shape) - condition_t = condition_t.to(torch.int32) - condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") - condition_layer = network.add_identity(condition_const) - condition_layer.set_output_type(0, trt.bool) - set_layer_name(condition_layer, target, f"{name}_condition") - condition_val = condition_layer.get_output(0) - else: - assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: - condition_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": condition_t, "sizes": output_shape}, - name=f"{name}_expand", - ) - else: - condition_val = condition_t - - if type(x_t) != TRTTensor: - if x_shape != output_shape: - # special case where 1 element in x_t - if len(x_t.shape) == 0: - x_t = x_t.unsqueeze(0) - x_t = x_t.expand(output_shape) - x_val = get_trt_tensor(network, x_t, f"{name}_x") - else: - x_val = x_t - if x_shape != output_shape: - x_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": x_val, "sizes": output_shape}, - name=f"{name}_x_expand", - ) - - if type(y_t) != TRTTensor: - if y_shape != output_shape: - # special case where 1 element in y_t - if len(y_t.shape) == 0: - y_t = y_t.unsqueeze(0) - y_t = y_t.expand(output_shape) - y_val = get_trt_tensor(network, y_t, f"{name}_y") - else: - y_val = y_t - if y_shape != output_shape: - y_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": y_val, "sizes": output_shape}, - name=f"{name}_y_expand", - ) - - select_layer = network.add_select(condition_val, x_val, y_val) - - set_layer_name(select_layer, target, f"{name}_select") - - return select_layer.get_output(0) + return add_where(network, target, kwargs, name) @tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True) @@ -2721,62 +2580,7 @@ def acc_ops_chunk( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - chunks = cast(int, kwargs["chunks"]) - dim = cast(int, kwargs["dim"]) - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"chunk received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - input_dim_size += 1 - dim = get_positive_dim(dim, input_dim_size) - assert dim != 0, "Can't chunk on batch dim when it's implicit!" - dim -= 1 - else: - if dynamic_shape: - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - dim = get_positive_dim(dim, input_dim_size) - - if chunks > input_val.shape[dim]: - warnings.warn( - f"Asked for {chunks} chunks along dimention " - f"{dim} on tensor with size {input_val.shape}, chunks " - f"will default to {input_val.shape[dim]}", - RuntimeWarning, - ) - chunks = input_val.shape[dim] - - start = [0] * len(input_val.shape) - stride = [1] * len(start) - offset = 0 - split_size = (input_val.shape[dim] + chunks - 1) // chunks - - max_offset = input_val.shape[dim] - # add slice layers - output = [] - for i in range(chunks): - shape = list(input_val.shape) - shape[dim] = min(split_size, max_offset - offset) - if dynamic_shape: - shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" - ) - start[dim] = offset - layer = network.add_slice( - input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride - ) - if dynamic_shape: - layer.set_input(2, shape) - offset += split_size - set_layer_name(layer, target, f"{name}_{i}") - output.append(layer.get_output(0)) - return output + return add_chunk(network, target, kwargs, name) @tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1dbfa14076..d47f30a790 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -620,3 +620,116 @@ def aten_ops_matmul( "other": args[1], } return add_matmul(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "normalized_shape": args[1], + "weight": args[2], + "bias": args[3], + "eps": args[4], + } + return add_layer_norm(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_softmax(network, target, kwargs_new, name) + + +# FIXME: need to look at case where dim is tuple +@tensorrt_converter(torch.ops.aten.squeeze.dim) +@tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_squeeze(network, target, kwargs_new, name) + + +# FIXME: need to confirm lower basic passes +# @tensorrt_converter(torch.ops.aten.chunk) +# def aten_ops_chunk( +# network: TRTNetwork, +# target: Target, +# args: Tuple[Argument, ...], +# kwargs: Dict[str, Argument], +# name: str, +# ) -> Union[TRTTensor, Sequence[TRTTensor]]: +# kwargs_new = { +# "input": args[0], +# "chunks": args[1], +# "dim": args[2], +# } +# return add_chunk(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "condition": args[0], + "x": args[1], + "y": args[2], + } + return add_where(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsub) +def aten_ops_rsub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + "alpha": args[2], + } + return add_rsub(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsqrt) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_rsqrt(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 5955e598f5..8d45278548 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -580,7 +580,7 @@ def layer_norm( set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] - sub_trt = operator.add_binary_elementwise_layer( + sub_trt = add_binary_elementwise_layer( network, input_val, mean_expected_layer.get_output(0), @@ -594,7 +594,7 @@ def layer_norm( trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" - pow_var = operator.add_binary_elementwise_layer( + pow_var = add_binary_elementwise_layer( network, sub_trt, pow_tensor.get_output(0), @@ -739,6 +739,7 @@ def add_layer_norm(network, target, kwargs, name): _LOGGER.error( "Unable to find layer norm plugin, fall back to TensorRT implementation." ) + args = [] return layer_norm(network, target, args, kwargs, name) layer = network.add_plugin_v2([input_val], plugin) layer.name = name @@ -1254,3 +1255,267 @@ def add_matmul(network, target, kwargs, name): ) set_layer_name(layer, target, name) return layer.get_output(0) + + +def add_softmax(network, target, kwargs, name): + input_val = kwargs["input"] + input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"softmax received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if kwargs["dim"] is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, kwargs["dim"]) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input_val) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_squeeze(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"squeeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert dim is not None, "We don't support dim=None right now for squeeze." + + dim = get_positive_dim( + dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input_val.shape): + if i == dim and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_chunk(network, target, kwargs, name): + input_val = kwargs["input"] + chunks = cast(int, kwargs["chunks"]) + dim = cast(int, kwargs["dim"]) + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"chunk received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + input_dim_size += 1 + dim = get_positive_dim(dim, input_dim_size) + assert dim != 0, "Can't chunk on batch dim when it's implicit!" + dim -= 1 + else: + if dynamic_shape: + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + dim = get_positive_dim(dim, input_dim_size) + + if chunks > input_val.shape[dim]: + warnings.warn( + f"Asked for {chunks} chunks along dimention " + f"{dim} on tensor with size {input_val.shape}, chunks " + f"will default to {input_val.shape[dim]}", + RuntimeWarning, + ) + chunks = input_val.shape[dim] + + start = [0] * len(input_val.shape) + stride = [1] * len(start) + offset = 0 + split_size = (input_val.shape[dim] + chunks - 1) // chunks + + max_offset = input_val.shape[dim] + # add slice layers + output = [] + for i in range(chunks): + shape = list(input_val.shape) + shape[dim] = min(split_size, max_offset - offset) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, shape, input_val, target, f"{name}_{i}" + ) + start[dim] = offset + layer = network.add_slice( + input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_size + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output + + +def add_where(network, target, kwargs, name): + condition_t = kwargs["condition"] + x_t = kwargs["x"] + y_t = kwargs["y"] + + if type(x_t) != TRTTensor: + assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" + + if type(y_t) != TRTTensor: + assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + + # get output shape + + x_shape = list(x_t.shape) + y_shape = list(y_t.shape) + condition_shape = list(condition_t.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition_t) != TRTTensor: + assert condition_t.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition_t.expand(output_shape) + condition_t = condition_t.to(torch.int32) + condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition_t.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != output_shape: + condition_val = add_expand( + network, + target, + None, + {"input": condition_t, "sizes": output_shape}, + name=f"{name}_expand", + ) + else: + condition_val = condition_t + + if type(x_t) != TRTTensor: + if x_shape != output_shape: + # special case where 1 element in x_t + if len(x_t.shape) == 0: + x_t = x_t.unsqueeze(0) + x_t = x_t.expand(output_shape) + x_val = get_trt_tensor(network, x_t, f"{name}_x") + else: + x_val = x_t + if x_shape != output_shape: + x_val = add_expand( + network, + target, + None, + {"input": x_val, "sizes": output_shape}, + name=f"{name}_x_expand", + ) + + if type(y_t) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in y_t + if len(y_t.shape) == 0: + y_t = y_t.unsqueeze(0) + y_t = y_t.expand(output_shape) + y_val = get_trt_tensor(network, y_t, f"{name}_y") + else: + y_val = y_t + if y_shape != output_shape: + y_val = add_expand( + network, + target, + None, + {"input": y_val, "sizes": output_shape}, + name=f"{name}_y_expand", + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) + + +def add_scale(network, target, kwargs, name): + other = kwargs["other"] + scale = kwargs["scale"] + if isinstance(other, TRTTensor): + other_dtype = torch_dtype_from_trt(other.dtype) + is_other_trt_tensor = True + + if not is_other_trt_tensor: + warnings.warn( + f"The value to be scaled is constant" + "In this case, please consider constant fold the model first." + ) + return other * scale + layer = network.add_scale(other, trt.ScaleMode.UNIFORM, 0, scale, 1) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_rsub(network, target, kwargs, name): + scaled_tensor = add_scale(network, target, kwargs, name) + input = kwargs["input"] + return add_binary_elementwise_layer( + network, + kwargs["input"], + scaled_tensor, + trt.ElementWiseOperation.SUB, + target, + name, + ) + + +def add_sqrt(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SQRT + return add_unary_layer(network, input_val, operation_type, target, name) + + +def add_rsqrt(network, target, kwargs, name): + sqrt_trt = add_sqrt(network, target, kwargs, name) + div_trt = add_binary_elementwise_layer( + network, + 1, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py new file mode 100644 index 0000000000..8fae6da293 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py @@ -0,0 +1,58 @@ +import unittest + +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_chunk(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(11)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.chunk}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_chunk_dim", 6, 0), + ] + ) + def test_chunk(self, _, chunk, dim): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.chunk(input, chunk, dim) + return out + + input = [torch.randn(12)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.chunk}, + test_explicit_precision=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py new file mode 100644 index 0000000000..cf97e828d0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[(1, 3, 1, 1)], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py new file mode 100644 index 0000000000..31e293fc91 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py new file mode 100644 index 0000000000..5dd15a89e7 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + # ("3d_two_dim", (0, 1), (2, 2, 1)), + # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + InputTensorSpec( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py new file mode 100644 index 0000000000..6c050eee2f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)), + ] + ) + def test_(self, _, condition, x_size, y_size): + class Where(nn.Module): + def forward(self, x): + return torch.where(x, dim) + + inputX = [torch.randn(*x_size)] + inputOther = [torch.randn(*y_size)] + expected_op = {} + self.run_test( + Where(), + inputs, + expected_ops=torch.ops.aten.where.self, + ) + + +# class TestWhereConverter(DispatchTestCase): +# @parameterized.expand( +# [ +# ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), +# ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), +# ] +# ) +# def test_where(self, _, dim, init_size, shape_range): +# class Squeeze(nn.Module): +# def forward(self, x): +# return torch.squeeze(x, dim) + +# input_specs = [ +# InputTensorSpec( +# shape=init_size, +# dtype=torch.float32, +# shape_ranges=shape_range, +# ), +# ] +# self.run_test_with_dynamic_shape( +# Squeeze(), +# input_specs, +# expected_ops=torch.ops.aten.where.self, +# ) From c5a4744867e8637a58042972bfee133372dcfbb1 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 10 Apr 2023 09:13:14 -0700 Subject: [PATCH 06/15] Transformer operator changes --- .../fx/converters/converter_utils.py | 33 ++++++++++ py/torch_tensorrt/fx/converters/operator.py | 64 +++++++++++++------ .../fx/passes/lower_basic_pass_aten.py | 1 + .../converters/aten_op/test_rsqrt_aten.py | 29 +++++++++ .../test/converters/aten_op/test_rsub_aten.py | 29 +++++++++ .../converters/aten_op/test_squeeze_aten.py | 4 +- .../converters/aten_op/test_where_aten.py | 57 +++++++++-------- .../tensorrt_dynamo_backend.py | 2 +- 8 files changed, 171 insertions(+), 48 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 551c18652d..9d405767ea 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -288,6 +288,39 @@ def prepend_ones( return layer.get_output(0) +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + print("a shape is", a_shape) + print("b shape is", b_shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True + + def broadcast( network: TRTNetwork, a: TRTTensor, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 8d45278548..4449b7146e 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -15,6 +15,7 @@ from .converter_utils import set_layer_name from .converter_utils import get_trt_tensor from .converter_utils import broadcast +from .converter_utils import broadcastable from .converter_utils import squeeze_left from .converter_utils import dtype_uniform from .converter_utils import get_trt_plugin @@ -1119,7 +1120,6 @@ def add_expand(network, target, kwargs, name): # TRT does not support different dimension size assert len(shape) == ranks shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] - inshape = tuple(input_val.shape) shape = tuple(shape) start = tuple([0] * ranks) @@ -1299,27 +1299,36 @@ def add_squeeze(network, target, kwargs, name): f"squeeze received input {input_val} that is not part " "of the TensorRT region!" ) + dims = [] + if "dim" in kwargs: + if isinstance(kwargs["dim"], int): + dims.append(cast(Optional[int], kwargs["dim"])) + else: + for dim in kwargs["dim"]: + dims.append(cast(Optional[int], dim)) - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 + for dim in dims: + dim = get_positive_dim( + dim, + len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." output_shape = [] for i, s in enumerate(input_val.shape): - if i == dim and s == 1: + if (i in dims) and s == 1: continue output_shape.append(s) layer = network.add_shuffle(input_val) @@ -1392,14 +1401,32 @@ def add_where(network, target, kwargs, name): x_t = kwargs["x"] y_t = kwargs["y"] + x_t_dim = len(tuple(x_t.shape)) + y_t_dim = len(tuple(y_t.shape)) + condition_t_dim = len(tuple(condition_t.shape)) + if type(x_t) != TRTTensor: assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" if type(y_t) != TRTTensor: assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + if not (broadcastable(x_t, y_t)): + assert f"The two torch tensors should be broadcastable" + # get output shape + # purpose of this is to bring x_t and y_t rank same as + # output_shape to input it to the add_expand operation + # condition_t will have dimension of either x_t or y_t + x_t, y_t = broadcast(network, x_t, y_t, f"{name}_x", f"{name}_y") + if len(tuple(condition_t.shape)) != len(tuple(x_t.shape)): + condition_t, x_t = broadcast( + network, condition_t, x_t, f"{name}_condition", f"{name}_x" + ) + print("x_t shape", x_t.shape) + print("y_t shape", y_t.shape) + print("condition_t shape", condition_t.shape) x_shape = list(x_t.shape) y_shape = list(y_t.shape) condition_shape = list(condition_t.shape) @@ -1418,11 +1445,10 @@ def add_where(network, target, kwargs, name): condition_val = condition_layer.get_output(0) else: assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: + if condition_shape != condition_t_dim: condition_val = add_expand( network, target, - None, {"input": condition_t, "sizes": output_shape}, name=f"{name}_expand", ) @@ -1430,7 +1456,7 @@ def add_where(network, target, kwargs, name): condition_val = condition_t if type(x_t) != TRTTensor: - if x_shape != output_shape: + if x_shape != x_t_dim: # special case where 1 element in x_t if len(x_t.shape) == 0: x_t = x_t.unsqueeze(0) @@ -1442,7 +1468,6 @@ def add_where(network, target, kwargs, name): x_val = add_expand( network, target, - None, {"input": x_val, "sizes": output_shape}, name=f"{name}_x_expand", ) @@ -1456,11 +1481,10 @@ def add_where(network, target, kwargs, name): y_val = get_trt_tensor(network, y_t, f"{name}_y") else: y_val = y_t - if y_shape != output_shape: + if y_shape != y_t_dim: y_val = add_expand( network, target, - None, {"input": y_val, "sizes": output_shape}, name=f"{name}_y_expand", ) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 00063c3e21..30aeee6944 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -258,6 +258,7 @@ def remove_ops( for n in module.graph.nodes: if n.op == "call_function" and n.target in ( torch.ops.aten._unsafe_view.default, + torch.ops.aten.view.default, ): modified = True node = n diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py index e69de29bb2..da3aa30cb7 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input, input, alpha) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops=torch.ops.aten.rsqrt, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index e69de29bb2..9be23fc419 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsub(self, _, x, alpha): + class rsub(nn.Module): + def forward(self, input): + return torch.rsub(input, input, alpha) + + inputs = [torch.randn(x)] + self.run_test( + rsub(), + inputs, + expected_ops=torch.ops.aten.rsub, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py index 5dd15a89e7..5c655422de 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -10,8 +10,8 @@ class TestSqueezeConverter(DispatchTestCase): [ ("2d_dim", (0), (2, 1)), ("3d_one_dim", (0), (2, 2, 1)), - # ("3d_two_dim", (0, 1), (2, 2, 1)), - # ("4d_dim", (0, 1, 2), (2, 2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), ] ) def test_squeeze(self, _, dim, init_size): diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py index 6c050eee2f..0d4849c21f 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -8,49 +8,56 @@ class TestWhereConverter(DispatchTestCase): @parameterized.expand( [ - ("2d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2)), - ("2d_broadcast_condition_xshape_yshape", (x < 0), (2, 2), (2, 1)), - ("3d_condition_xshape_yshape", (x > 0), (2, 2, 1), (2, 2, 1)), - ("2d_3d_condition_xshape_yshape", (x < 0), (2, 2), (2, 2, 1)), + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), ] ) - def test_(self, _, condition, x_size, y_size): + def test_(self, _, x_size, y_size): class Where(nn.Module): - def forward(self, x): - return torch.where(x, dim) + def forward(self, condition, x, y): + return torch.where(condition, x, y) - inputX = [torch.randn(*x_size)] - inputOther = [torch.randn(*y_size)] - expected_op = {} + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 self.run_test( Where(), - inputs, - expected_ops=torch.ops.aten.where.self, + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, ) +# FIXME: How to specify condition for dynamic shape +# InputTensorSpec like case below where one input is dynamic another is not # class TestWhereConverter(DispatchTestCase): # @parameterized.expand( # [ -# ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), -# ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# ("2d_dim", (-1, 2), [((1, 2), (2, 2), (2, 2))], (2,2)) +# #("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), # #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), # ] # ) -# def test_where(self, _, dim, init_size, shape_range): -# class Squeeze(nn.Module): -# def forward(self, x): -# return torch.squeeze(x, dim) - -# input_specs = [ -# InputTensorSpec( -# shape=init_size, +# def test_where(self, _, x_size, x_size_range, y_size): +# class Where(nn.Module): +# def forward(self, condition, x, y): +# return torch.where(condition, x, y) +# inputX = InputTensorSpec( +# shape=x_size, # dtype=torch.float32, -# shape_ranges=shape_range, -# ), +# shape_ranges=x_size_range, +# ) +# inputOther = torch.randn(*y_size) +# condition = (inputOther < 0) +# input_specs = [ +# inputX, inputOther, condition # ] # self.run_test_with_dynamic_shape( -# Squeeze(), +# Where(), # input_specs, # expected_ops=torch.ops.aten.where.self, # ) + +# if __name__ == "__main__": +# run_tests() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py index 55c5e2df33..e53f0bc64e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -22,7 +22,7 @@ from torch._inductor.decomposition import decompositions DECOMPOSITIONS = decompositions.copy() -MAX_SPLITS_THRESHOLD = 10 +MAX_SPLITS_THRESHOLD = 100 def tensorrt_backend(gm, sample_inputs): From 1d78f436a9a423a8486338ecefd972fd08777f63 Mon Sep 17 00:00:00 2001 From: Michael Feliz <104801882+mfeliz-cruise@users.noreply.github.com> Date: Wed, 19 Apr 2023 15:30:32 -0700 Subject: [PATCH 07/15] feat: Add ts converter support for aten::all.dim (#1840) --- core/conversion/converters/impl/reduce.cpp | 76 +++++++++++++------ .../conversion/converters/test_reduce.cpp | 53 ++++++++++++- 2 files changed, 105 insertions(+), 24 deletions(-) diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 249ae916ef..e3c7498c47 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -9,6 +9,36 @@ namespace converters { namespace impl { namespace { +nvinfer1::ITensor* anyDimImplementation( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* in_tensor, + int dim, + bool keepdim) { + auto in_dims = in_tensor->getDimensions(); + LOG_DEBUG("Dim to reduce (original): " << dim); + dim = dim < 0 ? (in_dims.nbDims + dim) : dim; + LOG_DEBUG("Dim to reduce (converted): " << dim); + + uint32_t axis_mask = 1 << dim; + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); + LOG_DEBUG("Keep dims: " << keepdim); + + // Reduce does not work on bool inputs + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); + } + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); + + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = + castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); + return out_tensor; +} + auto reduce_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED = {"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensorOrFreeze(ctx); - auto in_dims = in_tensor->getDimensions(); auto dim = args[1].unwrapToInt(); - LOG_DEBUG("Dim to reduce (original): " << dim); - dim = dim < 0 ? (in_dims.nbDims + dim) : dim; - LOG_DEBUG("Dim to reduce (converted): " << dim); - - uint32_t axis_mask = 1 << dim; - LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); - auto keepdim = args[2].unwrapToBool(); - LOG_DEBUG("Keep dims: " << keepdim); - - // Reduce does not work on bool inputs - if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { - in_tensor = - castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); - } - auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); - - TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); - - sum_layer->setName(util::node_info(n).c_str()); - auto out_tensor = castITensor( - ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim); out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; + }}) + .pattern( + {"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // use Not(Any(Not(input))) to calculate all without a direct all reduction + auto in_tensor = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto keepdim = args[2].unwrapToBool(); + if (in_tensor->getType() != nvinfer1::DataType::kBOOL) { + // unary not layer only supports bool inputs + in_tensor = castITensor( + ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str()); + } + auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n); + not_input_layer->setName((util::node_info(n) + "_not_in").c_str()); + auto not_in = not_input_layer->getOutput(0); + auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim); + auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; }}); } // namespace } // namespace impl diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 40835a8dea..47e8b8d154 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) { return (%5))IR"; } -void test_body(const std::string& graph, at::Tensor& in) { +void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); @@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) { in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + std::vector trt_results; + if (dynamic) { + trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}); + } else { + trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + } ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } } // namespace @@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenAllDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::ones({2, 32}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf); + test_body(graph, in, true); +} + TEST(Converters, UnpackVarLowersCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): From ce7f122aa47d632120ec7c36e91f359afbe612d8 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 12:33:02 -0700 Subject: [PATCH 08/15] Correcting rsqrt and rsub operator --- .../fx/test/converters/aten_op/test_rsqrt_aten.py | 4 ++-- .../fx/test/converters/aten_op/test_rsub_aten.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py index da3aa30cb7..c80216654c 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase): def test_rsqrt(self, _, x, alpha): class rsqrt(nn.Module): def forward(self, input): - return torch.rsqrt(input, input, alpha) + return torch.rsqrt(input) inputs = [torch.randn(x) + 1] self.run_test( rsqrt(), inputs, - expected_ops=torch.ops.aten.rsqrt, + expected_ops={torch.ops.aten.rsqrt.default}, ) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index 9be23fc419..dddd72f732 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -15,13 +15,13 @@ class TestRSubConverter(DispatchTestCase): def test_rsub(self, _, x, alpha): class rsub(nn.Module): def forward(self, input): - return torch.rsub(input, input, alpha) + return torch.rsub(input, input, alpha = alpha) inputs = [torch.randn(x)] self.run_test( rsub(), inputs, - expected_ops=torch.ops.aten.rsub, + expected_ops={torch.ops.aten.rsub.Tensor}, ) From 30c5fd6e654f0ac3a7025c49d60b24cd8f96df40 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 13:08:34 -0700 Subject: [PATCH 09/15] python linting issues and removing chunk test --- .../fx/converters/aten_ops_converters.py | 25 ++------ py/torch_tensorrt/fx/converters/operator.py | 12 +++- .../converters/aten_op/test_chunk_aten.py | 58 ------------------- .../test/converters/aten_op/test_rsub_aten.py | 2 +- 4 files changed, 15 insertions(+), 82 deletions(-) delete mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index d47f30a790..defa88d18b 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -672,23 +672,6 @@ def aten_ops_squeeze( return add_squeeze(network, target, kwargs_new, name) -# FIXME: need to confirm lower basic passes -# @tensorrt_converter(torch.ops.aten.chunk) -# def aten_ops_chunk( -# network: TRTNetwork, -# target: Target, -# args: Tuple[Argument, ...], -# kwargs: Dict[str, Argument], -# name: str, -# ) -> Union[TRTTensor, Sequence[TRTTensor]]: -# kwargs_new = { -# "input": args[0], -# "chunks": args[1], -# "dim": args[2], -# } -# return add_chunk(network, target, kwargs_new, name) - - @tensorrt_converter(torch.ops.aten.where.self) def aten_ops_where( network: TRTNetwork, @@ -705,7 +688,7 @@ def aten_ops_where( return add_where(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.rsub) +@tensorrt_converter(torch.ops.aten.rsub.Tensor) def aten_ops_rsub( network: TRTNetwork, target: Target, @@ -713,15 +696,17 @@ def aten_ops_rsub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + if "alpha" in kwargs: + alpha = kwargs["alpha"] kwargs_new = { "input": args[0], "other": args[1], - "alpha": args[2], + "alpha": alpha, } return add_rsub(network, target, kwargs_new, name) -@tensorrt_converter(torch.ops.aten.rsqrt) +@tensorrt_converter(torch.ops.aten.rsqrt.default) def aten_ops_rsqrt( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 1e53b1ccc5..ffd6a1bab5 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -1526,7 +1526,13 @@ def add_scale(network, target, kwargs, name): def add_rsub(network, target, kwargs, name): - scaled_tensor = add_scale(network, target, kwargs, name) + kwargs_new = {} + if "alpha" in kwargs: + kwargs_new["input"] = kwargs["other"] + kwargs_new["other"] = kwargs["alpha"] + scaled_tensor = add_mul(network, target, kwargs_new, name + "_mul") + else: + scaled_tensor = kwargs["other"] input = kwargs["input"] return add_binary_elementwise_layer( network, @@ -1534,7 +1540,7 @@ def add_rsub(network, target, kwargs, name): scaled_tensor, trt.ElementWiseOperation.SUB, target, - name, + name + "_sub", ) @@ -1546,7 +1552,7 @@ def add_sqrt(network, target, kwargs, name): def add_rsqrt(network, target, kwargs, name): sqrt_trt = add_sqrt(network, target, kwargs, name) - div_trt = add_binary_elementwise_layer( + return add_binary_elementwise_layer( network, 1, sqrt_trt, diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py deleted file mode 100644 index 8fae6da293..0000000000 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_chunk_aten.py +++ /dev/null @@ -1,58 +0,0 @@ -import unittest - -import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -from parameterized import param, parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec - - -class TestSelectConverterImplicitBatch(DispatchTestCase): - @parameterized.expand( - [ - ("select_chunk_dim", 6, 0), - ] - ) - def test_chunk(self, _, chunk, dim): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.ops.aten.chunk(input, chunk, dim) - return out - - input = [torch.randn(11)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.chunk}, - ) - - -class TestSelectConverterExplicitBatch(DispatchTestCase): - @parameterized.expand( - [ - ("select_chunk_dim", 6, 0), - ] - ) - def test_chunk(self, _, chunk, dim): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, input): - out = torch.ops.aten.chunk(input, chunk, dim) - return out - - input = [torch.randn(12)] - self.run_test( - TestModule(), - input, - expected_ops={torch.ops.aten.chunk}, - test_explicit_precision=True, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py index dddd72f732..268df8ccfd 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -15,7 +15,7 @@ class TestRSubConverter(DispatchTestCase): def test_rsub(self, _, x, alpha): class rsub(nn.Module): def forward(self, input): - return torch.rsub(input, input, alpha = alpha) + return torch.rsub(input, input, alpha=alpha) inputs = [torch.randn(x)] self.run_test( From 7ab071d91cc09281d8d518ce4f0dd406c6537955 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 16:00:48 -0700 Subject: [PATCH 10/15] Correcting acc squeeze test --- .../fx/test/converters/acc_op/test_squeeze.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index d265def896..c9b4776dd3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -12,7 +12,12 @@ def forward(self, x): return x.squeeze(2) inputs = [torch.randn(1, 2, 1)] - self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + self.run_test( + Squeeze(), + inputs, + expected_ops={acc_ops.squeeze}, + test_implicit_batch_dim=False, + ) # Testing with shape=(-1, -1, -1, -1) results in error: # AssertionError: We don't support squeeze dynamic dim. From 36ac0cf341286865cdea67c87fac2f3f9cf8b8b9 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 17:23:15 -0700 Subject: [PATCH 11/15] test_reshape expected ops aten.reshape since aten.view has been removed in lowering --- .../fx/test/converters/aten_op/test_reshape_aten.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 538e575d6e..385ec05b8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -31,7 +31,7 @@ def forward(self, x): self.run_test( TestModule(target_shape), inputs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) @parameterized.expand( @@ -64,7 +64,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(target_shape), input_specs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) @unittest.skipIf( @@ -94,7 +94,7 @@ def forward(self, x, y): self.run_test_with_dynamic_shape( TestModule(), input_specs, - expected_ops={torch.ops.aten.view.default}, + expected_ops={torch.ops.aten.reshape}, ) From eb851b19880dbc00acf6c69e78dd509e87bd1e81 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 21:43:07 -0700 Subject: [PATCH 12/15] removing aten.view in lowering pass --- py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py | 1 - .../fx/test/converters/aten_op/test_reshape_aten.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 0d6b1c28de..6790962621 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -258,7 +258,6 @@ def remove_ops( for n in module.graph.nodes: if n.op == "call_function" and n.target in ( torch.ops.aten._unsafe_view.default, - torch.ops.aten.view.default, ): modified = True node = n diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 385ec05b8b..538e575d6e 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -31,7 +31,7 @@ def forward(self, x): self.run_test( TestModule(target_shape), inputs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) @parameterized.expand( @@ -64,7 +64,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(target_shape), input_specs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) @unittest.skipIf( @@ -94,7 +94,7 @@ def forward(self, x, y): self.run_test_with_dynamic_shape( TestModule(), input_specs, - expected_ops={torch.ops.aten.reshape}, + expected_ops={torch.ops.aten.view.default}, ) From 6b234e0f34a9a27851eb438d70327c316976368e Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 22:47:12 -0700 Subject: [PATCH 13/15] layer_norm test --- .../aten_op/test_layer_norm_aten.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index cf97e828d0..e204f4ec8b 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -19,26 +19,26 @@ def forward(self, x): ) -def test_layernorm_with_dynamic_shape(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.ln = torch.nn.LayerNorm([3, 224, 224]) - - def forward(self, x): - return self.ln(x) - - input_specs = [ - InputTensorSpec( - shape=(-1, 3, 224, 224), - dtype=torch.float32, - shape_ranges=[(1, 3, 1, 1)], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} - ) + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[(1, 3, 1, 1)], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) if __name__ == "__main__": From 95c1adab0143bf7d2f1aeb99988bde00f54a6be4 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 20 Apr 2023 22:49:21 -0700 Subject: [PATCH 14/15] correcting linting error --- .../fx/test/converters/aten_op/test_layer_norm_aten.py | 1 - 1 file changed, 1 deletion(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index e204f4ec8b..6662d91b9a 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -18,7 +18,6 @@ def forward(self, x): TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} ) - def test_layernorm_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): From 1a1b809b7b2f90043cbcab0318141f6302057021 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 21 Apr 2023 05:06:07 -0700 Subject: [PATCH 15/15] correcting dynamic shape layer norm --- .../fx/test/converters/aten_op/test_layer_norm_aten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py index 6662d91b9a..fab398ac0f 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -31,12 +31,12 @@ def forward(self, x): InputTensorSpec( shape=(-1, 3, 224, 224), dtype=torch.float32, - shape_ranges=[(1, 3, 1, 1)], + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], ), ] self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} )