Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Construct the inverse in SuggestIndexMap (apache#12797)
Browse files Browse the repository at this point in the history
Computing the inverse mapping requires arithmetic analysis which is not guaranteed to cover all cases. We provide the pre-defined inverse index map instead.
  • Loading branch information
vinx13 authored and xinetzone committed Nov 25, 2022
1 parent 0e504a5 commit 5f3f89b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 21 deletions.
26 changes: 24 additions & 2 deletions include/tvm/tir/index_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class IndexMapNode : public Object {
*/
Array<PrimExpr> final_indices;

/*!
* \brief The inverse index map.
*
* When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
* Otherwise, the inverse index map will be computed on the fly.
* It is the user's responsibility to ensure the correctness of the pre-defined inverse index
* map.
*
* \note ObjectRef is used here instead of IndexMap to avoid circular reference.
*/
Optional<ObjectRef> inverse_index_map;

/*!
* \brief Default constructor
*
Expand Down Expand Up @@ -133,6 +145,7 @@ class IndexMapNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("initial_indices", &initial_indices);
v->Visit("final_indices", &final_indices);
v->Visit("inverse_index_map", &inverse_index_map);
}

bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
Expand All @@ -153,15 +166,24 @@ class IndexMapNode : public Object {

class IndexMap : public ObjectRef {
public:
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices);
/*!
* \brief The constructor
* \param initial_indices Variables representing the indices prior to remapping
* \param final_indices Expressions defining the indices after remapping.
* \param inverse_index_map The optional pre-defined inverse index map
*/
IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map = NullOpt);

/*!
* \brief Create an index map from a packed function
* \param ndim The number of dimensions
* \param func The function to be applied
* \param inverse_index_map The optional pre-defined inverse index map
* \return The created index map
*/
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func);
static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
Optional<IndexMap> inverse_index_map = NullOpt);

