diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h index f461c5640bb0..8a176cb3cee8 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tir/index_map.h @@ -70,6 +70,18 @@ class IndexMapNode : public Object { */ Array final_indices; + /*! + * \brief The inverse index map. + * + * When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. + * Otherwise, the inverse index map will be computed on the fly. + * It is the user's responsibility to ensure the correctness of the pre-defined inverse index + * map. + * + * \note ObjectRef is used here instead of IndexMap to avoid circular reference. + */ + Optional inverse_index_map; + /*! * \brief Default constructor * @@ -133,6 +145,7 @@ class IndexMapNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("initial_indices", &initial_indices); v->Visit("final_indices", &final_indices); + v->Visit("inverse_index_map", &inverse_index_map); } bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const { @@ -153,15 +166,24 @@ class IndexMapNode : public Object { class IndexMap : public ObjectRef { public: - IndexMap(Array initial_indices, Array final_indices); + /*! + * \brief The constructor + * \param initial_indices Variables representing the indices prior to remapping + * \param final_indices Expressions defining the indices after remapping. + * \param inverse_index_map The optional pre-defined inverse index map + */ + IndexMap(Array initial_indices, Array final_indices, + Optional inverse_index_map = NullOpt); /*! * \brief Create an index map from a packed function * \param ndim The number of dimensions * \param func The function to be applied + * \param inverse_index_map The optional pre-defined inverse index map * \return The created index map */ - static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func); + static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func, + Optional inverse_index_map = NullOpt); /*! \brief Generate the inverse mapping. * diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 12c8053e39cc..e525fc2cc31a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -271,6 +271,12 @@ class IndexMap(Object): Variables representing the indices prior to remapping. final_indices : List[PrimExpr] Expressions defining the indices after remapping. + inverse_index_map : Union[Callable, Optional[IndexMap]] + The optional pre-defined inverse index map. + When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. + Otherwise, the inverse index map will be computed on the fly. + It is the user's responsibility to ensure the correctness of the pre-defined inverse + index map. """ initial_indices: List[Var] @@ -281,11 +287,19 @@ class IndexMap(Object): # Stage.transform_layout for more details. AXIS_SEPARATOR = "axis_separator" - def __init__(self, initial_indices, final_indices): - self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices) + def __init__(self, initial_indices, final_indices, inverse_index_map): + if isinstance(inverse_index_map, Callable): + inverse_index_map = IndexMap.from_func(inverse_index_map) + self.__init_handle_by_constructor__( + _ffi_api.IndexMap, initial_indices, final_indices, inverse_index_map + ) @staticmethod - def from_func(mapping_function: Callable, ndim: Optional[int] = None): + def from_func( + mapping_function: Callable, + ndim: Optional[int] = None, + inverse_index_map: Union[Callable, Optional["IndexMap"]] = None, + ): """Create an index map from a function Parameters @@ -305,6 +319,13 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): mapping_function does not use variadic arguments, ndim is optional. + inverse_index_map : Union[Callable, Optional[IndexMap]] + The optional pre-defined inverse index map. + When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. + Otherwise, the inverse index map will be computed on the fly. + It is the user's responsibility to ensure the correctness of the pre-defined inverse + index map. + Returns ------- index_map: IndexMap @@ -312,7 +333,9 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): Returns an IndexMap representing the `mapping_function`. """ - index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim) + index_map, axis_separators = IndexMap.from_func_with_separators( + mapping_function, ndim, inverse_index_map + ) assert not axis_separators, ( "The mapping_function provided to IndexMap.from_func " "may not return IndexMap.AXIS_SEPARATOR. " @@ -321,7 +344,11 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None): return index_map @staticmethod - def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None): + def from_func_with_separators( + mapping_function: Callable, + ndim: Optional[int] = None, + inverse_index_map: Union[Callable, Optional["IndexMap"]] = None, + ): """Create an index map from a function Parameters @@ -341,6 +368,13 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = mapping_function does not use variadic arguments, ndim is optional. + inverse_index_map : Union[Callable, Optional[IndexMap]] + The optional pre-defined inverse index map. + When this is defined, IndexMap::Inverse will return the pre-defined inverse index map. + Otherwise, the inverse index map will be computed on the fly. + It is the user's responsibility to ensure the correctness of the pre-defined inverse + index map. + Returns ------- ret: Tuple[IndexMap, List[int]] @@ -401,7 +435,7 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = f"Instead received {val} of type {type(val)}." ) - return IndexMap(initial_indices, final_indices), axis_separators + return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators def is_equivalent_to(self, other_map: "IndexMap") -> bool: """Return if the index maps are equivalent. diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 0e3c3b2774c8..cceff72ec82f 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -34,20 +34,23 @@ namespace tvm { namespace tir { -IndexMap::IndexMap(Array initial_indices, Array final_indices) { +IndexMap::IndexMap(Array initial_indices, Array final_indices, + Optional inverse_index_map) { auto n = make_object(); n->initial_indices = std::move(initial_indices); n->final_indices = std::move(final_indices); + n->inverse_index_map = std::move(inverse_index_map); data_ = std::move(n); } -IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func) { +IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func, + Optional inverse_index_map) { Array initial_indices; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32))); } - return IndexMap(initial_indices, func(initial_indices)); + return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map)); } std::pair IndexMap::NonSurjectiveInverse(Array initial_ranges) const { @@ -114,6 +117,10 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia } IndexMap IndexMap::Inverse(Array initial_ranges) const { + if ((*this)->inverse_index_map.defined()) { + // return the pre-defined inverse index map if exists. + return Downcast((*this)->inverse_index_map.value()); + } // Dummy variables to represent the inverse's inputs. Array output_vars; for (size_t i = 0; i < (*this)->final_indices.size(); i++) { @@ -232,7 +239,14 @@ Array IndexMapNode::MapShape(const Array& shape, return output; } -String IndexMapNode::ToPythonString() const { +/*! + * \brief Auxilarry function to comvert an index map to lambda expression in Python. + * \param initial_indices The initial indices in the index map. + * \param final_indices The final indices in the index map. + * \return The lambda expression string. + */ +std::string IndexMap2PythonLambdaExpr(const Array& initial_indices, + const Array& final_indices) { std::unordered_set used_names; Map var_remap; for (const Var& initial_index : initial_indices) { @@ -259,10 +273,28 @@ String IndexMapNode::ToPythonString() const { } oss << ": ("; for (size_t i = 0; i < final_indices.size(); ++i) { + if (i != 0) { + oss << " "; + } oss << Substitute(final_indices[i], var_remap); - oss << ", "; + oss << ","; } oss << ")"; + return oss.str(); +} + +String IndexMapNode::ToPythonString() const { + std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices); + if (!inverse_index_map.defined()) { + return String(lambda_expr); + } + // Also convert the inverse index map. + IndexMap inverse = Downcast(inverse_index_map.value()); + std::string inverse_lambda_expr = + IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices); + std::ostringstream oss; + oss << "tvm.tir.IndexMap.from_func(" << lambda_expr + << ", inverse_index_map=" << inverse_lambda_expr << ")"; return String(oss.str()); } @@ -275,8 +307,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(IndexMapNode); TVM_REGISTER_GLOBAL("tir.IndexMap") - .set_body_typed([](Array initial_indices, Array final_indices) { - return IndexMap(initial_indices, final_indices); + .set_body_typed([](Array initial_indices, Array final_indices, + Optional inverse_index_map) { + return IndexMap(initial_indices, final_indices, inverse_index_map); }); TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices") diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc index b0cafac3151f..b071b2d7e4a1 100644 --- a/src/tir/schedule/analysis/layout.cc +++ b/src/tir/schedule/analysis/layout.cc @@ -167,20 +167,25 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } return a.lower_factor > b.lower_factor; }); + // Compute the inverse permutation by argsort + std::vector inverse_order = order; + std::sort(inverse_order.begin(), inverse_order.end(), + [&order](int _a, int _b) -> bool { return order[_a] < order[_b]; }); // Step 5. Create the indexing mapping auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // - split_exprs = std::move(split_exprs), // - order = std::move(order), // - shape = buffer->shape, // + &split_exprs, // + &order, // + & shape = buffer->shape, // analyzer // ](Array indices) -> Array { ICHECK_EQ(indices.size(), shape.size()); for (int i = 0, n = indices.size(); i < n; ++i) { analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); } + // Step 5.1: Fuse all indices into a flattened one PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); int ndim = split_exprs.size(); - // Step 5.1. Split the flattened index according to `split_exprs` + // Step 5.2. Split the flattened index according to `split_exprs` std::vector split; split.reserve(ndim); for (int i = ndim - 1; i >= 0; --i) { @@ -190,7 +195,7 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& index = floordiv(index, extent); } std::reverse(split.begin(), split.end()); - // Step 5.2. Reorder the indexing pattern according to `order` + // Step 5.3. Reorder the indexing pattern according to `order` Array results; results.reserve(ndim); for (int i = 0; i < ndim; ++i) { @@ -198,7 +203,39 @@ Optional SuggestIndexMap(const Buffer& buffer, const Array& } return results; }; - return IndexMap::FromFunc(ndim, f_alter_layout); + // Step 6: Create the inverse index mapping. + auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape, + analyzer](Array indices) -> Array { + ICHECK_EQ(indices.size(), split_exprs.size()); + // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3. + // After the inverse permutation, indices[i] corresponds to split_exprs[i] + Array inv_permuted_indices; + inv_permuted_indices.reserve(indices.size()); + for (int i = 0, n = indices.size(); i < n; ++i) { + const Var& index = indices[inverse_order[i]]; + inv_permuted_indices.push_back(index); + analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent))); + } + + // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2. + PrimExpr flattened_index = make_const(indices[0]->dtype, 0); + int64_t stride = 1; + for (int i = static_cast(split_exprs.size()) - 1; i >= 0; --i) { + flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index; + stride *= split_exprs[i].extent; + } + // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. + Array result; + result.reserve(shape.size()); + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i])); + flattened_index = floordiv(flattened_index, shape[i]); + result.push_back(index); + } + return Array(result.rbegin(), result.rend()); + }; + IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse); + return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); } TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 5524abbaf094..378e5183b49c 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -101,6 +101,47 @@ def test_suggest_index_map_bijective(): assert index_map.is_equivalent_to(expected_index_map) +def test_suggest_index_map_winograd(): + """use case in winograd conv where the indices are complicated""" + fused_outer, i3_3_fused, i4_0, i4_1 = _make_vars("fused_outer", "i3_3_fused", "i4_0", "i4_1") + eps = floordiv(fused_outer, 336) * 2 + floordiv(floormod(fused_outer, 16), 8) + nu = floordiv(floormod(fused_outer, 336), 112) * 2 + floordiv(floormod(fused_outer, 8), 4) + co = floormod(fused_outer, 4) * 32 + i3_3_fused + ci = (i4_0 * 32) + i4_1 + buffer = decl_buffer(shape=[6, 6, 128, 128]) + index_map = suggest_index_map( + buffer=buffer, + indices=[eps, nu, co, ci], + loops=_make_loops( + loop_vars=[fused_outer, i3_3_fused, i4_0, i4_1], + extents=[1008, 32, 4, 32], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda i0, i1, i2, i3: ( + floordiv(i0, 2), + floordiv(i1, 2), + floormod(i0, 2), + floormod(((i1 * 4) + floordiv(i2, 32)), 8), + floormod(i2, 32), + floordiv(i3, 32), + floormod(i3, 32), + ) + ) + assert index_map.is_equivalent_to(expected_index_map) + inverse_index_map = index_map.inverse(buffer.shape) + expected_inverse_index_map = IndexMap.from_func( + lambda i0, i1, i2, i3, i4, i5, i6: ( + ((i0 * 2) + i2), + ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)), + floormod(((i3 * 32) + i4), 128), + ((i5 * 32) + i6), + ) + ) + assert inverse_index_map.is_equivalent_to(expected_inverse_index_map) + + @tvm.script.ir_module class DenseVNNIModule: @T.prim_func