diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index cfbd44529515..ce7a425c94f9 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -394,6 +394,15 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span()); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); +/*! + * \brief Compute log(exp(a) + exp(b)). + * + * \param a Left operand. + * \param b Right operand. + * \param span The location of this operation in the source. + * \return The result expression. + */ +TVM_DLL PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute ceil(a / b) * @@ -404,6 +413,7 @@ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); * \note this function does eager constant folding for * index types(int32, int64) when possible. */ + TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute the remainder of floordiv @@ -1071,6 +1081,7 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(logaddexp); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*) diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index d27b6f1a3cfe..9be7256b446e 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -257,6 +257,22 @@ TOPI_DEFINE_BCAST_OP(floor_divide, { } }); +/*! + * \fn log_add_exp + * \brief Compute log(exp(A) + exp(B)) with auto-broadcasting. + * + * This operation is useful for numerically stable log-sum-exp computations, + * which frequently appear in probabilistic and statistical models. + * + * \param A The first input tensor, or Expr. + * \param B The second input tensor, or Expr. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return The computed log-sum-exp result. + */ +TOPI_DEFINE_BCAST_OP(log_add_exp, { return logaddexp(a, b); }); + /*! * \fn trunc divide * \brief Compute trunc(A / B) with auto-broadcasting. diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 97ccc6393cbb..733649264ae1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -295,6 +295,7 @@ def create_convert_map( "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), "floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv), + "logaddexp.default": self._binary_op(relax.op.log_add_exp, torch.logaddexp), "ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge), "ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge), "gt.Scalar": self._binary_op(relax.op.greater, operator.gt), diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97f18a239640..ddfdfc2b05d8 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -50,6 +50,7 @@ divide, equal, floor_divide, + log_add_exp, floor_mod, greater, greater_equal, diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 7a41c8b0953c..d18aac863535 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -85,6 +85,25 @@ def floor_divide(x1: Expr, x2: Expr) -> Expr: return _ffi_api.floor_divide(x1, x2) # type: ignore +def log_add_exp(x1: Expr, x2: Expr) -> Expr: + """ + Compute the log of the sum of exponentials of the inputs, element-wise. + + Parameters + ---------- + x1 : Expr + The first input tensor. + x2 : Expr + The second input tensor. + + Returns + ------- + Expr + The element-wise log-sum-exp of `x1` and `x2`. + """ + return _ffi_api.log_add_exp(x1, x2) + + def multiply(x1: Expr, x2: Expr) -> Expr: """Multiplication with numpy-style broadcasting. diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 41e317f1e0ef..1acbddb2190b 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -44,6 +44,7 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.add", _binary(topi.add)) register_legalize("relax.divide", _binary(topi.divide)) register_legalize("relax.floor_divide", _binary(topi.floor_divide)) +register_legalize("relax.log_add_exp", _binary(topi.log_add_exp)) register_legalize("relax.multiply", _binary(topi.multiply)) register_legalize("relax.power", _binary(topi.power)) register_legalize("relax.subtract", _binary(topi.subtract)) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ddc534cf6086..6fa3cc61cbbc 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -112,6 +112,7 @@ less_equal, linear, log, + log_add_exp, logical_and, logical_not, logical_or, @@ -794,6 +795,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "less_equal", "linear", "log", + "log_add_exp", "logical_and", "logical_not", "logical_or", diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index b31853bea666..362419bebf58 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -24,7 +24,7 @@ from tvm.tir import asin, asinh, acos, acosh, atan, atanh from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf -from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod +from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, logaddexp from tvm.tir import comm_reducer, min, max, sum from tvm.tir import add, subtract, multiply diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 4f56ec3c15bc..5ceb48127038 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -90,7 +90,7 @@ from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else from .op import likely, isnan, isnullptr, isfinite, isinf, copysign -from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv +from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp from .op import comm_reducer, min, max, sum from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 53c92fff86dc..3770a8be5fd2 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3221,6 +3221,28 @@ def floordiv(a, b, span=None): return _ffi_api._OpFloorDiv(a, b, span) # type: ignore +def logaddexp(a, b, span=None): + """Compute the logaddexp of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + """ + return _ffi_api._OpLogAddExp(a, b, span) # type: ignore + + def floormod(a, b, span=None): """Compute the floormod of two expressions. diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 2b350ff817d9..e2982ecfc21b 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -135,6 +135,25 @@ def floor_divide(lhs, rhs): return _cpp.floor_divide(lhs, rhs) +def log_add_exp(lhs, rhs): + """Log-sum-exp operation with auto-broadcasting. + + Parameters + ---------- + x1 : tvm.te.Tensor or Expr + The first input tensor or expression. + x2 : tvm.te.Tensor or Expr + The second input tensor or expression. + + Returns + ------- + ret : tvm.te.Tensor or Expr + Returns an Expr if both operands are Expr. + Otherwise, returns a Tensor. + """ + return _cpp.log_add_exp(lhs, rhs) + + def mod(lhs, rhs): """Modulus with auto-broadcasting diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 4a63993d507c..e7fab8f166e1 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -193,6 +193,7 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index b66eb96f8452..6b106f760d5f 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -70,6 +70,9 @@ Expr divide(Expr x1, Expr x2); /*! \brief Floor division with numpy-style broadcasting. */ Expr floor_divide(Expr x1, Expr x2); +/*! \brief Log Add Exponent with numpy-style broadcasting. */ +Expr log_add_exp(Expr x1, Expr x2); + /*! \brief Multiplication with numpy-style broadcasting. */ Expr multiply(Expr x1, Expr x2); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 46c15cb3dfc3..47aecf480988 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -507,6 +507,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { return tir::FloorDiv(a, b, span); } +PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) { + ICHECK(a.dtype().is_float()) << a; + ICHECK(b.dtype().is_float()) << b; + BinaryOpMatchTypes(a, b, span); + PrimExpr exp_sum = add(exp(a), exp(b)); + PrimExpr log_exp_sum = log(exp_sum); + return log_exp_sum; +} + PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; @@ -1134,6 +1143,7 @@ REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); +REGISTER_MAKE_BINARY_OP(_OpLogAddExp, logaddexp); REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index f6a28c7722af..2105172aed40 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -52,6 +52,7 @@ TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply); TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide); TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide); +TOPI_REGISTER_BCAST_OP("topi.log_add_exp", topi::log_add_exp); TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod); TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod); TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 98f0f1d9cac6..46029c856e5c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -682,6 +682,32 @@ def main( verify_model(LeakyReLU1(), example_args, {}, expected) +def test_logaddexp(): + class LogAddExp(Module): + def forward(self, input1, input2): + return torch.logaddexp(input1, input2) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + input_2: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log_add_exp(input_1, input_2) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 10, 10, dtype=torch.float32), + torch.randn(1, 3, 10, 10, dtype=torch.float32), + ) + verify_model(LogAddExp(), example_args, {}, expected) + + def test_logsoftmax(): class LogSoftmax(Module): def __init__(self):