Skip to content

Commit

Permalink
*: add support for windowing of tensors
Browse files Browse the repository at this point in the history
This commit adds support for windowing of tensors in the existing index
notation DSL. For example:

```
A(i, j) = B(i(1, 4), j) * C(i, j(5, 10))
```

causes `B` to be windowed along its first mode, and `C` to be windowed
along its second mode. In this commit any mix of windowed and
non-windowed modes are supported, along with windowing the same tensor
in different ways in the same expression. The windowing expressions
correspond to the `:` operator to slice dimensions in `numpy`.

Currently, only windowing by integers is supported.

Windowing is achieved by tying windowing information to particular
`Iterator` objects, as these are created for each `Tensor`-`IndexVar`
pair. When iterating over an `Iterator` that may be windowed, extra
steps are taken to either generate an index into the windowed space, or
to recover an index from a point in the windowed space.
  • Loading branch information
rohany committed Jan 21, 2021
1 parent 468ad7f commit 4fd7744
Show file tree
Hide file tree
Showing 14 changed files with 756 additions and 23 deletions.
80 changes: 79 additions & 1 deletion include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TACO_INDEX_NOTATION_H
#define TACO_INDEX_NOTATION_H

#include <functional>
#include <ostream>
#include <string>
#include <memory>
Expand Down Expand Up @@ -30,6 +31,7 @@ class Format;
class Schedule;

class IndexVar;
class WindowedIndexVar;
class TensorVar;