/*! \brief Generate the inverse mapping.
*
Expand Down
46 changes: 40 additions & 6 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ class IndexMap(Object):
Variables representing the indices prior to remapping.
final_indices : List[PrimExpr]
Expressions defining the indices after remapping.
inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.
"""

initial_indices: List[Var]
Expand All @@ -281,11 +287,19 @@ class IndexMap(Object):
# Stage.transform_layout for more details.
AXIS_SEPARATOR = "axis_separator"

def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap, initial_indices, final_indices)
def __init__(self, initial_indices, final_indices, inverse_index_map):
if isinstance(inverse_index_map, Callable):
inverse_index_map = IndexMap.from_func(inverse_index_map)
self.__init_handle_by_constructor__(
_ffi_api.IndexMap, initial_indices, final_indices, inverse_index_map
)

@staticmethod
def from_func(mapping_function: Callable, ndim: Optional[int] = None):
def from_func(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
):
"""Create an index map from a function
Parameters
Expand All @@ -305,14 +319,23 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
mapping_function does not use variadic arguments, ndim is
optional.
inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.
Returns
-------
index_map: IndexMap
Returns an IndexMap representing the `mapping_function`.
"""
index_map, axis_separators = IndexMap.from_func_with_separators(mapping_function, ndim)
index_map, axis_separators = IndexMap.from_func_with_separators(
mapping_function, ndim, inverse_index_map
)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
"may not return IndexMap.AXIS_SEPARATOR. "
Expand All @@ -321,7 +344,11 @@ def from_func(mapping_function: Callable, ndim: Optional[int] = None):
return index_map

@staticmethod
def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = None):
def from_func_with_separators(
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
):
"""Create an index map from a function
Parameters
Expand All @@ -341,6 +368,13 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
mapping_function does not use variadic arguments, ndim is
optional.
inverse_index_map : Union[Callable, Optional[IndexMap]]
The optional pre-defined inverse index map.
When this is defined, IndexMap::Inverse will return the pre-defined inverse index map.
Otherwise, the inverse index map will be computed on the fly.
It is the user's responsibility to ensure the correctness of the pre-defined inverse
index map.
Returns
-------
ret: Tuple[IndexMap, List[int]]
Expand Down Expand Up @@ -401,7 +435,7 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
f"Instead received {val} of type {type(val)}."
)

return IndexMap(initial_indices, final_indices), axis_separators
return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators

def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
Expand Down
47 changes: 40 additions & 7 deletions src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@
namespace tvm {
namespace tir {

IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices) {
IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map) {
auto n = make_object<IndexMapNode>();
n->initial_indices = std::move(initial_indices);
n->final_indices = std::move(final_indices);
n->inverse_index_map = std::move(inverse_index_map);
data_ = std::move(n);
}

IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func) {
IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
Optional<IndexMap> inverse_index_map) {
Array<Var> initial_indices;
initial_indices.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
}
return IndexMap(initial_indices, func(initial_indices));
return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
}

std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
Expand Down Expand Up @@ -114,6 +117,10 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
}

IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
if ((*this)->inverse_index_map.defined()) {
// return the pre-defined inverse index map if exists.
return Downcast<IndexMap>((*this)->inverse_index_map.value());
}
// Dummy variables to represent the inverse's inputs.
Array<Var> output_vars;
for (size_t i = 0; i < (*this)->final_indices.size(); i++) {
Expand Down Expand Up @@ -232,7 +239,14 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
return output;
}

String IndexMapNode::ToPythonString() const {
/*!
* \brief Auxilarry function to comvert an index map to lambda expression in Python.
* \param initial_indices The initial indices in the index map.
* \param final_indices The final indices in the index map.
* \return The lambda expression string.
*/
std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
const Array<PrimExpr>& final_indices) {
std::unordered_set<std::string> used_names;
Map<Var, PrimExpr> var_remap;
for (const Var& initial_index : initial_indices) {
Expand All @@ -259,10 +273,28 @@ String IndexMapNode::ToPythonString() const {
}
oss << ": (";
for (size_t i = 0; i < final_indices.size(); ++i) {
if (i != 0) {
oss << " ";
}
oss << Substitute(final_indices[i], var_remap);
oss << ", ";
oss << ",";
}
oss << ")";
return oss.str();
}

String IndexMapNode::ToPythonString() const {
std::string lambda_expr = IndexMap2PythonLambdaExpr(initial_indices, final_indices);
if (!inverse_index_map.defined()) {
return String(lambda_expr);
}
// Also convert the inverse index map.
IndexMap inverse = Downcast<IndexMap>(inverse_index_map.value());
std::string inverse_lambda_expr =
IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
std::ostringstream oss;
oss << "tvm.tir.IndexMap.from_func(" << lambda_expr
<< ", inverse_index_map=" << inverse_lambda_expr << ")";
return String(oss.str());
}

Expand All @@ -275,8 +307,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IndexMapNode);

TVM_REGISTER_GLOBAL("tir.IndexMap")
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices) {
return IndexMap(initial_indices, final_indices);
.set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices,
Optional<IndexMap> inverse_index_map) {
return IndexMap(initial_indices, final_indices, inverse_index_map);
});

TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
Expand Down
49 changes: 43 additions & 6 deletions src/tir/schedule/analysis/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,25 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
}
return a.lower_factor > b.lower_factor;
});
// Compute the inverse permutation by argsort
std::vector<int> inverse_order = order;
std::sort(inverse_order.begin(), inverse_order.end(),
[&order](int _a, int _b) -> bool { return order[_a] < order[_b]; });
// Step 5. Create the indexing mapping
auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
split_exprs = std::move(split_exprs), //
order = std::move(order), //
shape = buffer->shape, //
&split_exprs, //
&order, //
& shape = buffer->shape, //
analyzer //
](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), shape.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
}
// Step 5.1: Fuse all indices into a flattened one
PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
int ndim = split_exprs.size();
// Step 5.1. Split the flattened index according to `split_exprs`
// Step 5.2. Split the flattened index according to `split_exprs`
std::vector<PrimExpr> split;
split.reserve(ndim);
for (int i = ndim - 1; i >= 0; --i) {
Expand All @@ -190,15 +195,47 @@ Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>&
index = floordiv(index, extent);
}
std::reverse(split.begin(), split.end());
// Step 5.2. Reorder the indexing pattern according to `order`
// Step 5.3. Reorder the indexing pattern according to `order`
Array<PrimExpr> results;
results.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
results.push_back(split[order[i]]);
}
return results;
};
return IndexMap::FromFunc(ndim, f_alter_layout);
// Step 6: Create the inverse index mapping.
auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape,
analyzer](Array<Var> indices) -> Array<PrimExpr> {
ICHECK_EQ(indices.size(), split_exprs.size());
// Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3.
// After the inverse permutation, indices[i] corresponds to split_exprs[i]
Array<Var> inv_permuted_indices;
inv_permuted_indices.reserve(indices.size());
for (int i = 0, n = indices.size(); i < n; ++i) {
const Var& index = indices[inverse_order[i]];
inv_permuted_indices.push_back(index);
analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent)));
}

// Step 6.2: Fuse all the indices. This is the inverse of Step 5.2.
PrimExpr flattened_index = make_const(indices[0]->dtype, 0);
int64_t stride = 1;
for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) {
flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index;
stride *= split_exprs[i].extent;
}
// Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1.
Array<PrimExpr> result;
result.reserve(shape.size());
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i]));
flattened_index = floordiv(flattened_index, shape[i]);
result.push_back(index);
}
return Array<PrimExpr>(result.rbegin(), result.rend());
};
IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse);
return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map);
}

TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,47 @@ def test_suggest_index_map_bijective():
assert index_map.is_equivalent_to(expected_index_map)


def test_suggest_index_map_winograd():
"""use case in winograd conv where the indices are complicated"""
fused_outer, i3_3_fused, i4_0, i4_1 = _make_vars("fused_outer", "i3_3_fused", "i4_0", "i4_1")
eps = floordiv(fused_outer, 336) * 2 + floordiv(floormod(fused_outer, 16), 8)
nu = floordiv(floormod(fused_outer, 336), 112) * 2 + floordiv(floormod(fused_outer, 8), 4)
co = floormod(fused_outer, 4) * 32 + i3_3_fused
ci = (i4_0 * 32) + i4_1
buffer = decl_buffer(shape=[6, 6, 128, 128])
index_map = suggest_index_map(
buffer=buffer,
indices=[eps, nu, co, ci],
loops=_make_loops(
loop_vars=[fused_outer, i3_3_fused, i4_0, i4_1],
extents=[1008, 32, 4, 32],
),
predicate=True,
)
expected_index_map = IndexMap.from_func(
lambda i0, i1, i2, i3: (
floordiv(i0, 2),
floordiv(i1, 2),
floormod(i0, 2),
floormod(((i1 * 4) + floordiv(i2, 32)), 8),
floormod(i2, 32),
floordiv(i3, 32),
floormod(i3, 32),
)
)
assert index_map.is_equivalent_to(expected_index_map)
inverse_index_map = index_map.inverse(buffer.shape)
expected_inverse_index_map = IndexMap.from_func(
lambda i0, i1, i2, i3, i4, i5, i6: (
((i0 * 2) + i2),
((i1 * 2) + floordiv(((i3 * 32) + i4), 128)),
floormod(((i3 * 32) + i4), 128),
((i5 * 32) + i6),
)
)
assert inverse_index_map.is_equivalent_to(expected_inverse_index_map)


@tvm.script.ir_module
class DenseVNNIModule:
@T.prim_func
Expand Down

0 comments on commit 5f3f89b

Please sign in to comment.