Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Change-Id: Id69ae51bad05b85e95e35132ea43434a17c7a89e
  • Loading branch information
Giuseppe Rossini committed May 25, 2021
1 parent 836648e commit cfa991a
Showing 1 changed file with 69 additions and 64 deletions.
133 changes: 69 additions & 64 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ namespace relay {
namespace backend {

/**
* Struct to contain information about intermediate variables in the
* Struct to contain information about the intermediate tensors in the
* runner function
*/
struct StorageInfo {
/*! \brief unique integer identifier of the particular intermediate variable */
std::vector<int> ids;
/*! \brief storage integer identifier of the particular intermediate buffer */
int sid;
/*! \brief exact size of the temporary */
std::vector<int> sizes_bytes;
/*! \brief device type of the temporary variable */
std::vector<int> dev_types;
int size_bytes;
/*! \brief device type of the intermediate tensor */
int dev_type;
};

using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
using StorageMap =
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
using StorageMap = std::unordered_map<Expr, std::vector<StorageInfo>, runtime::ObjectPtrHash,
runtime::ObjectPtrEqual>;

/**
* This is an on demand allocator for AOT. A new temporary
Expand All @@ -74,26 +74,26 @@ class AOTOnDemandAllocator : public ExprVisitor {
node_device_map_ = CollectDeviceInfo(func);

for (Expr param : func->params) {
CreateSid(param.operator->());
CreateStorage(param.operator->());
}

GetSid(func->body);
GetStorage(func->body);
}

std::vector<int> GetReturnIds() const { return return_ids_; }

StorageMap GetStorageMap() const { return storage_device_map_; }

void VisitExpr_(const ConstantNode* op) final {
CreateSid(op);
CreateStorage(op);
AssignReturnSid(GetRef<Expr>(op));
}

void VisitExpr_(const CallNode* op) final {
// create token for the call node.
CreateSid(op);
CreateStorage(op);
for (Expr arg : op->args) {
GetSid(arg);
GetStorage(arg);
}
AssignReturnSid(GetRef<Expr>(op));
}
Expand All @@ -116,28 +116,22 @@ class AOTOnDemandAllocator : public ExprVisitor {
}

void VisitExpr_(const TupleNode* op) final {
StorageInfo field_sid;
std::vector<StorageInfo> field_sids;
Expr expr = GetRef<Expr>(op);
for (Expr field : op->fields) {
auto sid = GetSid(field);
field_sid.ids.insert(field_sid.ids.end(), sid.ids.begin(), sid.ids.end());
field_sid.dev_types.insert(field_sid.dev_types.end(), sid.dev_types.begin(),
sid.dev_types.end());
field_sid.sizes_bytes.insert(field_sid.sizes_bytes.end(), sid.sizes_bytes.begin(),
sid.sizes_bytes.end());
auto sid = GetStorage(field);
field_sids.insert(field_sids.end(), sid.begin(), sid.end());
}

storage_device_map_[expr] = field_sid;
storage_device_map_[expr] = field_sids;
AssignReturnSid(expr);
}

void VisitExpr_(const TupleGetItemNode* op) final {
Expr expr = GetRef<Expr>(op);
const auto& sid = GetSid(op->tuple);
ICHECK_LT(static_cast<size_t>(op->index), sid.ids.size());
storage_device_map_[expr].ids = {sid.ids[op->index]};
storage_device_map_[expr].sizes_bytes = {sid.sizes_bytes[op->index]};
storage_device_map_[expr].dev_types = {sid.dev_types[op->index]};
const auto& sids = GetStorage(op->tuple);
ICHECK_LT(static_cast<size_t>(op->index), sids.size());
storage_device_map_[expr] = {sids[op->index]};
AssignReturnSid(expr);
}

Expand All @@ -147,9 +141,13 @@ class AOTOnDemandAllocator : public ExprVisitor {

private:
void AssignReturnSid(Expr e) {
auto iter = storage_device_map_.find(e);
if (iter != storage_device_map_.end()) {
return_ids_ = (*iter).second.ids;
if (storage_device_map_.find(e) != storage_device_map_.end()) {
auto buffers = storage_device_map_[e];
std::vector<int> return_ids;
for (auto buffer : buffers) {
return_ids.push_back(buffer.sid);
}
return_ids_ = return_ids;
}
}
/*!
Expand Down Expand Up @@ -178,45 +176,53 @@ class AOTOnDemandAllocator : public ExprVisitor {
return size;
}
/*!
* \brief Get the necessary token.
* \brief Get the necessary storage for the expression.
* \param expr The expression.
* \return The corresponding token.
*/
StorageInfo GetSid(const Expr& expr) {
std::vector<StorageInfo> GetStorage(const Expr& expr) {
this->VisitExpr(expr);
auto it = storage_device_map_.find(expr);
ICHECK(it != storage_device_map_.end());
return it->second;
}

void CreateSid(const ExprNode* op) {
StorageInfo sid;
/*!
* \brief Create storage for the expression.
* \param expr The expression.
*/
void CreateStorage(const ExprNode* op) {
std::vector<StorageInfo> buffers;
Expr expr = GetRef<Expr>(op);
int device_type = node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
ICHECK(ttype);
sid.ids.push_back(sid_++);
sid.dev_types.push_back(device_type);
sid.sizes_bytes.push_back(GetMemorySize(ttype));
StorageInfo buffer;
buffer.sid = sid_++;
buffer.size_bytes = GetMemorySize(ttype);
buffer.dev_type = device_type;
buffers.push_back(buffer);
}
} else {
const auto* ttype = op->checked_type().as<TensorTypeNode>();
ICHECK(ttype);
sid.ids.push_back(sid_++);
sid.dev_types.push_back(device_type);
sid.sizes_bytes.push_back(GetMemorySize(ttype));
StorageInfo buffer;
buffer.sid = sid_++;
buffer.size_bytes = GetMemorySize(ttype);
buffer.dev_type = device_type;
buffers.push_back(buffer);
}
storage_device_map_[expr] = sid;
storage_device_map_[expr] = buffers;
}
/*! \brief mapping of expression -> storageInfo*/
StorageMap storage_device_map_;
/*! \brief mapping of expression -> device type*/
Map<Expr, Integer> node_device_map_;
/*! \brief current id of the temporary allocated*/
int sid_{0};
/*! \brief the set of identifiers that are return variables */
/*! \brief the set of intermediate tensors that are return variables */
std::vector<int> return_ids_;
};

Expand Down Expand Up @@ -248,18 +254,18 @@ class AOTExecutorCodegen : public ExprVisitor {
* \brief Return a vector of variables that represents the sids for the given Relay Expr
*/
std::vector<tir::Var> PackSid(Expr expr) {
auto sids = storage_device_map_[expr];
std::vector<tir::Var> sid_vars;
auto buffers = storage_device_map_[expr];
std::vector<tir::Var> buffer_vars;

// Note that an expression can have multiple sids associated with it
// e.g., returning multiple values from a function
for (const auto& sid : sids.ids) {
for (const auto& buffer : buffers) {
// Determine if an sid is an output buffer
int sid_int = sid;
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
int sid = buffer.sid;
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
sid_vars.push_back(main_signature_[input_vars_.size() + output_index]);
buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]);
continue;
}
// Pack the sid inside the TVMValue
Expand All @@ -269,9 +275,9 @@ class AOTExecutorCodegen : public ExprVisitor {
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, sid_value});
stmts_.push_back(tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
sid_vars.push_back(sid_array);
buffer_vars.push_back(sid_array);
}
return sid_vars;
return buffer_vars;
}