class IndexExpr;
Expand Down Expand Up @@ -228,6 +230,16 @@ class Access : public IndexExpr {
/// Returns the index variables used to index into the Access's TensorVar.
const std::vector<IndexVar>& getIndexVars() const;

/// hasWindowedModes returns true if any accessed modes are windowed.
bool hasWindowedModes() const;

/// Returns whether or not the input mode (0-indexed) is windowed.
bool isModeWindowed(int mode) const;

/// Return the {lower,upper} bound of the window on the input mode (0-indexed).
int getWindowLowerBound(int mode) const;
int getWindowUpperBound(int mode) const;

/// Assign the result of an expression to a left-hand-side tensor access.
/// ```
/// a(i) = b(i) * c(i);
Expand Down Expand Up @@ -800,11 +812,67 @@ class Multi : public IndexStmt {
/// Create a multi index statement.
Multi multi(IndexStmt stmt1, IndexStmt stmt2);

/// IndexVarInterface is a marker superclass for IndexVar-like objects.
/// It is intended to be used in situations where many IndexVar-like objects
/// must be stored together, like when building an Access AST node where some
/// of the access variables are windowed. Use cases for IndexVarInterface
/// will inspect the underlying type of the IndexVarInterface. For sake of
/// completeness, the current implementers of IndexVarInterface are:
/// * IndexVar
/// * WindowedIndexVar
/// If this set changes, make sure to update the match function.
class IndexVarInterface {
public:
virtual ~IndexVarInterface() = default;

/// match performs a dynamic case analysis of the implementers of IndexVarInterface
/// as a utility for handling the different values within. It mimics the dynamic
/// type assertion of Go.
static void match(
std::shared_ptr<IndexVarInterface> ptr,
std::function<void(std::shared_ptr<IndexVar>)> ivarFunc,
std::function<void(std::shared_ptr<WindowedIndexVar>)> wvarFunc
) {
auto iptr = std::dynamic_pointer_cast<IndexVar>(ptr);
auto wptr = std::dynamic_pointer_cast<WindowedIndexVar>(ptr);
if (iptr != nullptr) {
ivarFunc(iptr);
} else if (wptr != nullptr) {
wvarFunc(wptr);
} else {
taco_iassert("IndexVarInterface was not IndexVar or WindowedIndexVar");
}
}
};

/// WindowedIndexVar represents an IndexVar that has been windowed. For example,
/// A(i) = B(i(2, 4))
/// In this case, i(2, 4) is a WindowedIndexVar. WindowedIndexVar is defined
/// before IndexVar so that IndexVar can return objects of type WindowedIndexVar.
class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public IndexVarInterface {
public:
WindowedIndexVar(IndexVar base, int lo = -1, int hi = -1);
~WindowedIndexVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;

/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
/// this index variable.
int getLowerBound() const;
int getUpperBound() const;

private:
struct Content;
std::shared_ptr<Content> content;
};

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar> {
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);

/// Returns the name of the index variable.
Expand All @@ -813,6 +881,8 @@ class IndexVar : public util::Comparable<IndexVar> {
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi);

private:
struct Content;
Expand All @@ -823,7 +893,15 @@ struct IndexVar::Content {
std::string name;
};

struct WindowedIndexVar::Content {
IndexVar base;
int lo;
int hi;
};

std::ostream& operator<<(std::ostream&, const std::shared_ptr<IndexVarInterface>&);
std::ostream& operator<<(std::ostream&, const IndexVar&);
std::ostream& operator<<(std::ostream&, const WindowedIndexVar&);

/// A suchthat statement provides a set of IndexVarRel that constrain
/// the iteration space for the child concrete index notation
Expand Down
17 changes: 17 additions & 0 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ struct AccessNode : public IndexExprNode {

TensorVar tensorVar;
std::vector<IndexVar> indexVars;

// An AccessNode carries the windowing information for an IndexVar + TensorVar
// combination. windowedModes contains the lower and upper bounds of each
// windowed mode (0-indexed).
struct Window {
int lo;
int hi;
friend bool operator==(const Window& a, const Window& b) {
return a.lo == b.lo && a.hi == b.hi;
}
};
std::map<int, Window> windowedModes;

protected:
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
/// then indexVars must be set afterwards.
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
};

struct LiteralNode : public IndexExprNode {
Expand Down
15 changes: 15 additions & 0 deletions include/taco/lower/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ class Iterator : public util::Comparable<Iterator> {
/// Returns true if the iterator is defined, false otherwise.
bool defined() const;

/// Methods for querying and operating on windowed tensor modes.

/// isWindowed returns true if this iterator is operating over a window
/// of a tensor mode.
bool isWindowed() const;

/// getWindow{Lower,Upper}Bound return the {Lower,Upper} bound of the
/// window that this iterator operates over.
ir::Expr getWindowLowerBound() const;
ir::Expr getWindowUpperBound() const;

friend bool operator==(const Iterator&, const Iterator&);
friend bool operator<(const Iterator&, const Iterator&);
friend std::ostream& operator<<(std::ostream&, const Iterator&);
Expand All @@ -169,6 +180,10 @@ class Iterator : public util::Comparable<Iterator> {

Iterator(std::shared_ptr<Content> content);
void setChild(const Iterator& iterator) const;

friend class Iterators;
/// setWindowBounds sets the window bounds of this iterator.
void setWindowBounds(ir::Expr lo, ir::Expr hi);
};

/**
Expand Down
23 changes: 22 additions & 1 deletion include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,30 @@ class LowererImpl : public util::Uncopyable {
/// Create an expression to index into a tensor value array.
ir::Expr generateValueLocExpr(Access access) const;

/// Expression that evaluates to true if none of the iteratators are exhausted
/// Expression that evaluates to true if none of the iterators are exhausted
ir::Expr checkThatNoneAreExhausted(std::vector<Iterator> iterators);

/// Expression that returns the beginning of a window to iterate over
/// in a compressed iterator. It is used when operating over windows of
/// tensors, instead of the full tensor.
ir::Expr searchForStartOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end);

/// Statement that guards against going out of bounds of the window that
/// the input iterator was configured with.
ir::Stmt upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access);

/// Expression that recovers a canonical index variable from a position in
/// a windowed position iterator. A windowed position iterator iterates over
/// values in the range [lo, hi). This expression projects values in that
/// range back into the canonical range of [0, n).
ir::Expr projectWindowedPositionToCanonicalSpace(Iterator iterator, ir::Expr expr);

// projectCanonicalSpaceToWindowedPosition is the opposite of
// projectWindowedPositionToCanonicalSpace. It takes an expression ranging
// through the canonical space of [0, n) and projects it up to the windowed
// range of [lo, hi).
ir::Expr projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr);

private:
bool assemble;
bool compute;
Expand Down
83 changes: 83 additions & 0 deletions include/taco/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class TensorBase {
/// Create an index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<IndexVar>& indices);

/// Create a possibly windowed index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);

/// Create an index expression that accesses (reads) this (scalar) tensor.
Access operator()();

Expand Down Expand Up @@ -621,6 +624,20 @@ class Tensor : public TensorBase {
template <typename... IndexVars>
Access operator()(const IndexVars&... indices);

/// The below two Access methods are used to allow users to access tensors
/// with a mix of IndexVar's and WindowedIndexVar's. This allows natural
/// expressions like
/// A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
/// to be constructed without adjusting the original API.

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const WindowedIndexVar& first, const IndexVars&... indices);

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const IndexVar& first, const IndexVars&... indices);

ScalarAccess<CType> operator()(const std::vector<int>& indices);

/// Create an index expression that accesses (reads) this tensor.
Expand All @@ -629,6 +646,15 @@ class Tensor : public TensorBase {

/// Assign an expression to a scalar tensor.
void operator=(const IndexExpr& expr);

private:
/// The _access method family is the template level implementation of
/// Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> _access(const First& first, const Rest&... rest);
std::vector<std::shared_ptr<IndexVarInterface>> _access();
template <typename... Args>
Access _access_wrapper(const Args&... args);
};

template <typename CType>
Expand Down Expand Up @@ -1084,6 +1110,63 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
return TensorBase::operator()(std::vector<IndexVar>{indices...});
}

/// The _access() methods perform primitive recursion on the input variadic template.
/// This means that each instance of the _access method matches on the first element
/// of the variadic template parameter pack, performs an "action", then recurses
/// with the remaining elements in the parameter pack through a recursive call
/// to _access. Since this is recursion, we need a base case. The empty argument
/// instance of _access returns an empty value of the desired type, in this case
/// a vector of IndexVarInterface.
template <typename CType>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
return std::vector<std::shared_ptr<IndexVarInterface>>{};
}

/// The recursive case of _access matches on the first element, and attempts to
/// create a shared_ptr out of it. It then makes a recursive call to get a
/// vector with the rest of the elements. Then, it pushes the first element onto
/// the back of the vector -- this check ensures that the type First is indeed
/// a member of IndexVarInterface.
template <typename CType>
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
auto var = std::make_shared<First>(first);
auto ret = _access(rest...);
ret.push_back(var);
return ret;
}

/// _access_wrapper just calls into _access and reverses the result to get the initial
/// order of the arguments.
template <typename CType>
template <typename... Args>
Access Tensor<CType>::_access_wrapper(const Args&... args) {
auto resultReversed = this->_access(args...);
std::vector<std::shared_ptr<IndexVarInterface>> result;
result.reserve(resultReversed.size());
for (auto it = resultReversed.rbegin(); it != resultReversed.rend(); it++) {
result.push_back(*it);
}
return TensorBase::operator()(result);
}

/// We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
/// so that the template engine can differentiate between the two versions.
// TODO (rohany): I think that there is a chance here that I might not need these
// two methods if I have _access. I think that instead I would just have to remove
// the other operator() methods that also take in IndexVar... so that there isn't
// any confusion.
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const IndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const WindowedIndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}

template <typename CType>
ScalarAccess<CType> Tensor<CType>::operator()(const std::vector<int>& indices) {
taco_uassert(indices.size() == (size_t)getOrder())
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ void CodeGen_C::visit(const Allocate* op) {
stream << ", ";
}
else {
stream << "malloc(";
stream << "calloc(1, ";
}
stream << "sizeof(" << elementType << ")";
stream << " * ";
Expand Down
8 changes: 8 additions & 0 deletions src/error/error_checks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultV
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
IndexVar var = readNode->indexVars[mode];
Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode);

// If this access has windowed modes, use the dimensions of those windows
// as the shape, rather than the shape of the underlying tensor.
auto a = Access(readNode);
if (a.isModeWindowed(mode)) {
dimension = Dimension(a.getWindowUpperBound(mode) - a.getWindowLowerBound(mode));
}

if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
} else {
Expand Down
Loading

0 comments on commit 4fd7744

Please sign in to comment.