Skip to content

Commit

Permalink
[ARITH] Introduce iterator (quasi)affine map detection. (#6667)
Browse files Browse the repository at this point in the history
* [ARITH] Introduce iterator (quasi)affine map detection.

The loop transformations (split, fuse) create bijective
maps from a collection of source iterators to target iterators.

DetectIterMap is a function that detects such bijective mappings
from the lowered index expression.

We choose the term quasi affine to be consistent with the
terminology used by in polyhedral compilation.
DetectIterMap can handle symbolic integers(in split/fuse) to some extent.

The utility can be useful in detecting loop transformation
patterns and data layout change patterns in TIR.

* Update per feedback
  • Loading branch information
tqchen authored Oct 14, 2020
1 parent 78bba3c commit 60ed926
Show file tree
Hide file tree
Showing 8 changed files with 1,298 additions and 13 deletions.
277 changes: 277 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/arith/iter_affine_map.h
* \brief Iterator quasi-affine mapping patterns.
*
* This file defines a collection of mapping patterns
* maps a collection of independent iterators to another
* collection of independent iterators.
*
* There are two main kinds of mapping patterns:
*
* - Fuse: fuse a collection of iterators into a single one
*
* domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
* fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
* domain(y) = [0, 24)
*
* - Split: split an iterator into multiple ones
*
* domain(x) = [0, 24)
* split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
* domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
*
* We use the name "(quasi)affine" to be consistent with
* the terminology used in the polyhedral compilation.
* Notably, fuse is an affine transformation,
* while split corresponds to additional floordiv/mod operations
* that can appear in quasi-affine transformations.
*/
#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
#define TVM_ARITH_ITER_AFFINE_MAP_H_

#include <tvm/ir/expr.h>

namespace tvm {
namespace arith {

/*!
* \brief Base class of all iter map expressions.
*
* An IterMapExpr is a special expression to store
* the result of IterMapDetection.
* It should not appear in a legal TIR PrimFunc.
*/
class IterMapExprNode : public PrimExprNode {
public:
// overrides
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "arith.IterMapExpr";
static constexpr const uint32_t _type_child_slots = 3;
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
};

/*!
* \brief Managed reference to IterMapExprNode.
* \sa IterMapExprNode
*/
class IterMapExpr : public PrimExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode);
};

/*!
* \brief Mark the source as an iterator in [0, extent).
*
* IterMark is used to mark source expression as a valid
* iterator to make future analysis easy.
*/
class IterMarkNode : public Object {
public:
/*!
* \brief The source expression, can either be
* a IterSumExpr or a Var.
*/
PrimExpr source;
/*!
* \brief The extent of the iteration.
*/
PrimExpr extent;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("extent", &extent);
}

bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(source, other->source) && equal(extent, other->extent);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(source);
hash_reduce(extent);
}

static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const char* _type_key = "arith.IterMark";
TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object);
};

/*!
* \brief Managed reference to IterMarkExprNode.
* \sa IterMarkExprNode
*/
class IterMark : public ObjectRef {
public:
/*!
* \brief constructor.
* \param source The source expression.
* \param extent The extent of the iterator.
*/
TVM_DLL IterMark(PrimExpr source, PrimExpr extent);

TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
};

/*!
* \brief Split of an iterator.
*
* result = floormod(floordiv(source, lower_factor), extent) * scale
*/
class IterSplitExprNode : public IterMapExprNode {
public:
/*! \brief The source marked iterator. */
IterMark source;
/*! \brief The lower factor to split the source. */
PrimExpr lower_factor;
/*! \brief The extent of the split. */
PrimExpr extent;
/*! \brief Additional scale. */
PrimExpr scale;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source", &source);
v->Visit("lower_factor", &lower_factor);
v->Visit("extent", &extent);
v->Visit("scale", &scale);
}

bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const {
return equal(source, other->source) && equal(lower_factor, other->lower_factor) &&
equal(extent, other->extent) && equal(scale, other->scale);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(source);
hash_reduce(lower_factor);
hash_reduce(extent);
hash_reduce(scale);
}