/*!
Expand Down Expand Up @@ -518,8 +524,7 @@ class AOTExecutorCodegen : public ExprVisitor {
}

ICHECK_GE(storage_device_map_.count(expr), 0);
auto& device_type = storage_device_map_[expr].dev_types;
auto call_dev_type = device_type[0];
auto call_dev_type = storage_device_map_[expr][0].dev_type;
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
Expand Down Expand Up @@ -556,14 +561,14 @@ class AOTExecutorCodegen : public ExprVisitor {

// If the Var node is an output node we need to copy the content of the variable to the output
// It's safe to check the SID here because Var StorageToken are never reallocated
auto sids = storage_device_map_[expr];
auto buffers = storage_device_map_[expr];

auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
auto var_expr = FindExpr(expr);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
sids.sizes_bytes[0]);
buffers[0].size_bytes);
}
}

Expand All @@ -572,18 +577,18 @@ class AOTExecutorCodegen : public ExprVisitor {
size_t index = params_.size();
std::string name = "p" + std::to_string(index);

param_storage_ids_[name] = storage_device_map_[expr].ids[0];
param_storage_ids_[name] = storage_device_map_[expr][0].sid;
params_[name] = op->data;
params_by_expr_.Set(expr, name);

// If the Constant node is an output node we need to copy the content of the parameter to the
// output A Var node can only produce a single output
auto sids = storage_device_map_[expr];
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
auto buffers = storage_device_map_[expr];
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr),
sids.sizes_bytes[0]);
buffers[0].size_bytes);
}
}

Expand Down Expand Up @@ -639,9 +644,9 @@ class AOTExecutorCodegen : public ExprVisitor {
continue;
}

for (unsigned int i = 0; i < kv.second.ids.size(); i++) {
int size = kv.second.sizes_bytes[i];
int sid = kv.second.ids[i];
for (unsigned int i = 0; i < kv.second.size(); i++) {
int size = kv.second[i].size_bytes;
int sid = kv.second[i].sid;

if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
continue;
Expand Down Expand Up @@ -733,9 +738,9 @@ class AOTExecutorCodegen : public ExprVisitor {

// Define the storage allocator ids
for (auto kv : storage_device_map_) {
for (const auto& sid : kv.second.ids) {
te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
sids_table_[sid] = sid_var;
for (const auto& buffer : kv.second) {
te::Var buffer_var(MakeString("sid_", buffer.sid), PointerType(PrimType(DataType::Int(8))));
sids_table_[buffer.sid] = buffer_var;
}
}

Expand Down

0 comments on commit cfa991a

Please sign in to comment.