diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 42a57273af4e..3d89eb8a3ff0 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -373,6 +373,9 @@ def create_convert_map( "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul), "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne), "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne), + "outer.default": lambda node: self.block_builder.emit( + relax.op.outer(self.env[node.args[0]], self.env[node.args[1]]) + ), "pow.Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow), "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 199e58cb1d9f..a10dbd6e02d1 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -749,6 +749,9 @@ def create_convert_map( "mod": self._binary_op(relax.op.floor_mod, operator.mod), "mul": self._binary_op(relax.op.multiply, operator.mul), "ne": self._binary_op(relax.op.not_equal, operator.ne), + "outer": lambda node: self.block_builder.emit( + relax.op.outer(self.env[node.args[0]], self.env[node.args[1]]) + ), "pow": self._binary_op(relax.op.power, operator.pow), "or_": self._binary_op(relax.op.bitwise_or, operator.or_), "rshift": self._binary_op(relax.op.right_shift, operator.rshift), diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index bfc0a997dfc8..0a2f0980fd08 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -84,7 +84,7 @@ ) from .datatype import astype, wrap_param from .index import dynamic_strided_slice, strided_slice, take -from .linear_algebra import einsum, linear, matmul +from .linear_algebra import einsum, linear, matmul, outer from .manipulate import ( broadcast_to, collapse_sum_like, diff --git a/python/tvm/relax/op/linear_algebra.py b/python/tvm/relax/op/linear_algebra.py index efb5085c7882..9b091195763e 100644 --- a/python/tvm/relax/op/linear_algebra.py +++ b/python/tvm/relax/op/linear_algebra.py @@ -110,3 +110,30 @@ def einsum(operands, subscripts): operands = RxTuple(operands) return _ffi_api.einsum(operands, subscripts) # type: ignore + + +def outer(x1: Expr, x2: Expr) -> Expr: + """ + Computes the outer product of two input expressions. + + Parameters + ---------- + x1 : relax.Expr + The first input expression. + + x2 : relax.Expr + The second input expression. + + Notes + ----- + This operation computes the outer product between two expressions, + resulting in a tensor where each element is the product of elements + from `x1` and `x2`. It is commonly used in tensor and matrix operations + to expand lower-dimensional inputs into higher-dimensional representations. + + Returns + ------- + result : relax.Expr + The resulting expression representing the outer product. + """ + return _ffi_api.outer(x1, x2) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index 318c9521f31a..154afa9dffca 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -115,3 +115,22 @@ def _einsum(bb: BlockBuilder, call: Call) -> Expr: t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] ) return bb.call_te(topi.einsum, call.attrs.subscripts, *fields) + + +@register_legalize("relax.outer") +def _outer(bb: BlockBuilder, call: Call) -> Expr: + def te_outer(a: te.Tensor, b: te.Tensor) -> te.Tensor: + a_shape = list(a.shape) + b_shape = list(b.shape) + assert len(a_shape) == 1 and len(b_shape) == 1, "outer requires 1D tensors" + + n = a_shape[0] + m = b_shape[0] + + def compute_fn(i, j): + return a[i] * b[j] + + return te.compute((n, m), compute_fn, name="outer") + + lhs, rhs = call.args + return bb.call_te(te_outer, lhs, rhs, primfunc_name_hint="outer") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index d1e86cc7f456..b696d73031b9 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ ones, ones_like, one_hot, + outer, permute_dims, power, print, @@ -826,6 +827,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "one_hot", "opencl", "output", + "outer", "permute_dims", "power", "prim_value", diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 0fdbee1c6aac..4ca42bffec90 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -251,5 +251,43 @@ TVM_REGISTER_OP("relax.einsum") .set_attr("FInferStructInfo", InferStructInfoEinsum) .set_attr("FPurity", Bool(true)); +/* relax.outer */ + +Expr outer(Expr x1, Expr x2) { + static const Op& op = Op::Get("relax.outer"); + return Call(op, {std::move(x1), std::move(x2)}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer); + +StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) { + auto input_sinfo = GetInputTensorStructInfo(call, ctx); + auto x1_sinfo = input_sinfo[0]; + auto x2_sinfo = input_sinfo[1]; + + // Ensure both inputs are 1D tensors + if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "torch.outer requires both inputs to be 1D tensors."); + } + + // Determine output shape + auto x1_shape = x1_sinfo->shape.as(); + auto x2_shape = x2_sinfo->shape.as(); + if (!x1_shape || !x2_shape) { + return TensorStructInfo(x1_sinfo->dtype, 2); + } + Array output_shape = {x1_shape->values[0], x2_shape->values[0]}; + return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.outer") + .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOuter) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) + .set_attr("FPurity", Bool(true)); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h index 638e5af8f87e..eb003fed1c76 100644 --- a/src/relax/op/tensor/linear_algebra.h +++ b/src/relax/op/tensor/linear_algebra.h @@ -51,6 +51,14 @@ Expr matmul(Expr x1, Expr x2, Optional out_dtype); */ Expr einsum(Expr operands, String subscripts); +/*! + * \brief Compute the outer product of two input expressions. + * \param x1 The first input expression. + * \param x2 The second input expression. + * \return The resulting expression representing the outer product. + */ +Expr outer(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index c375992dcaa7..e9b33ac2dcd8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -2166,6 +2166,30 @@ def main( verify_model(Einsum2(), example_args, {}, Expected2) +def test_outer(): + class Outer(torch.nn.Module): + def forward(self, x, y): + return torch.outer(x, y) + + @tvm.script.ir_module + class expected: + @R.function + def main( + a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) + gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(3, dtype=torch.float32), + torch.randn(4, dtype=torch.float32), + ) + verify_model(Outer(), example_args, {}, expected) + + def test_embedding(): class Embedding(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 643372750bd6..7695addd3f43 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -874,6 +874,27 @@ def main( verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) +def test_outer(): + class Outer(torch.nn.Module): + def forward(self, x, y): + return torch.outer(x, y) + + @tvm.script.ir_module + class expected: + @R.function + def main( + a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") + ) -> R.Tensor((3, 4), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) + gv: R.Tensor((3, 4), dtype="float32") = lv + R.output(gv) + return gv + + input_infos = [([3], "float32"), ([4], "float32")] + verify_model(Outer(), input_infos, {}, expected) + + @tvm.testing.requires_gpu def test_softplus(): import torch