Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ def create_convert_map(
"mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul),
"ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
"ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
"outer.default": lambda node: self.block_builder.emit(
relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
),
"pow.Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,9 @@ def create_convert_map(
"mod": self._binary_op(relax.op.floor_mod, operator.mod),
"mul": self._binary_op(relax.op.multiply, operator.mul),
"ne": self._binary_op(relax.op.not_equal, operator.ne),
"outer": lambda node: self.block_builder.emit(
relax.op.outer(self.env[node.args[0]], self.env[node.args[1]])
),
"pow": self._binary_op(relax.op.power, operator.pow),
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
)
from .datatype import astype, wrap_param
from .index import dynamic_strided_slice, strided_slice, take
from .linear_algebra import einsum, linear, matmul
from .linear_algebra import einsum, linear, matmul, outer
from .manipulate import (
broadcast_to,
collapse_sum_like,
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relax/op/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,30 @@ def einsum(operands, subscripts):
operands = RxTuple(operands)

return _ffi_api.einsum(operands, subscripts) # type: ignore


def outer(x1: Expr, x2: Expr) -> Expr:
"""
Computes the outer product of two input expressions.

Parameters
----------
x1 : relax.Expr
The first input expression.

x2 : relax.Expr
The second input expression.

Notes
-----
This operation computes the outer product between two expressions,
resulting in a tensor where each element is the product of elements
from `x1` and `x2`. It is commonly used in tensor and matrix operations
to expand lower-dimensional inputs into higher-dimensional representations.

Returns
-------
result : relax.Expr
The resulting expression representing the outer product.
"""
return _ffi_api.outer(x1, x2)
19 changes: 19 additions & 0 deletions python/tvm/relax/transform/legalize_ops/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,22 @@ def _einsum(bb: BlockBuilder, call: Call) -> Expr:
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
)
return bb.call_te(topi.einsum, call.attrs.subscripts, *fields)


@register_legalize("relax.outer")
def _outer(bb: BlockBuilder, call: Call) -> Expr:
def te_outer(a: te.Tensor, b: te.Tensor) -> te.Tensor:
a_shape = list(a.shape)
b_shape = list(b.shape)
assert len(a_shape) == 1 and len(b_shape) == 1, "outer requires 1D tensors"

n = a_shape[0]
m = b_shape[0]

def compute_fn(i, j):
return a[i] * b[j]

return te.compute((n, m), compute_fn, name="outer")

lhs, rhs = call.args
return bb.call_te(te_outer, lhs, rhs, primfunc_name_hint="outer")
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
ones,
ones_like,
one_hot,
outer,
permute_dims,
power,
print,
Expand Down Expand Up @@ -826,6 +827,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"one_hot",
"opencl",
"output",
"outer",
"permute_dims",
"power",
"prim_value",
Expand Down
38 changes: 38 additions & 0 deletions src/relax/op/tensor/linear_algebra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,43 @@ TVM_REGISTER_OP("relax.einsum")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEinsum)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.outer */

Expr outer(Expr x1, Expr x2) {
static const Op& op = Op::Get("relax.outer");
return Call(op, {std::move(x1), std::move(x2)}, {});
}

TVM_REGISTER_GLOBAL("relax.op.outer").set_body_typed(outer);

StructInfo InferStructInfoOuter(const Call& call, const BlockBuilder& ctx) {
auto input_sinfo = GetInputTensorStructInfo(call, ctx);
auto x1_sinfo = input_sinfo[0];
auto x2_sinfo = input_sinfo[1];

// Ensure both inputs are 1D tensors
if (x1_sinfo->ndim != 1 || x2_sinfo->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "torch.outer requires both inputs to be 1D tensors.");
}

// Determine output shape
auto x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
auto x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
if (!x1_shape || !x2_shape) {
return TensorStructInfo(x1_sinfo->dtype, 2);
}
Array<PrimExpr> output_shape = {x1_shape->values[0], x2_shape->values[0]};
return TensorStructInfo(ShapeExpr(output_shape), x1_sinfo->dtype);
}

TVM_REGISTER_OP("relax.outer")
.set_num_inputs(2)
.add_argument("x1", "Tensor", "The first input tensor.")
.add_argument("x2", "Tensor", "The second input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOuter)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
8 changes: 8 additions & 0 deletions src/relax/op/tensor/linear_algebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ Expr matmul(Expr x1, Expr x2, Optional<DataType> out_dtype);
*/
Expr einsum(Expr operands, String subscripts);

/*!
* \brief Compute the outer product of two input expressions.
* \param x1 The first input expression.
* \param x2 The second input expression.
* \return The resulting expression representing the outer product.
*/
Expr outer(Expr x1, Expr x2);

} // namespace relax
} // namespace tvm

Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,6 +2166,30 @@ def main(
verify_model(Einsum2(), example_args, {}, Expected2)


def test_outer():
class Outer(torch.nn.Module):
def forward(self, x, y):
return torch.outer(x, y)

@tvm.script.ir_module
class expected:
@R.function
def main(
a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32")
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (
torch.randn(3, dtype=torch.float32),
torch.randn(4, dtype=torch.float32),
)
verify_model(Outer(), example_args, {}, expected)


def test_embedding():
class Embedding(Module):
def __init__(self):
Expand Down
21 changes: 21 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,27 @@ def main(
verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2)


def test_outer():
class Outer(torch.nn.Module):
def forward(self, x, y):
return torch.outer(x, y)

@tvm.script.ir_module
class expected:
@R.function
def main(
a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32")
) -> R.Tensor((3, 4), dtype="float32"):
with R.dataflow():
lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
gv: R.Tensor((3, 4), dtype="float32") = lv
R.output(gv)
return gv

input_infos = [([3], "float32"), ([4], "float32")]
verify_model(Outer(), input_infos, {}, expected)


@tvm.testing.requires_gpu
def test_softplus():
import torch
Expand Down