static constexpr const char* _type_key = "arith.IterSplitExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode);
};

/*!
* \brief Managed reference to IterSplitExprNode.
* \sa IterSplitExprNode
*/
class IterSplitExpr : public IterMapExpr {
public:
/*!
* \brief constructor from just source.
* \param source The source expression.
*/
TVM_DLL explicit IterSplitExpr(IterMark source);
/*!
* \brief constructor
* \param source The source expression.
* \param lower_factor The lower factor to split the source.
* \param extent The extent of the split.
* \param scale The additional scaling factor.
*/
TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
PrimExpr scale);

TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode);
};

/*!
* \brief Fuse multiple iterators by summing them with scaling.
*
* result = sum(args) + base
*/
class IterSumExprNode : public IterMapExprNode {
public:
/*! \brief The args to the sum. */
Array<IterSplitExpr> args;
/*! \brief The base offset. */
PrimExpr base;

// overrides
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("args", &args);
v->Visit("base", &base);
}

bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const {
return equal(args, other->args) && equal(base, other->base);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(args);
hash_reduce(base);
}

static constexpr const char* _type_key = "arith.IterSumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode);
};

/*!
* \brief Managed reference to IterSumExprNode.
* \sa IterSumExprNode
*/
class IterSumExpr : public IterMapExpr {
public:
/*!
* \brief constructor.
* \param args The args to the sum.
* \param base The base offset.
*/
TVM_DLL IterSumExpr(Array<IterSplitExpr> args, PrimExpr base);

TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
};

/*!
* \brief Detect if indices can be written as
*
* [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
*
* Here y = some-quasi-affine-iter-map(input_iters)
* and c are symbolic constants.
*
* We also requires that y_i and y_j to be independent for i != j.
*
* For returned value rv, the following is always true:
* - rv[i]->args.size() <=1: only one iterator per element.
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param analyzer Analyzer used to get context information.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
2 changes: 2 additions & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import detect_iter_map
108 changes: 108 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Iterator (quasi)affine mapping patterns."""
import tvm._ffi
from tvm.runtime import Object
from tvm.ir import PrimExpr
from . import _ffi_api


class IterMapExpr(PrimExpr):
"""Base class of all IterMap expressions."""


@tvm._ffi.register_object("arith.IterMark")
class IterMark(Object):
"""Mark the source as an iterator in [0, extent).
Parameters
----------
source : PrimExpr.
The source expression.
extent : PrimExpr
The extent of the iterator.
"""

def __init__(self, source, extent):
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)


@tvm._ffi.register_object("arith.IterSplitExpr")
class IterSplitExpr(IterMapExpr):
"""Split of an iterator.
result = floormod(floordiv(source, lower_factor), extent) * scale
Parameters
----------
source : IterMark
The source marked iterator.
lower_factor : PrimExpr
The lower factor to split the domain.
extent : PrimExpr
The extent of the split.
scale : PrimExpr
Additional scale to the split.
"""

def __init__(self, source, lower_factor, extent, scale):
self.__init_handle_by_constructor__(
_ffi_api.IterSplitExpr, source, lower_factor, extent, scale
)


@tvm._ffi.register_object("arith.IterSumExpr")
class IterSumExpr(IterMapExpr):
"""Fuse multiple iterators by summing them with scaling.
result = sum(args) + base
Parameters
----------
args : List[IterSplitExpr]
The input to the sum expression.
base : PrimExpr
The base offset.
"""

def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)


def detect_iter_map(indices, input_iters):
"""Detect if indices can be written mapped iters from input_iters.
Parameters
----------
indices : List[PrimExpr]
The input indices.
input_iters : Map[Var, Range]
The domain of each input iterators.
Returns
-------
results : List[IterSumExpr]
The iter map matching result.
Empty array if no match can be found.
"""
return _ffi_api.DetectIterMap(indices, input_iters)
Loading

0 comments on commit 60ed926

Please sign in to comment.