Skip to content

Commit

Permalink
Merge pull request #372 from rohany/windowing-different-notation
Browse files Browse the repository at this point in the history
*: add support for windowing of tensors
  • Loading branch information
stephenchouca authored Feb 12, 2021
2 parents 48824d5 + 2d4a7d3 commit baf174c
Show file tree
Hide file tree
Showing 18 changed files with 930 additions and 35 deletions.
84 changes: 82 additions & 2 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TACO_INDEX_NOTATION_H
#define TACO_INDEX_NOTATION_H

#include <functional>
#include <ostream>
#include <string>
#include <memory>
Expand Down Expand Up @@ -30,13 +31,15 @@ class Format;
class Schedule;

class IndexVar;
class WindowedIndexVar;
class TensorVar;

class IndexExpr;
class Assignment;
class Access;

struct AccessNode;
struct AccessWindow;
struct LiteralNode;
struct NegNode;
struct SqrtNode;
Expand Down Expand Up @@ -220,14 +223,25 @@ class Access : public IndexExpr {
Access() = default;
Access(const Access&) = default;
Access(const AccessNode*);
Access(const TensorVar& tensorVar, const std::vector<IndexVar>& indices={});
Access(const TensorVar &tensorVar, const std::vector<IndexVar> &indices = {},
const std::map<int, AccessWindow> &windows = {});

/// Return the Access expression's TensorVar.
const TensorVar &getTensorVar() const;

/// Returns the index variables used to index into the Access's TensorVar.
const std::vector<IndexVar>& getIndexVars() const;

/// hasWindowedModes returns true if any accessed modes are windowed.
bool hasWindowedModes() const;

/// Returns whether or not the input mode (0-indexed) is windowed.
bool isModeWindowed(int mode) const;

/// Return the {lower,upper} bound of the window on the input mode (0-indexed).
int getWindowLowerBound(int mode) const;
int getWindowUpperBound(int mode) const;

/// Assign the result of an expression to a left-hand-side tensor access.
/// ```
/// a(i) = b(i) * c(i);
Expand Down Expand Up @@ -800,11 +814,67 @@ class Multi : public IndexStmt {
/// Create a multi index statement.
Multi multi(IndexStmt stmt1, IndexStmt stmt2);

/// IndexVarInterface is a marker superclass for IndexVar-like objects.
/// It is intended to be used in situations where many IndexVar-like objects
/// must be stored together, like when building an Access AST node where some
/// of the access variables are windowed. Use cases for IndexVarInterface
/// will inspect the underlying type of the IndexVarInterface. For sake of
/// completeness, the current implementers of IndexVarInterface are:
/// * IndexVar
/// * WindowedIndexVar
/// If this set changes, make sure to update the match function.
class IndexVarInterface {
public:
virtual ~IndexVarInterface() = default;

/// match performs a dynamic case analysis of the implementers of IndexVarInterface
/// as a utility for handling the different values within. It mimics the dynamic
/// type assertion of Go.
static void match(
std::shared_ptr<IndexVarInterface> ptr,
std::function<void(std::shared_ptr<IndexVar>)> ivarFunc,
std::function<void(std::shared_ptr<WindowedIndexVar>)> wvarFunc
) {
auto iptr = std::dynamic_pointer_cast<IndexVar>(ptr);
auto wptr = std::dynamic_pointer_cast<WindowedIndexVar>(ptr);
if (iptr != nullptr) {
ivarFunc(iptr);
} else if (wptr != nullptr) {
wvarFunc(wptr);
} else {
taco_iassert("IndexVarInterface was not IndexVar or WindowedIndexVar");
}
}
};

/// WindowedIndexVar represents an IndexVar that has been windowed. For example,
/// A(i) = B(i(2, 4))
/// In this case, i(2, 4) is a WindowedIndexVar. WindowedIndexVar is defined
/// before IndexVar so that IndexVar can return objects of type WindowedIndexVar.
class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public IndexVarInterface {
public:
WindowedIndexVar(IndexVar base, int lo = -1, int hi = -1);
~WindowedIndexVar() = default;

/// getIndexVar returns the underlying IndexVar.
IndexVar getIndexVar() const;

/// get{Lower,Upper}Bound returns the {lower,upper} bound of the window of
/// this index variable.
int getLowerBound() const;
int getUpperBound() const;

private:
struct Content;
std::shared_ptr<Content> content;
};

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar> {
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);

/// Returns the name of the index variable.
Expand All @@ -813,6 +883,8 @@ class IndexVar : public util::Comparable<IndexVar> {
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi);

private:
struct Content;
Expand All @@ -823,7 +895,15 @@ struct IndexVar::Content {
std::string name;
};

struct WindowedIndexVar::Content {
IndexVar base;
int lo;
int hi;
};

std::ostream& operator<<(std::ostream&, const std::shared_ptr<IndexVarInterface>&);
std::ostream& operator<<(std::ostream&, const IndexVar&);
std::ostream& operator<<(std::ostream&, const WindowedIndexVar&);

/// A suchthat statement provides a set of IndexVarRel that constrain
/// the iteration space for the child concrete index notation
Expand Down
21 changes: 19 additions & 2 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,21 @@

namespace taco {

// An AccessNode carries the windowing information for an IndexVar + TensorVar
// combination. An AccessWindow contains the lower and upper bounds of each
// windowed mode (0-indexed). AccessWindow is extracted from AccessNode so that
// it can be referenced externally.
struct AccessWindow {
int lo;
int hi;
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
return a.lo == b.lo && a.hi == b.hi;
}
};

struct AccessNode : public IndexExprNode {
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices)
: IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices) {}
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices, const std::map<int, AccessWindow>& windows={})
: IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices), windowedModes(windows) {}

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
Expand All @@ -26,6 +37,12 @@ struct AccessNode : public IndexExprNode {

TensorVar tensorVar;
std::vector<IndexVar> indexVars;
std::map<int, AccessWindow> windowedModes;

protected:
/// Initialize an AccessNode with just a TensorVar. If this constructor is used,
/// then indexVars must be set afterwards.
explicit AccessNode(TensorVar tensorVar) : IndexExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar) {}
};

