Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of store-based OpenScop transform #4

Merged
merged 2 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions include/polymer/Support/OslScopStmtOpSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- OslScopStmtOpSet.h ---------------------------------------*- C++ -*-===//
//
// This file declares the class OslScopStmtOpSet.
//
//===----------------------------------------------------------------------===//
#ifndef POLYMER_SUPPORT_OSLSCOPSTMTOPSET_H
#define POLYMER_SUPPORT_OSLSCOPSTMTOPSET_H

#include "llvm/ADT/SetVector.h"

using namespace llvm;

namespace mlir {
class Operation;
class LogicalResult;
class FlatAffineConstraints;
} // namespace mlir

namespace polymer {

/// This class contains a set of operations that will correspond to a single
/// OpenScop statement body. The underlying data structure is SetVector.
class OslScopStmtOpSet {
public:
using Set = SetVector<mlir::Operation *>;
using iterator = Set::iterator;
using reverse_iterator = Set::reverse_iterator;

OslScopStmtOpSet() {}

/// The core store op. There should be only one of it.
mlir::Operation *getStoreOp() { return storeOp; }

/// Insert.
void insert(mlir::Operation *op);

/// Count.
unsigned count(mlir::Operation *op) { return opSet.count(op); };

/// Size.
unsigned size() { return opSet.size(); }

/// Iterators.
iterator begin() { return opSet.begin(); }
iterator end() { return opSet.end(); }
reverse_iterator rbegin() { return opSet.rbegin(); }
reverse_iterator rend() { return opSet.rend(); }

mlir::Operation *get(unsigned i) { return opSet[i]; }

/// The domain of a stmtOpSet is the union of all load/store operations in
/// that set. We calculate such a union by concatenating the constraints of
/// domain defined by FlatAffineConstraints.
/// TODO: improve the interface.
mlir::LogicalResult getDomain(mlir::FlatAffineConstraints &domain);
mlir::LogicalResult
getDomain(mlir::FlatAffineConstraints &domain,
SmallVectorImpl<mlir::Operation *> &enclosingOps);

/// Get the enclosing operations for the opSet.
mlir::LogicalResult
getEnclosingOps(SmallVectorImpl<mlir::Operation *> &enclosingOps);

private:
Set opSet;

mlir::Operation *storeOp = nullptr;
};

} // namespace polymer

#endif
17 changes: 11 additions & 6 deletions include/polymer/Support/OslSymbolTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,33 @@ class Value;

