diff --git a/core/conversion/converters/impl/reduce.cpp b/core/conversion/converters/impl/reduce.cpp index 249ae916ef..e3c7498c47 100644 --- a/core/conversion/converters/impl/reduce.cpp +++ b/core/conversion/converters/impl/reduce.cpp @@ -9,6 +9,36 @@ namespace converters { namespace impl { namespace { +nvinfer1::ITensor* anyDimImplementation( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* in_tensor, + int dim, + bool keepdim) { + auto in_dims = in_tensor->getDimensions(); + LOG_DEBUG("Dim to reduce (original): " << dim); + dim = dim < 0 ? (in_dims.nbDims + dim) : dim; + LOG_DEBUG("Dim to reduce (converted): " << dim); + + uint32_t axis_mask = 1 << dim; + LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); + LOG_DEBUG("Keep dims: " << keepdim); + + // Reduce does not work on bool inputs + if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { + in_tensor = castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); + } + auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); + + TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); + + sum_layer->setName(util::node_info(n).c_str()); + auto out_tensor = + castITensor(ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); + return out_tensor; +} + auto reduce_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern( @@ -224,33 +254,35 @@ auto reduce_registrations TORCHTRT_UNUSED = {"aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in_tensor = args[0].ITensorOrFreeze(ctx); - auto in_dims = in_tensor->getDimensions(); auto dim = args[1].unwrapToInt(); - LOG_DEBUG("Dim to reduce (original): " << dim); - dim = dim < 0 ? (in_dims.nbDims + dim) : dim; - LOG_DEBUG("Dim to reduce (converted): " << dim); - - uint32_t axis_mask = 1 << dim; - LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask)); - auto keepdim = args[2].unwrapToBool(); - LOG_DEBUG("Keep dims: " << keepdim); - - // Reduce does not work on bool inputs - if (in_tensor->getType() == nvinfer1::DataType::kBOOL) { - in_tensor = - castITensor(ctx, in_tensor, nvinfer1::DataType::kINT32, (util::node_info(n) + "_in").c_str()); - } - auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim); - - TORCHTRT_CHECK(sum_layer, "Unable to create sum layer from node: " << *n); - - sum_layer->setName(util::node_info(n).c_str()); - auto out_tensor = castITensor( - ctx, sum_layer->getOutput(0), nvinfer1::DataType::kBOOL, (util::node_info(n) + "_out").c_str()); + auto out_tensor = anyDimImplementation(ctx, n, in_tensor, dim, keepdim); out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; + }}) + .pattern( + {"aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // use Not(Any(Not(input))) to calculate all without a direct all reduction + auto in_tensor = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto keepdim = args[2].unwrapToBool(); + if (in_tensor->getType() != nvinfer1::DataType::kBOOL) { + // unary not layer only supports bool inputs + in_tensor = castITensor( + ctx, in_tensor, nvinfer1::DataType::kBOOL, (util::node_info(n) + "_in_to_bool").c_str()); + } + auto not_input_layer = ctx->net->addUnary(*in_tensor, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_input_layer, "Unable to create logical_not layer from node: " << *n); + not_input_layer->setName((util::node_info(n) + "_not_in").c_str()); + auto not_in = not_input_layer->getOutput(0); + auto any_out = anyDimImplementation(ctx, n, not_in, dim, keepdim); + auto not_output_layer = ctx->net->addUnary(*any_out, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(not_output_layer, "Unable to create logical_not layer from node: " << *n); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], not_output_layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; }}); } // namespace } // namespace impl diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index e556e81bb5..a321bb8dfe 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -678,7 +678,13 @@ def acc_ops_batch_norm( @tensorrt_converter(acc_ops.layer_norm) -def acc_ops_layer_norm(network, target, args, kwargs, name): +def acc_ops_layer_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: return add_layer_norm(network, target, kwargs, name) @@ -690,37 +696,7 @@ def acc_ops_softmax( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"softmax received input {input_val} that is not part " - "of the TensorRT region!" - ) - - # Used to get dim when dim is None. Copied from PyTorch softmax implementation. - def get_softmax_dim(ndim: int) -> int: - if ndim == 0 or ndim == 1 or ndim == 3: - ret = 0 - else: - ret = 1 - return ret - - if kwargs["dim"] is None: - dim = get_softmax_dim(input_ranks) - else: - dim = cast(int, kwargs["dim"]) - - dim = get_positive_dim(dim, input_ranks) - if network.has_implicit_batch_dimension: - assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." - dim -= 1 - - layer = network.add_softmax(input_val) - layer.axes = 1 << dim - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_softmax(network, target, kwargs, name) @tensorrt_converter(acc_ops.tile) @@ -956,9 +932,7 @@ def acc_ops_sqrt( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return add_sqrt(network, target, kwargs, name) @tensorrt_converter(acc_ops.reciprocal) @@ -1619,40 +1593,7 @@ def acc_ops_squeeze( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"squeeze received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) - # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic - # dim, which is a very rare case. For now we just claim not supporting dim=None. - assert dim is not None, "We don't support dim=None right now for squeeze." - - dim = get_positive_dim( - dim, len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - ) - if network.has_implicit_batch_dimension: - assert dim != 0, "We don't support squeeze batch dim when it's implicit." - dim -= 1 - - assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input_val.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." - - output_shape = [] - for i, s in enumerate(input_val.shape): - if i == dim and s == 1: - continue - output_shape.append(s) - layer = network.add_shuffle(input_val) - layer.reshape_dims = tuple(output_shape) - set_layer_name(layer, target, name) - return layer.get_output(0) + return add_squeeze(network, target, kwargs, name) @tensorrt_converter(acc_ops.add) @@ -2022,89 +1963,7 @@ def acc_ops_where( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - - condition_t = kwargs["condition"] - x_t = kwargs["x"] - y_t = kwargs["y"] - - if type(x_t) != TRTTensor: - assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" - - if type(y_t) != TRTTensor: - assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" - - # get output shape - - x_shape = list(x_t.shape) - y_shape = list(y_t.shape) - condition_shape = list(condition_t.shape) - output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) - - # expand shape - if type(condition_t) != TRTTensor: - assert condition_t.dtype == torch.bool, "condition dtype is not bool" - if condition_shape != output_shape: - condition_t.expand(output_shape) - condition_t = condition_t.to(torch.int32) - condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") - condition_layer = network.add_identity(condition_const) - condition_layer.set_output_type(0, trt.bool) - set_layer_name(condition_layer, target, f"{name}_condition") - condition_val = condition_layer.get_output(0) - else: - assert condition_t.dtype == trt.bool, "mask dtype is not bool!" - if condition_shape != output_shape: - condition_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": condition_t, "sizes": output_shape}, - name=f"{name}_expand", - ) - else: - condition_val = condition_t - - if type(x_t) != TRTTensor: - if x_shape != output_shape: - # special case where 1 element in x_t - if len(x_t.shape) == 0: - x_t = x_t.unsqueeze(0) - x_t = x_t.expand(output_shape) - x_val = get_trt_tensor(network, x_t, f"{name}_x") - else: - x_val = x_t - if x_shape != output_shape: - x_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": x_val, "sizes": output_shape}, - name=f"{name}_x_expand", - ) - - if type(y_t) != TRTTensor: - if y_shape != output_shape: - # special case where 1 element in y_t - if len(y_t.shape) == 0: - y_t = y_t.unsqueeze(0) - y_t = y_t.expand(output_shape) - y_val = get_trt_tensor(network, y_t, f"{name}_y") - else: - y_val = y_t - if y_shape != output_shape: - y_val = acc_ops_expand_tensor( - network, - target, - None, - {"input": y_val, "sizes": output_shape}, - name=f"{name}_y_expand", - ) - - select_layer = network.add_select(condition_val, x_val, y_val) - - set_layer_name(select_layer, target, f"{name}_select") - - return select_layer.get_output(0) + return add_where(network, target, kwargs, name) @tensorrt_converter(acc_ops.masked_fill, no_implicit_batch_dim=True) @@ -2721,62 +2580,7 @@ def acc_ops_chunk( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - input_val = kwargs["input"] - chunks = cast(int, kwargs["chunks"]) - dim = cast(int, kwargs["dim"]) - input_dim_size = len(input_val.shape) # type: ignore[union-attr] - - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"chunk received input {input_val} that is not part " - "of the TensorRT region!" - ) - - dynamic_shape = has_dynamic_shape(input_val.shape) - if network.has_implicit_batch_dimension: - input_dim_size += 1 - dim = get_positive_dim(dim, input_dim_size) - assert dim != 0, "Can't chunk on batch dim when it's implicit!" - dim -= 1 - else: - if dynamic_shape: - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" - dim = get_positive_dim(dim, input_dim_size) - - if chunks > input_val.shape[dim]: - warnings.warn( - f"Asked for {chunks} chunks along dimention " - f"{dim} on tensor with size {input_val.shape}, chunks " - f"will default to {input_val.shape[dim]}", - RuntimeWarning, - ) - chunks = input_val.shape[dim] - - start = [0] * len(input_val.shape) - stride = [1] * len(start) - offset = 0 - split_size = (input_val.shape[dim] + chunks - 1) // chunks - - max_offset = input_val.shape[dim] - # add slice layers - output = [] - for i in range(chunks): - shape = list(input_val.shape) - shape[dim] = min(split_size, max_offset - offset) - if dynamic_shape: - shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" - ) - start[dim] = offset - layer = network.add_slice( - input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride - ) - if dynamic_shape: - layer.set_input(2, shape) - offset += split_size - set_layer_name(layer, target, f"{name}_{i}") - output.append(layer.get_output(0)) - return output + return add_chunk(network, target, kwargs, name) @tensorrt_converter(acc_ops.cumsum, no_implicit_batch_dim=True) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 1dbfa14076..defa88d18b 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -620,3 +620,101 @@ def aten_ops_matmul( "other": args[1], } return add_matmul(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "normalized_shape": args[1], + "weight": args[2], + "bias": args[3], + "eps": args[4], + } + return add_layer_norm(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_softmax(network, target, kwargs_new, name) + + +# FIXME: need to look at case where dim is tuple +@tensorrt_converter(torch.ops.aten.squeeze.dim) +@tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "dim": args[1], + } + return add_squeeze(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "condition": args[0], + "x": args[1], + "y": args[2], + } + return add_where(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsub.Tensor) +def aten_ops_rsub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if "alpha" in kwargs: + alpha = kwargs["alpha"] + kwargs_new = { + "input": args[0], + "other": args[1], + "alpha": alpha, + } + return add_rsub(network, target, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return add_rsqrt(network, target, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/operator.py b/py/torch_tensorrt/fx/converters/operator.py index 9487894506..ffd6a1bab5 100644 --- a/py/torch_tensorrt/fx/converters/operator.py +++ b/py/torch_tensorrt/fx/converters/operator.py @@ -582,7 +582,7 @@ def layer_norm( set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] - sub_trt = operator.add_binary_elementwise_layer( + sub_trt = add_binary_elementwise_layer( network, input_val, mean_expected_layer.get_output(0), @@ -596,7 +596,7 @@ def layer_norm( trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" - pow_var = operator.add_binary_elementwise_layer( + pow_var = add_binary_elementwise_layer( network, sub_trt, pow_tensor.get_output(0), @@ -741,6 +741,7 @@ def add_layer_norm(network, target, kwargs, name): _LOGGER.error( "Unable to find layer norm plugin, fall back to TensorRT implementation." ) + args = [] return layer_norm(network, target, args, kwargs, name) layer = network.add_plugin_v2([input_val], plugin) layer.name = name @@ -1130,7 +1131,6 @@ def add_expand(network, target, kwargs, name): ranks = len(shape) shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)] - inshape = tuple(input_val.shape) shape = tuple(shape) start = tuple([0] * ranks) @@ -1266,3 +1266,297 @@ def add_matmul(network, target, kwargs, name): ) set_layer_name(layer, target, name) return layer.get_output(0) + + +def add_softmax(network, target, kwargs, name): + input_val = kwargs["input"] + input_ranks = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"softmax received input {input_val} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if kwargs["dim"] is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, kwargs["dim"]) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input_val) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_squeeze(network, target, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"squeeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + dims = [] + if "dim" in kwargs: + if isinstance(kwargs["dim"], int): + dims.append(cast(Optional[int], kwargs["dim"])) + else: + for dim in kwargs["dim"]: + dims.append(cast(Optional[int], dim)) + + # dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + + for dim in dims: + dim = get_positive_dim( + dim, + len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input_val.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input_val.shape): + if (i in dims) and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_chunk(network, target, kwargs, name): + input_val = kwargs["input"] + chunks = cast(int, kwargs["chunks"]) + dim = cast(int, kwargs["dim"]) + input_dim_size = len(input_val.shape) # type: ignore[union-attr] + + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"chunk received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dynamic_shape = has_dynamic_shape(input_val.shape) + if network.has_implicit_batch_dimension: + input_dim_size += 1 + dim = get_positive_dim(dim, input_dim_size) + assert dim != 0, "Can't chunk on batch dim when it's implicit!" + dim -= 1 + else: + if dynamic_shape: + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + dim = get_positive_dim(dim, input_dim_size) + + if chunks > input_val.shape[dim]: + warnings.warn( + f"Asked for {chunks} chunks along dimention " + f"{dim} on tensor with size {input_val.shape}, chunks " + f"will default to {input_val.shape[dim]}", + RuntimeWarning, + ) + chunks = input_val.shape[dim] + + start = [0] * len(input_val.shape) + stride = [1] * len(start) + offset = 0 + split_size = (input_val.shape[dim] + chunks - 1) // chunks + + max_offset = input_val.shape[dim] + # add slice layers + output = [] + for i in range(chunks): + shape = list(input_val.shape) + shape[dim] = min(split_size, max_offset - offset) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, shape, input_val, target, f"{name}_{i}" + ) + start[dim] = offset + layer = network.add_slice( + input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride + ) + if dynamic_shape: + layer.set_input(2, shape) + offset += split_size + set_layer_name(layer, target, f"{name}_{i}") + output.append(layer.get_output(0)) + return output + + +def add_where(network, target, kwargs, name): + condition_t = kwargs["condition"] + x_t = kwargs["x"] + y_t = kwargs["y"] + + x_t_dim = len(tuple(x_t.shape)) + y_t_dim = len(tuple(y_t.shape)) + condition_t_dim = len(tuple(condition_t.shape)) + + if type(x_t) != TRTTensor: + assert type(x_t) is torch.Tensor, f"value {x_t} is not torch.Tensor!" + + if type(y_t) != TRTTensor: + assert type(y_t) is torch.Tensor, f"value {y_t} is not torch.Tensor!" + + if not (broadcastable(x_t, y_t)): + assert f"The two torch tensors should be broadcastable" + + # get output shape + # purpose of this is to bring x_t and y_t rank same as + # output_shape to input it to the add_expand operation + # condition_t will have dimension of either x_t or y_t + x_t, y_t = broadcast(network, x_t, y_t, f"{name}_x", f"{name}_y") + if len(tuple(condition_t.shape)) != len(tuple(x_t.shape)): + condition_t, x_t = broadcast( + network, condition_t, x_t, f"{name}_condition", f"{name}_x" + ) + + print("x_t shape", x_t.shape) + print("y_t shape", y_t.shape) + print("condition_t shape", condition_t.shape) + x_shape = list(x_t.shape) + y_shape = list(y_t.shape) + condition_shape = list(condition_t.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition_t) != TRTTensor: + assert condition_t.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition_t.expand(output_shape) + condition_t = condition_t.to(torch.int32) + condition_const = get_trt_tensor(network, condition_t, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition_t.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != condition_t_dim: + condition_val = add_expand( + network, + target, + {"input": condition_t, "sizes": output_shape}, + name=f"{name}_expand", + ) + else: + condition_val = condition_t + + if type(x_t) != TRTTensor: + if x_shape != x_t_dim: + # special case where 1 element in x_t + if len(x_t.shape) == 0: + x_t = x_t.unsqueeze(0) + x_t = x_t.expand(output_shape) + x_val = get_trt_tensor(network, x_t, f"{name}_x") + else: + x_val = x_t + if x_shape != output_shape: + x_val = add_expand( + network, + target, + {"input": x_val, "sizes": output_shape}, + name=f"{name}_x_expand", + ) + + if type(y_t) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in y_t + if len(y_t.shape) == 0: + y_t = y_t.unsqueeze(0) + y_t = y_t.expand(output_shape) + y_val = get_trt_tensor(network, y_t, f"{name}_y") + else: + y_val = y_t + if y_shape != y_t_dim: + y_val = add_expand( + network, + target, + {"input": y_val, "sizes": output_shape}, + name=f"{name}_y_expand", + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) + + +def add_scale(network, target, kwargs, name): + other = kwargs["other"] + scale = kwargs["scale"] + if isinstance(other, TRTTensor): + other_dtype = torch_dtype_from_trt(other.dtype) + is_other_trt_tensor = True + + if not is_other_trt_tensor: + warnings.warn( + f"The value to be scaled is constant" + "In this case, please consider constant fold the model first." + ) + return other * scale + layer = network.add_scale(other, trt.ScaleMode.UNIFORM, 0, scale, 1) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +def add_rsub(network, target, kwargs, name): + kwargs_new = {} + if "alpha" in kwargs: + kwargs_new["input"] = kwargs["other"] + kwargs_new["other"] = kwargs["alpha"] + scaled_tensor = add_mul(network, target, kwargs_new, name + "_mul") + else: + scaled_tensor = kwargs["other"] + input = kwargs["input"] + return add_binary_elementwise_layer( + network, + kwargs["input"], + scaled_tensor, + trt.ElementWiseOperation.SUB, + target, + name + "_sub", + ) + + +def add_sqrt(network, target, kwargs, name): + input_val = kwargs["input"] + operation_type = trt.UnaryOperation.SQRT + return add_unary_layer(network, input_val, operation_type, target, name) + + +def add_rsqrt(network, target, kwargs, name): + sqrt_trt = add_sqrt(network, target, kwargs, name) + return add_binary_elementwise_layer( + network, + 1, + sqrt_trt, + trt.ElementWiseOperation.DIV, + target, + f"{name}_div_trt", + ) diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py index d265def896..c9b4776dd3 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_squeeze.py @@ -12,7 +12,12 @@ def forward(self, x): return x.squeeze(2) inputs = [torch.randn(1, 2, 1)] - self.run_test(Squeeze(), inputs, expected_ops={acc_ops.squeeze}) + self.run_test( + Squeeze(), + inputs, + expected_ops={acc_ops.squeeze}, + test_implicit_batch_dim=False, + ) # Testing with shape=(-1, -1, -1, -1) results in error: # AssertionError: We don't support squeeze dynamic dim. diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py new file mode 100644 index 0000000000..fab398ac0f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_layer_norm_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..c80216654c --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py new file mode 100644 index 0000000000..268df8ccfd --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsub_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSubConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsub(self, _, x, alpha): + class rsub(nn.Module): + def forward(self, input): + return torch.rsub(input, input, alpha=alpha) + + inputs = [torch.randn(x)] + self.run_test( + rsub(), + inputs, + expected_ops={torch.ops.aten.rsub.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py new file mode 100644 index 0000000000..31e293fc91 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py new file mode 100644 index 0000000000..5c655422de --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_squeeze_aten.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + InputTensorSpec( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py new file mode 100644 index 0000000000..0d4849c21f --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_where_aten.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), + ] + ) + def test_(self, _, x_size, y_size): + class Where(nn.Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 + self.run_test( + Where(), + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, + ) + + +# FIXME: How to specify condition for dynamic shape +# InputTensorSpec like case below where one input is dynamic another is not +# class TestWhereConverter(DispatchTestCase): +# @parameterized.expand( +# [ +# ("2d_dim", (-1, 2), [((1, 2), (2, 2), (2, 2))], (2,2)) +# #("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), +# #("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), +# ] +# ) +# def test_where(self, _, x_size, x_size_range, y_size): +# class Where(nn.Module): +# def forward(self, condition, x, y): +# return torch.where(condition, x, y) +# inputX = InputTensorSpec( +# shape=x_size, +# dtype=torch.float32, +# shape_ranges=x_size_range, +# ) +# inputOther = torch.randn(*y_size) +# condition = (inputOther < 0) +# input_specs = [ +# inputX, inputOther, condition +# ] +# self.run_test_with_dynamic_shape( +# Where(), +# input_specs, +# expected_ops=torch.ops.aten.where.self, +# ) + +# if __name__ == "__main__": +# run_tests() diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index e60c8f8d13..356ddc978e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -130,7 +130,7 @@ def trace(f, args, *rest): @req_torch_version("2.dev") -def opt_trace(f, args, *rest): +def opt_trace(f, args, perform_trace=True, *rest): """ Optimized trace with necessary passes which re-compose some ops or replace some ops These passes should be general and functional purpose @@ -148,7 +148,11 @@ def opt_trace(f, args, *rest): replace_inplace_ops, # remove it once functionalization is enabled ] - fx_module, _ = trace(f, args) + if perform_trace: + fx_module, _ = trace(f, args) + else: + fx_module = f + print(fx_module.graph) for passes in passes_list: pr: PassResult = passes(fx_module) diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py new file mode 100644 index 0000000000..e53f0bc64e --- /dev/null +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tensorrt_dynamo_backend.py @@ -0,0 +1,114 @@ +import torch +import traceback +import torch._dynamo as td + +from torch_tensorrt.fx.fx2trt import ( + InputTensorSpec, + TRTInterpreter, +) +import tensorrt as trt +from torch_tensorrt.fx.tools.trt_splitter import ( + TRTSplitter, + TRTSplitterSetting, +) +from torch_tensorrt.fx.tracer.dispatch_tracer import aten_tracer +from torch_tensorrt.fx.trt_module import TRTModule +from torch_tensorrt.fx.utils import LowerPrecision + +from torch._dynamo.backends.common import fake_tensor_unsupported + +from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler + +from torch._inductor.decomposition import decompositions + +DECOMPOSITIONS = decompositions.copy() +MAX_SPLITS_THRESHOLD = 100 + + +def tensorrt_backend(gm, sample_inputs): + # Invoke AOTAutograd to compile model + return aot_module_simplified( + gm, + sample_inputs, + fw_compiler=make_boxed_compiler(fx2trt_compiler), + decompositions=DECOMPOSITIONS, + ) + + +def fx2trt(gm: torch.fx.GraphModule, example_inputs, **kwargs): + model = gm + inputs = example_inputs + + # Perform lowering pass on model + model = aten_tracer.opt_trace(model, inputs, perform_trace=False) + + # Split out unsupported ops --> Needs rewrite/revision for ATEN + splitter_setting = TRTSplitterSetting() + splitter_setting.use_implicit_batch_dim = False + splitter = TRTSplitter(model, inputs, settings=splitter_setting) + + splitter.node_support_preview() + split_mod = splitter() + num_pieces = 0 + + for name, _ in split_mod.named_children(): + print(f"Graph is split into {name}") + num_pieces += 1 + + # Select threshold above which segmentation is not beneficial and run graph in Torch + if num_pieces > MAX_SPLITS_THRESHOLD: + raise AssertionError( + f"The graph module is split into {num_pieces} which is large than the \ + threshold={MAX_SPLITS_THRESHOLD}. Falling back to non-TRT module." + ) + + precision = LowerPrecision.FP32 + + def get_submod_inputs(mod, submod, inputs): + acc_inputs = None + + def get_input(self, inputs): + nonlocal acc_inputs + acc_inputs = inputs + + handle = submod.register_forward_pre_hook(get_input) + mod(*inputs) + handle.remove() + return acc_inputs + + for name, _ in split_mod.named_children(): + if "_run_on_acc" in name: + submod = getattr(split_mod, name) + acc_inputs = get_submod_inputs(split_mod, submod, inputs) + + interp = TRTInterpreter( + submod, + InputTensorSpec.from_tensors(acc_inputs), + explicit_batch_dimension=True, + logger_level=trt.Logger.VERBOSE, + ) + r = interp.run( + max_workspace_size=20 << 30, + lower_precision=precision, + profiling_verbosity=trt.ProfilingVerbosity.VERBOSE, + ) + + trt_mod = TRTModule(*r) + + setattr(split_mod, name, trt_mod) + + return split_mod + + +@td.register_backend +@fake_tensor_unsupported +def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): + try: + trt_compiled = fx2trt(gm, example_inputs) + return trt_compiled + except Exception: + traceback.print_exc() + print( + "FX2TRT conversion failed on the subgraph. See trace above. Returning GraphModule forward instead" + ) + return gm.forward diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index 40835a8dea..47e8b8d154 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -62,7 +62,7 @@ std::string gen_keepdim_graph(const std::string& op) { return (%5))IR"; } -void test_body(const std::string& graph, at::Tensor& in) { +void test_body(const std::string& graph, at::Tensor& in, bool dynamic = false) { auto g = std::make_shared(); torch::jit::parseIR(graph, g.get()); @@ -71,7 +71,12 @@ void test_body(const std::string& graph, at::Tensor& in) { in = at::clone(in); params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + std::vector trt_results; + if (dynamic) { + trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}); + } else { + trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + } ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } } // namespace @@ -344,6 +349,50 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) { test_body(graph, in); } +TEST(Converters, ATenAllDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimKeepDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=1]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(-2, 2, {2, 32}, at::kCUDA).to(torch::kBool); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimAllTrueConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::ones({2, 32}, at::kCUDA); + test_body(graph, in); +} + +TEST(Converters, ATenAllDimDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=-1]() + %3 : bool = prim::Constant[value=0]() + %5 : Tensor = aten::all(%0, %1, %3) + return (%5))IR"; + auto in = at::randint(0, 2, {64, 2}, at::kCUDA).to(torch::kHalf); + test_body(graph, in, true); +} + TEST(Converters, UnpackVarLowersCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):