diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h new file mode 100644 index 000000000000..93fc9a36c5dc --- /dev/null +++ b/include/tvm/relax/nested_msg.h @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/nested_msg.h + * \brief Helper container to store nested message for robust tuple-aware analysis. + * + * Please see NestedMsg for description of usage. + * + * \sa NestedMsg + */ +#ifndef TVM_RELAX_NESTED_MSG_H_ +#define TVM_RELAX_NESTED_MSG_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Container that stores possibly nested message with leaf message type T. + * + * NestedMsg is a helper structure to store intermediate + * message state in pass analysis so we can robustly handle message + * passing with the presence of nested tuple types. + * + * Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]]. + * Each nested message corresponds to the same nesting structure as + * the nested tuple types when we encounter them in analysis. + * + * Relax support nested tuple structures in the IR. Nested tuple structure + * is important to support advanced groupings in cases such as gradient calculation + * and other scenarios. + * + * The possible presence of nested tuple does mean that we need to + * to robustly handle analysis that contains nested tuple structures + * in a dataflow graph. + * + * \code + * + * v1 = relu(v0) + * v2 = exp(v0) + * t = ((v0, v1), (v2,), v0) + * t1 = t[0] + * v3 = concat(t1) + * v4 = t[2] + * v5 = add(v4, v3) + * + * \endcode + * + * Consider the above code sequence that contains a mixture of tuple + * nesting and normal operations. A common message-passing-based analysis + * will track messages attached to each intermediate variable. + * + * Because the intermediate value can contain nested-tuples, we need to have + * abilities to nest messages according to tuple structure and propagate them + * along the way. In python, this simply corresponds to using a tuple to hold + * nested messages. This class provides a helper wrapper in C++ to present such + * possibly nested message for a given leaf message. + * + * This design pattern is necessary to handle tuple values regardless of + * the normal form design of the IR to enable different messages for each + * tuple component without enforcing all tuple elements to have the same message. + * + * Please consider the following patterns in our pass: + * + * On a forward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan forward + * - input_msg = [MapToNestedMsg(x, lookup_map) for x in call->args] + * - output_msg = ForwardProp[call->op](input_msg, call) + * - map[binding->var] = output_msg + * - Use MapToNestedMsg to remap the remaining body. + * + * On a backward propagation message passing analysis: + * - Create map [leafnode=>NestedMsg], scan backward + * - output_msg = lookup map(binding->var) + * - handle case when output_msg is null + * - input_msg = BackProp[call->op](out_msg, call) + * - for arg, msg in zip(call->args, input_msg), + * DecomposeNestedMessage(arg, msg, lambda node, m: update_map(node, m)) + * - update_map(node, m) => CombineNestedMessage(map[node], m) + * + * Here leafnode is a node that you would like to propagate messages to + * such as constant, var and should not include tuple. + * + * We also recommend writing unit-test cases that involve nested tuple composition + * and decomposition. + * + * \sa MapToNestedMsg, DecomposeNestedMsg, CombineNestedMsg, ForEachLeaf, Equal + * + * \note If you want to write robust message passing-based analysis for + * programs that can contain nested tuples, you likely need to + * use this class or logic of a similar kind. + */ +template +class NestedMsg : public ObjectRef { + public: + // default constructors. + NestedMsg() = default; + NestedMsg(const NestedMsg&) = default; + NestedMsg(NestedMsg&&) = default; + NestedMsg& operator=(const NestedMsg&) = default; + NestedMsg& operator=(NestedMsg&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already satisfies the constraint + * \param ptr + */ + explicit NestedMsg(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + NestedMsg(runtime::NullOptType) {} // NOLINT(*) + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit NestedMsg(std::nullptr_t) {} + NestedMsg& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + NestedMsg(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // Array> handling + NestedMsg(Array, void> other) // NOLINT(*) + : ObjectRef(std::move(other)) {} + NestedMsg& operator=(Array, void> other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + + // initializer list handling + NestedMsg(std::initializer_list> other) // NOLINT(*) + : NestedMsg(Array, void>(other)) {} + NestedMsg& operator=(std::initializer_list> other) { + return operator=(Array, void>(other)); + } + + // delete the int constructor + // since NestedMsg(0) is ambiguous + // 0 can be implicitly casted to nullptr_t + explicit NestedMsg(int val) = delete; + NestedMsg& operator=(int val) = delete; + // operator overloadings + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + /*! \return Whether the nested message is not-null leaf value */ + bool IsLeaf() const { return data_ != nullptr && data_->IsInstance(); } + + /*! \return Whether the nested message is null */ + bool IsNull() const { return data_ == nullptr; } + + /*! \return Whether the nested message is nested */ + bool IsNested() const { return data_ != nullptr && data_->IsInstance(); } + + /*! + * \return The underlying leaf value. + * \note This function checks if the msg is leaf. + */ + T LeafValue() const { + ICHECK(IsLeaf()); + return T(data_); + } + + /*! + * \return a corresponding nested array. + * \note This checks if the underlying data type is array. + */ + Array, void> NestedArray() const { + ICHECK(IsNested()); + return Array, void>(data_); + } + + using ContainerType = Object; + using LeafContainerType = typename T::ContainerType; + + static_assert(std::is_base_of::value, "NestedMsg is only defined for ObjectRef."); + + static constexpr bool _type_is_nullable = true; +}; + +/*! + * \brief Apply fvisit for each leaf elements in the nested message. + * \param fvisit The visit callback. + * \param msg The input nested message. + * \tparam T the content type of nested msg + * \tparam FType the visitor type with signature void fvisit(T) + */ +template +void ForEachLeaf(const NestedMsg& msg, FType fvisit) { + if (msg == nullptr) return; + if (msg.IsLeaf()) { + fvisit(msg.LeafValue()); + } else { + for (NestedMsg x : msg.NestedArray()) { + ForEachLeaf(x, fvisit); + } + } +} + +/*! + * \brief Recursively compare two nested messages. + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fequal The equal functor with signature bool fequal(T, T) + * \tparam T the content type of nested msg + * \tparam FType the equal comparator type + */ +template +bool Equal(const NestedMsg& lhs, const NestedMsg& rhs, FType fequal) { + if (lhs.IsNull()) return rhs.IsNull(); + if (rhs.IsNull()) return lhs.IsNull(); + if (lhs.IsLeaf()) { + return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue()); + } else { + if (!rhs.IsNested()) return false; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + if (arr_lhs.size() != arr_rhs.size()) return false; + for (size_t i = 0; i < arr_lhs.size(); ++i) { + if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false; + } + return true; + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive tuples and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param expr The input expression. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(Expr expr, FType fmapleaf) { + if (auto* tuple = expr.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (Expr x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map structinfo with possible nested-sinfo to nested message. + * + * This function will unpack recursive sinfo and run fmapleaf for each leaf, + * then recursively combines the results together into a NestedMsg. + * + * The nesting structure will corresponds to the tuple structure. + * + * \param sinfo The input struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmap(StructInfo)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsg(StructInfo sinfo, FType fmapleaf) { + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (StructInfo x : tuple->fields) { + res.push_back(MapToNestedMsg(x, fmapleaf)); + } + return res; + } else { + return fmapleaf(sinfo); + } +} + +/*! + * \brief Map expr with possible nested-tuple to nested message. + * + * This function will unpack recursive expr by its struct info and + * run fmapleaf for each leaf, then recursively combines the results + * together into a NestedMsg. + * + * The nesting structure will corresponds to the struct info of expr. + * + * \param expr The input expression which should have struct info. + * \param fmapleaf The mapping function for each leaf with signature `NestedMsg fmapleaf(Expr)` + * \tparam T the content type of nested msg + * \tparam FType The mapping function type + */ +template +NestedMsg MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) { + auto sinfo = GetStructInfo(expr); + if (auto* tuple = sinfo.as()) { + Array> res; + res.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + res.push_back(MapToNestedMsgBySInfo(field, fmapleaf)); + } + return res; + } else { + return fmapleaf(expr); + } +} + +/*! + * \brief Map nested message back to the expr. + * + * This function will decompose the nested message and + * run fmapleaf for each leaf message and get the leaf expr, + * then recursively combines the results as tuple expr. + * + * \param msg The input nested message. + * \param fmapleaf The mapping function for each leaf with signature `Expr fmapleaf(Optional)`. + * \tparam T the content type of nested msg. + * \tparam FType The mapping function type. + */ +template +Expr NestedMsgToExpr(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return fmapleaf(NullOpt); + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array subexpr; + subexpr.reserve(arr.size()); + for (size_t i = 0; i < arr.size(); ++i) { + subexpr.push_back(NestedMsgToExpr(arr[i], fmapleaf)); + } + Optional simplified_tuple; + bool simplified_flag = false; + if (subexpr.size() >= 1) { + simplified_flag = true; + for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) { + auto* node = subexpr[i].as(); + if (node == nullptr || node->index != static_cast(i)) { + simplified_flag = false; + } else { + if (simplified_tuple.defined()) { + simplified_flag &= (simplified_tuple == node->tuple); + } else { + simplified_tuple = node->tuple; + ICHECK(simplified_tuple.defined()); + } + } + } + } + return simplified_flag ? simplified_tuple.value() : Tuple(subexpr); + } +} + +/*! + * \brief Recursively combine two nested message into one. + * + * This function requires the two messages to be compatible with each other. + * The combination rule is as follows: + * - combine(null, msg) => msg + * - combine(leaf1, leaf2) => fcombine(leaf1, leaf2) + * - combine(array1, array2) => [combine(x, y) for x, y in zip(array1, array2)] + * - This function will throw an error if array have different size + * + * \param lhs The left operand. + * \param rhs The right operand. + * \param fcombine with signature T fcombine(T lhs, T rhs) + * \tparam T the content type of nested msg + * \tparam FType combine function type. + */ +template +NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine) { + if (lhs.IsNull()) return rhs; + if (rhs.IsNull()) return lhs; + + if (lhs.IsLeaf()) { + ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested"; + return NestedMsg(fcombine(lhs.LeafValue(), rhs.LeafValue())); + } else { + ICHECK(lhs.IsNested()); + ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested"; + Array> arr_lhs = lhs.NestedArray(); + Array> arr_rhs = rhs.NestedArray(); + ICHECK_EQ(arr_lhs.size(), arr_rhs.size()) + << "Cannot combine two nested array with different sizes"; + Array> res; + res.reserve(arr_lhs.size()); + for (size_t i = 0; i < arr_lhs.size(); ++i) { + res.push_back(CombineNestedMsg(arr_lhs[i], arr_rhs[i], fcombine)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively map a nested message to another one, with leaf mapped by the input fmapleaf. + * \param msg The nested message to be mapped. + * \param fmapleaf The leaf map function, with signature NestedMsg fmapleaf(T msg) + * \tparam T The content type of nested message. + * \tparam FType The leaf map function type. + * \return The new nested message. + */ +template +NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return msg; + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array> res; + res.reserve(arr.size()); + for (int i = 0; i < static_cast(arr.size()); ++i) { + res.push_back(MapNestedMsg(arr[i], fmapleaf)); + } + return NestedMsg(res); + } +} + +/*! + * \brief Recursively decompose the tuple structure in expr and msg along with it. + * + * This function will call fvisitleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be decomposed. + * \param msg The input nested message. + * \param fvisitleaf with signature fvisitleaf(Expr expr, NestedMsg msg) + * \tparam T the content type of nested msg + * \tparam FType The visit function type. + */ +template +void DecomposeNestedMsg(Expr expr, NestedMsg msg, FType fvisitleaf) { + if (auto* tuple = expr.as()) { + ICHECK(msg.IsNested()) << "Expected nested to match tuple"; + Array> arr = msg.NestedArray(); + ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size to match tuple size"; + for (size_t i = 0; i < arr.size(); ++i) { + DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf); + } + } else { + fvisitleaf(expr, msg); + } +} + +/*! + * \brief Recursively transform the tuple structure in expr and msgs along with it. + * + * This function will call ftransleaf for each leaf expression in expr. + * This function will throw an error if the nesting structure in msg does not + * match the tuple nesting structure in expr. + * + * \param expr The input expression to be transform.  + * \param msgs The input messages to guide the transformation. + * \param ftransleaf with signature ftransleaf(Expr, Array>)->Expr + * \tparam T the content type of nested msg + * \tparam N the number of messages + * \tparam FType The visit function type. + */ +template +Expr TransformTupleLeaf(Expr expr, std::array, N> msgs, FType ftransleaf) { + StructInfo sinfo = GetStructInfo(expr); + if (const auto* tuple = sinfo.as()) { + std::array>, N> msg_arrays; + for (size_t i = 0; i < N; ++i) { + ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple"; + msg_arrays[i] = msgs[i].NestedArray(); + } + bool same = true; + Array fields; + fields.reserve(tuple->fields.size()); + for (size_t i = 0; i < tuple->fields.size(); ++i) { + Expr field; + if (const auto* expr_tuple = expr.as()) { + field = expr_tuple->fields[i]; + } else { + field = TupleGetItem(expr, i); + UpdateStructInfo(field, tuple->fields[i]); + } + std::array, N> sub_msgs; + for (size_t j = 0; j < N; ++j) { + sub_msgs[j] = msg_arrays[j][i]; + } + fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf)); + same &= (fields.back().same_as(field)); + } + return same ? expr : Tuple(fields); + } else { + for (const auto& msg : msgs) { + ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple"; + } + return ftransleaf(expr, msgs); + } +} + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_NESTED_MSG_H_ diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc new file mode 100644 index 000000000000..48af552007fd --- /dev/null +++ b/tests/cpp/nested_msg_test.cc @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace tvm; +using namespace tvm::runtime; +using namespace tvm::relax; + +TEST(NestedMsg, Basic) { + // start with no annotation + relax::Var x("x", NullOpt), y("y", NullOpt); + + // constructor from array, T and nullopt. + NestedMsg msg({x, NullOpt, x}); + + EXPECT_TRUE(msg.IsNested()); + EXPECT_FALSE(msg.IsLeaf()); + EXPECT_TRUE(msg != nullptr); + + EXPECT_ANY_THROW(msg.LeafValue()); + + auto arr = msg.NestedArray(); + EXPECT_TRUE(arr[0].same_as(x)); + EXPECT_TRUE(arr[1] == nullptr); + EXPECT_TRUE(arr[1].IsNull()); + + EXPECT_TRUE(arr[2].LeafValue().same_as(x)); + + auto a0 = arr[0]; + EXPECT_TRUE(a0.IsLeaf()); + + // assignment + // assign null + a0 = NullOpt; + EXPECT_TRUE(a0 == nullptr); + + // assign array + a0 = {x, {x, NullOpt, y}}; + EXPECT_TRUE(a0.IsNested()); + auto t0 = a0.NestedArray()[1]; + EXPECT_TRUE(t0.IsNested()); + EXPECT_TRUE(t0.NestedArray()[2].same_as(y)); + + // assign leaf + a0 = x; + + EXPECT_TRUE(a0.IsLeaf()); + EXPECT_TRUE(a0.same_as(x)); +} + +TEST(NestedMsg, ForEachLeaf) { + relax::Var x("x", NullOpt), y("y", NullOpt); + NestedMsg msg = {x, {x, y}, NullOpt, {x, {x, y}}}; + + int x_count = 0, y_count = 0; + + ForEachLeaf(msg, [&](const Expr& v) { + if (v.same_as(x)) ++x_count; + if (v.same_as(y)) ++y_count; + }); + EXPECT_EQ(x_count, 4); + EXPECT_EQ(y_count, 2); +} + +TEST(NestedMsg, Equal) { + relax::Var x("x", NullOpt), y("y", NullOpt); + relax::Var z("z", NullOpt); + + auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); }; + + using M = NestedMsg; + + EXPECT_TRUE(Equal(M(NullOpt), M(NullOpt), fequal)); + + EXPECT_TRUE(Equal(M(x), M(x), fequal)); + + EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal)); + + EXPECT_TRUE(Equal(M({x, NullOpt}), M({x, NullOpt}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}}), fequal)); + + EXPECT_TRUE(Equal(M({x, {NullOpt, y}, {x, z}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + // type mismatch + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, x}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}, {x, NullOpt}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}, {x, z}}), fequal)); + + EXPECT_FALSE(Equal(M(x), M(NullOpt), fequal)); + + EXPECT_FALSE(Equal(M(NullOpt), M(x), fequal)); + + EXPECT_FALSE(Equal(M(x), M(Array({x})), fequal)); + + EXPECT_FALSE(Equal(M(Array({x})), M(x), fequal)); +} + +TEST(NestedMsg, MapAndDecompose) { + relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16))); + relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32))); + relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64))); + + BlockBuilder bb = BlockBuilder::Create(NullOpt); + relax::Expr t0 = bb->Normalize(Tuple({x, y})); + relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0})); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + auto output = MapToNestedMsg(t1, [&](Expr value) { + if (value.same_as(x)) return c0; + if (value.same_as(y)) return c1; + return c2; + }); + + NestedMsg expected = {{c0, c1}, c0, c2, {c0, c1}}; + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + auto output2 = + MapToNestedMsg(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg { + const auto* prim_sinfo = sinfo.as(); + if (prim_sinfo == nullptr) return NullOpt; + int bits = prim_sinfo->dtype.bits(); + if (bits == 16) return c0; + if (bits == 32) return c1; + if (bits == 64) return c2; + return NullOpt; + }); + + EXPECT_TRUE(Equal(output2, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); + + int x_count = 0, y_count = 0, z_count = 0; + + DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg msg) { + if (value.same_as(x)) { + EXPECT_TRUE(msg.same_as(c0)); + ++x_count; + } else if (value.same_as(y)) { + EXPECT_TRUE(msg.same_as(c1)); + ++y_count; + } else { + EXPECT_TRUE(msg.same_as(c2)); + ++z_count; + } + }); + EXPECT_EQ(x_count, 3); + EXPECT_EQ(y_count, 2); + EXPECT_EQ(z_count, 1); +} + +TEST(NestedMsg, MapToNestedMsgBySInfo) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + auto sf2 = TupleStructInfo({sf0, sf0}); + auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0})); + + auto msg = MapToNestedMsgBySInfo(x, [](Expr value) { return value; }); + + EXPECT_TRUE(msg.IsNested()); + auto arr = msg.NestedArray(); + + EXPECT_TRUE(arr[1].IsNested()); + auto arr1 = arr[1].NestedArray(); + + EXPECT_TRUE(arr1[0].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), TupleGetItem(TupleGetItem(x, 1), 0))); + + EXPECT_TRUE(arr[2].IsLeaf()); + EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2))); +} + +TEST(NestedMsg, NestedMsgToExpr) { + auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0); + auto sf1 = TupleStructInfo({sf0, sf0}); + + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + relax::Var x("x", sf0), y("y", sf0), z("z", sf0); + + NestedMsg msg = {c0, {c0, c1}, {c0, {c1, c2}}}; + auto expr = NestedMsgToExpr(msg, [&](Optional leaf) { + ICHECK(leaf.defined()); + int value = leaf.value().IntValue(); + switch (value) { + case 0: + return x; + case 1: + return y; + default: + return z; + } + }); + + Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})}); + EXPECT_TRUE(StructuralEqual()(expr, expected)); + + // test simplified + relax::Var t("t", sf1); + NestedMsg msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)}; + auto expr1 = NestedMsgToExpr(msg1, [](Optional leaf) { return leaf.value(); }); + EXPECT_TRUE(StructuralEqual()(expr1, t)); +} + +TEST(NestedMsg, CombineNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + + NestedMsg lhs = {c0, {c0, c1}, NullOpt, {c0, {c1, c2}}}; + NestedMsg rhs = {c1, {c2, NullOpt}, NullOpt, {c1, {c2, c2}}}; + NestedMsg expected = {c1, {c2, c1}, NullOpt, {c1, {c2, c2}}}; + + auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) { + if (x->value > y->value) return x; + return y; + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, MapNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + auto c3 = Integer(3); + + NestedMsg msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, NullOpt}}}; + + auto output = MapNestedMsg(msg, [](Integer x) { + if (x->value == 0) { + return NestedMsg(Integer(3)); + } else if (x->value == 1) { + return NestedMsg(); + } else { + return NestedMsg(x); + } + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + +TEST(NestedMsg, TransformTupleLeaf) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + using NInt = NestedMsg; + + NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}}; + NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}}; + + PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32)); + relax::Var x("x", s), y("y", s), z("z", s); + BlockBuilder bb = BlockBuilder::Create(NullOpt); + Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})})); + + auto ftransleaf = [&](Expr value, std::array msgs) -> Expr { + int lhs = Downcast(msgs[0].LeafValue())->value; + int rhs = Downcast(msgs[1].LeafValue())->value; + if (lhs > rhs) + return z; + else if (lhs == rhs) + return value; + else + return y; + }; + + Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})}); + + EXPECT_TRUE(StructuralEqual()( + TransformTupleLeaf(expr, std::array({msg1, msg2}), ftransleaf), expected)); + + EXPECT_TRUE( + expr.same_as(TransformTupleLeaf(expr, std::array({msg1, msg1}), ftransleaf))); +}