struct LiteralNode : public IndexExprNode {
Expand Down
3 changes: 2 additions & 1 deletion include/taco/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,10 @@ struct Allocate : public StmtNode<Allocate> {
Expr num_elements;
Expr old_elements; // used for realloc in CUDA
bool is_realloc;
bool clear; // Whether to use calloc to allocate this memory.

static Stmt make(Expr var, Expr num_elements, bool is_realloc=false,
Expr old_elements=Expr());
Expr old_elements=Expr(), bool clear=false);

static const IRNodeType _type_info = IRNodeType::Allocate;
};
Expand Down
15 changes: 15 additions & 0 deletions include/taco/lower/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ class Iterator : public util::Comparable<Iterator> {
/// Returns true if the iterator is defined, false otherwise.
bool defined() const;

/// Methods for querying and operating on windowed tensor modes.

/// isWindowed returns true if this iterator is operating over a window
/// of a tensor mode.
bool isWindowed() const;

/// getWindow{Lower,Upper}Bound return the {Lower,Upper} bound of the
/// window that this iterator operates over.
ir::Expr getWindowLowerBound() const;
ir::Expr getWindowUpperBound() const;

friend bool operator==(const Iterator&, const Iterator&);
friend bool operator<(const Iterator&, const Iterator&);
friend std::ostream& operator<<(std::ostream&, const Iterator&);
Expand All @@ -170,6 +181,10 @@ class Iterator : public util::Comparable<Iterator> {

Iterator(std::shared_ptr<Content> content);
void setChild(const Iterator& iterator) const;

friend class Iterators;
/// setWindowBounds sets the window bounds of this iterator.
void setWindowBounds(ir::Expr lo, ir::Expr hi);
};

/**
Expand Down
23 changes: 22 additions & 1 deletion include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,30 @@ class LowererImpl : public util::Uncopyable {
/// Create an expression to index into a tensor value array.
ir::Expr generateValueLocExpr(Access access) const;

/// Expression that evaluates to true if none of the iteratators are exhausted
/// Expression that evaluates to true if none of the iterators are exhausted
ir::Expr checkThatNoneAreExhausted(std::vector<Iterator> iterators);

/// Expression that returns the beginning of a window to iterate over
/// in a compressed iterator. It is used when operating over windows of
/// tensors, instead of the full tensor.
ir::Expr searchForStartOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end);

/// Statement that guards against going out of bounds of the window that
/// the input iterator was configured with.
ir::Stmt upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access);

/// Expression that recovers a canonical index variable from a position in
/// a windowed position iterator. A windowed position iterator iterates over
/// values in the range [lo, hi). This expression projects values in that
/// range back into the canonical range of [0, n).
ir::Expr projectWindowedPositionToCanonicalSpace(Iterator iterator, ir::Expr expr);

// projectCanonicalSpaceToWindowedPosition is the opposite of
// projectWindowedPositionToCanonicalSpace. It takes an expression ranging
// through the canonical space of [0, n) and projects it up to the windowed
// range of [lo, hi).
ir::Expr projectCanonicalSpaceToWindowedPosition(Iterator iterator, ir::Expr expr);

private:
bool assemble;
bool compute;
Expand Down
79 changes: 79 additions & 0 deletions include/taco/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class TensorBase {
/// Create an index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<IndexVar>& indices);

/// Create a possibly windowed index expression that accesses (reads or writes) this tensor.
Access operator()(const std::vector<std::shared_ptr<IndexVarInterface>>& indices);

/// Create an index expression that accesses (reads) this (scalar) tensor.
Access operator()();

Expand Down Expand Up @@ -621,6 +624,20 @@ class Tensor : public TensorBase {
template <typename... IndexVars>
Access operator()(const IndexVars&... indices);

/// The below two Access methods are used to allow users to access tensors
/// with a mix of IndexVar's and WindowedIndexVar's. This allows natural
/// expressions like
/// A(i, j(1, 3)) = B(i(2, 4), j) * C(i(5, 7), j(7, 9))
/// to be constructed without adjusting the original API.

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const WindowedIndexVar& first, const IndexVars&... indices);

/// Create an index expression that accesses (reads, writes) this tensor.
template <typename... IndexVars>
Access operator()(const IndexVar& first, const IndexVars&... indices);

ScalarAccess<CType> operator()(const std::vector<int>& indices);

/// Create an index expression that accesses (reads) this tensor.
Expand All @@ -629,6 +646,15 @@ class Tensor : public TensorBase {

/// Assign an expression to a scalar tensor.
void operator=(const IndexExpr& expr);

private:
/// The _access method family is the template level implementation of
/// Access() expressions containing mixes of IndexVar and WindowedIndexVar objects.
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> _access(const First& first, const Rest&... rest);
std::vector<std::shared_ptr<IndexVarInterface>> _access();
template <typename... Args>
Access _access_wrapper(const Args&... args);
};

template <typename CType>
Expand Down Expand Up @@ -1084,6 +1110,59 @@ Access Tensor<CType>::operator()(const IndexVars&... indices) {
return TensorBase::operator()(std::vector<IndexVar>{indices...});
}

/// The _access() methods perform primitive recursion on the input variadic template.
/// This means that each instance of the _access method matches on the first element
/// of the variadic template parameter pack, performs an "action", then recurses
/// with the remaining elements in the parameter pack through a recursive call
/// to _access. Since this is recursion, we need a base case. The empty argument
/// instance of _access returns an empty value of the desired type, in this case
/// a vector of IndexVarInterface.
template <typename CType>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access() {
return std::vector<std::shared_ptr<IndexVarInterface>>{};
}

/// The recursive case of _access matches on the first element, and attempts to
/// create a shared_ptr out of it. It then makes a recursive call to get a
/// vector with the rest of the elements. Then, it pushes the first element onto
/// the back of the vector -- this check ensures that the type First is indeed
/// a member of IndexVarInterface.
template <typename CType>
template <typename First, typename... Rest>
std::vector<std::shared_ptr<IndexVarInterface>> Tensor<CType>::_access(const First& first, const Rest&... rest) {
auto var = std::make_shared<First>(first);
auto ret = _access(rest...);
ret.push_back(var);
return ret;
}

/// _access_wrapper just calls into _access and reverses the result to get the initial
/// order of the arguments.
template <typename CType>
template <typename... Args>
Access Tensor<CType>::_access_wrapper(const Args&... args) {
auto resultReversed = this->_access(args...);
std::vector<std::shared_ptr<IndexVarInterface>> result;
result.reserve(resultReversed.size());
for (auto& it : util::reverse(resultReversed)) {
result.push_back(it);
}
return TensorBase::operator()(result);
}

/// We have to case on whether the first argument is an IndexVar or a WindowedIndexVar
/// so that the template engine can differentiate between the two versions.
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const IndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}
template <typename CType>
template <typename... IndexVars>
Access Tensor<CType>::operator()(const WindowedIndexVar& first, const IndexVars&... indices) {
return this->_access_wrapper(first, indices...);
}

template <typename CType>
ScalarAccess<CType> Tensor<CType>::operator()(const std::vector<int>& indices) {
taco_uassert(indices.size() == (size_t)getOrder())
Expand Down
8 changes: 7 additions & 1 deletion src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,13 @@ void CodeGen_C::visit(const Allocate* op) {
stream << ", ";
}
else {
stream << "malloc(";
// If the allocation was requested to clear the allocated memory,
// use calloc instead of malloc.
if (op->clear) {
stream << "calloc(1, ";
} else {
stream << "malloc(";
}
}
stream << "sizeof(" << elementType << ")";
stream << " * ";
Expand Down
9 changes: 7 additions & 2 deletions src/codegen/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,9 +1293,14 @@ void CodeGen_CUDA::visit(const Call* op) {
stream << op->func << "(";
parentPrecedence = Precedence::CALL;

// Need to print cast to type so that arguments match
// Need to print cast to type so that arguments match.
if (op->args.size() > 0) {
if (op->type != op->args[0].type() || isa<Literal>(op->args[0])) {
// However, the binary search arguments take int* as their first
// argument. This pointer information isn't carried anywhere in
// the argument expressions, so we need to special case and not
// emit an invalid cast for that argument.
auto opIsBinarySearch = op->func == "taco_binarySearchAfter" || op->func == "taco_binarySearchBefore";
if (!opIsBinarySearch && (op->type != op->args[0].type() || isa<Literal>(op->args[0]))) {
stream << "(" << printCUDAType(op->type, false) << ") ";
}
op->args[0].accept(this);
Expand Down
Loading

0 comments on commit baf174c

Please sign in to comment.