namespace polymer {

class OslScopStmtOpSet;

class OslSymbolTable {
public:
enum SymbolType { LoopIV, Memref, StmtOp };
using OpSet = OslScopStmtOpSet;
using OpSetPtr = std::unique_ptr<OpSet>;

enum SymbolType { LoopIV, Memref, StmtOpSet };

Value getValue(StringRef key);

Operation *getOperation(StringRef key);
OpSet getOpSet(StringRef key);

void setValue(StringRef key, Value val, SymbolType type);

void setOperation(StringRef key, Operation *val, SymbolType type);
void setOpSet(StringRef key, OpSet val, SymbolType type);

unsigned getNumValues(SymbolType type);

unsigned getNumOperations(SymbolType type);
unsigned getNumOpSets(SymbolType type);

void getValueSymbols(SmallVectorImpl<StringRef> &symbols);

void getOperationSymbols(SmallVectorImpl<StringRef> &symbols);
void getOpSetSymbols(SmallVectorImpl<StringRef> &symbols);

private:
StringMap<Operation *> nameToStmtOp;
StringMap<OpSet> nameToStmtOpSet;
StringMap<Value> nameToLoopIV;
StringMap<Value> nameToMemref;
};
Expand Down
4 changes: 4 additions & 0 deletions lib/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
add_mlir_library(PolymerSupport
OslScop.cc
OslScopStmtOpSet.cc
OslSymbolTable.cc

ADDITIONAL_HEADER_DIRS
${POLYMER_MAIN_INCLUDE_DIR}/polymer/Support

LINK_LIBS PUBLIC
MLIRAnalysis
MLIRLoopAnalysis

libosl
libcloog
libisl
Expand Down
64 changes: 64 additions & 0 deletions lib/Support/OslScopStmtOpSet.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- OslScopStmtOpSet.cc --------------------------------------*- C++ -*-===//
//
// This file implements the class OslScopStmtOpSet.
//
//===----------------------------------------------------------------------===//

#include "polymer/Support/OslScopStmtOpSet.h"

#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"

using namespace llvm;
using namespace mlir;
using namespace polymer;

void OslScopStmtOpSet::insert(mlir::Operation *op) {
opSet.insert(op);
if (isa<mlir::AffineStoreOp>(op)) {
assert(!storeOp && "There should be only one AffineStoreOp in the set.");
storeOp = op;
}
}

LogicalResult OslScopStmtOpSet::getEnclosingOps(
SmallVectorImpl<mlir::Operation *> &enclosingOps) {
SmallVector<Operation *, 8> ops;
SmallPtrSet<Operation *, 8> visited;
for (auto op : opSet) {
if (isa<mlir::AffineLoadOp, mlir::AffineStoreOp>(op)) {
ops.clear();
getEnclosingAffineForAndIfOps(*op, &ops);
for (auto enclosingOp : ops) {
if (visited.find(enclosingOp) == visited.end()) {
visited.insert(enclosingOp);
enclosingOps.push_back(enclosingOp);
}
}
}
}

return success();
}

LogicalResult
OslScopStmtOpSet::getDomain(FlatAffineConstraints &domain,
SmallVectorImpl<mlir::Operation *> &enclosingOps) {
return getIndexSet(enclosingOps, &domain);
}

LogicalResult OslScopStmtOpSet::getDomain(FlatAffineConstraints &domain) {
SmallVector<Operation *, 8> enclosingOps;
if (failed(getEnclosingOps(enclosingOps)))
return failure();

return getDomain(domain, enclosingOps);
}
33 changes: 16 additions & 17 deletions lib/Support/OslSymbolTable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//===----------------------------------------------------------------------===//

#include "polymer/Support/OslSymbolTable.h"
#include "polymer/Support/OslScopStmtOpSet.h"

#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
Expand All @@ -26,12 +27,11 @@ Value OslSymbolTable::getValue(StringRef key) {
return nullptr;
}

Operation *OslSymbolTable::getOperation(StringRef key) {
OslSymbolTable::OpSet OslSymbolTable::getOpSet(StringRef key) {
// If key corresponds to an Op of a statement.
if (nameToStmtOp.find(key) != nameToStmtOp.end())
return nameToStmtOp.lookup(key);

return nullptr;
assert(nameToStmtOpSet.find(key) != nameToStmtOpSet.end() &&
"Key is not found.");
return nameToStmtOpSet.lookup(key);
}

void OslSymbolTable::setValue(StringRef key, Value val, SymbolType type) {
Expand All @@ -47,14 +47,13 @@ void OslSymbolTable::setValue(StringRef key, Value val, SymbolType type) {
}
}

void OslSymbolTable::setOperation(StringRef key, Operation *val,
SymbolType type) {
void OslSymbolTable::setOpSet(StringRef key, OpSet val, SymbolType type) {
switch (type) {
case StmtOp:
nameToStmtOp[key] = val;
case StmtOpSet:
nameToStmtOpSet[key] = val;
break;
default:
assert(false && "Symbole type for Operation not recognized.");
assert(false && "Symbole type for OpSet not recognized.");
}
}

Expand All @@ -69,12 +68,12 @@ unsigned OslSymbolTable::getNumValues(SymbolType type) {
}
}

unsigned OslSymbolTable::getNumOperations(SymbolType type) {
unsigned OslSymbolTable::getNumOpSets(SymbolType type) {
switch (type) {
case StmtOp:
return nameToStmtOp.size();
case StmtOpSet:
return nameToStmtOpSet.size();
default:
assert(false && "Symbole type for Operation not recognized.");
assert(false && "Symbole type for OpSet not recognized.");
}
}

Expand All @@ -86,10 +85,10 @@ void OslSymbolTable::getValueSymbols(SmallVectorImpl<StringRef> &symbols) {
for (auto &it : nameToMemref)
symbols.push_back(it.first());
}
void OslSymbolTable::getOperationSymbols(SmallVectorImpl<StringRef> &symbols) {
symbols.reserve(getNumOperations(StmtOp));
void OslSymbolTable::getOpSetSymbols(SmallVectorImpl<StringRef> &symbols) {
symbols.reserve(getNumOpSets(StmtOpSet));

for (auto &it : nameToStmtOp)
for (auto &it : nameToStmtOpSet)
symbols.push_back(it.first());
}
} // namespace polymer
52 changes: 19 additions & 33 deletions lib/Target/OpenScop/ConvertFromOpenScop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "osl/osl.h"

#include "polymer/Support/OslScop.h"
#include "polymer/Support/OslScopStmtOpSet.h"
#include "polymer/Support/OslSymbolTable.h"
#include "polymer/Target/OpenScop.h"

Expand Down Expand Up @@ -369,22 +370,19 @@ LogicalResult Importer::processStmt(clast_user_stmt *userStmt) {
// Create the callee.
// First, we create the callee function type.
unsigned numArgs = args.size();
llvm::SmallVector<mlir::Type, 8> calleeArgTypes(numArgs);
llvm::SmallVector<mlir::Type, 8> calleeArgTypes;

for (unsigned i = 0; i < numArgs; i++) {
if (isMemrefArg(args[i])) {
// Memref
// Memref. A memref name and its number of dimensions.
auto memName = args[i];
auto memShape = std::vector<int64_t>(numArgs - i - 1, -1);
auto memShape = std::vector<int64_t>(std::stoi(args[i + 1]), -1);
MemRefType memType = MemRefType::get(memShape, b.getF32Type());
calleeArgTypes[i] = memType;
} else if (isResultArg(args[i])) {
// Result from other statements.
// TODO: we just assume all data types are scalar float32.
calleeArgTypes[i] = b.getF32Type();
calleeArgTypes.push_back(memType);
i++;
} else {
// Loop IV.
calleeArgTypes[i] = b.getIndexType();
calleeArgTypes.push_back(b.getIndexType());
}
}

Expand All @@ -401,13 +399,13 @@ LogicalResult Importer::processStmt(clast_user_stmt *userStmt) {

// Initialise all the caller arguments. The first argument should be the
// memory object, which is set to be a BlockArgument.
llvm::SmallVector<mlir::Value, 8> callerArgs(numArgs);
llvm::SmallVector<mlir::Value, 8> callerArgs;
auto &entryBlock = *func.getBlocks().begin();

for (unsigned i = 0; i < numArgs; i++) {
if (isMemrefArg(args[i])) {
// TODO: refactorize this.
auto memShape = std::vector<int64_t>(numArgs - i - 1, -1);
auto memShape = std::vector<int64_t>(std::stoi(args[i + 1]), -1);
MemRefType memType = MemRefType::get(memShape, b.getF32Type());

// TODO: refactorize these two lines into a single API.
Expand All @@ -416,31 +414,16 @@ LogicalResult Importer::processStmt(clast_user_stmt *userStmt) {
memref = entryBlock.addArgument(memType);
symTable->setValue(args[i], memref, OslSymbolTable::Memref);
}
callerArgs[i] = memref;
} else if (isResultArg(args[i])) {
// TODO: remove this branch since it won't be triggered in the latest
// design.
auto srcOp = symTable->getOperation(args[i]);
if (!srcOp)
return failure();

auto caller = dyn_cast<mlir::CallOp>(srcOp);
auto srcCallee = dyn_cast<mlir::FuncOp>(calleeMap[caller.getCallee()]);
if (srcCallee.getNumResults() == 0) {
// TODO: still, we assume that the returned value is of type F32.
auto newCalleeType =
b.getFunctionType(srcCallee.getArgumentTypes(), b.getF32Type());
srcCallee.setType(newCalleeType);
}
callerArgs[i] = srcOp->getResult(0);
callerArgs.push_back(memref);
i++;
} else if (auto val = symTable->getValue(args[i])) {
// The rest of the arguments are access indices. They could be the loop
// IVs or the parameters. Loop IV
callerArgs[i] = val;
callerArgs.push_back(val);
// Symbol.
// TODO: manage sym name by the symTable.
} else if (symNameToArg.find(args[i]) != symNameToArg.end()) {
callerArgs[i] = symNameToArg.lookup(args[i]);
callerArgs.push_back(symNameToArg.lookup(args[i]));
// TODO: what if an index is a constant?
} else { // TODO: error handling
llvm::errs() << "Cannot find " << args[i]
Expand All @@ -455,7 +438,10 @@ LogicalResult Importer::processStmt(clast_user_stmt *userStmt) {
auto callOp = b.create<CallOp>(UnknownLoc::get(context), callee, callerArgs);

// Update StmtOpMap.
symTable->setOperation(calleeName, callOp, OslSymbolTable::StmtOp);
OslScopStmtOpSet opSet;
opSet.insert(callOp);
opSet.insert(callee);
symTable->setOpSet(calleeName, opSet, OslSymbolTable::StmtOpSet);

return success();
}
Expand Down Expand Up @@ -583,8 +569,8 @@ polymer::translateOpenScopToModule(std::unique_ptr<OslScop> scop,
FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));

OslSymbolTable symTable;
if (createFuncOpFromOpenScop(std::move(scop), module.get(), symTable,
context))
if (!createFuncOpFromOpenScop(std::move(scop), module.get(), symTable,
context))
return {};

return module;
Expand Down
Loading