Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

[WIP] make ngraph_bridge aware of storage type #146

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_mapRef,
const std::unordered_map<std::string, int>& arg_dtype_mapRef,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::unordered_map<std::string, int>& arg_stype_mapRef,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& shared_arg_names,
std::vector<NDArray>* in_arg_vec,
Expand All @@ -1040,13 +1040,14 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
// make copies so that ngraph compilation can modify shape / dtype
std::unordered_map<std::string, TShape> arg_shape_map = arg_shape_mapRef;
std::unordered_map<std::string, int> arg_dtype_map = arg_dtype_mapRef;
std::unordered_map<std::string, int> arg_stype_map = arg_stype_mapRef;

#if MXNET_USE_NGRAPH == 1
// TODO(mbrookhart): Remove this when hetr can handle multiple contexts
auto multi_context = multi_context_check(default_ctx, in_arg_ctxes,
arg_grad_ctxes, aux_state_ctxes);
ngraph_bridge::SimpleBindArg simplebind(num_forward_inputs_, arg_shape_map,
arg_dtype_map);
arg_dtype_map, arg_stype_map);
ngraph_bridge::Compiler compiler(
g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), simplebind,
default_ctx);
Expand All @@ -1057,6 +1058,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
// modify shape / dtype with ngraph version
arg_shape_map = compiler.GetNgraphShape();
arg_dtype_map = compiler.GetNgraphDtype();
arg_stype_map = compiler.GetNgraphStype();

// create "device" and "context" attrs for the graph
g = InitFullGraph(g, compiler.GetInputs(), grad_req_types);
Expand Down
37 changes: 23 additions & 14 deletions src/ngraph/ngraph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,20 @@ void Compiler::Infer(const BindArg* bind) {
if (mutable_nodes.count(nid)) {
shapes_.push_back(bind->aux_states_[aux_top].shape());
dtypes_.push_back(bind->aux_states_[aux_top].dtype());
stypes_.push_back(bind->aux_states_[aux_top].storage_type());
++aux_top;
} else {
shapes_.push_back(bind->in_args_[arg_top].shape());
dtypes_.push_back(bind->in_args_[arg_top].dtype());
stypes_.push_back(bind->in_args_[arg_top].storage_type());
++arg_top;
}
}

// append default shapes / dtypes so that vector size = graph size
// append default shapes / types so that vector size = graph size
shapes_.resize(idx.input_nodes().size(), nnvm::TShape());
dtypes_.resize(idx.input_nodes().size(), -1);
stypes_.resize(idx.num_node_entries(), mxnet::kDefaultStorage);
}

// infer nnvm::Graph shape and dtype for simple bind case
Expand All @@ -110,6 +113,7 @@ void Compiler::Infer(const SimpleBindArg* simplebind) {
const auto& idx = graph_.indexed_graph();
shapes_.resize(idx.input_nodes().size(), nnvm::TShape());
dtypes_.resize(idx.input_nodes().size(), -1);
stypes_.resize(idx.num_node_entries(), mxnet::kDefaultStorage);
size_t arg_top = 0, aux_top = 0;
for (size_t i = 0; i < simplebind->kNumForwardInputs; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
Expand All @@ -122,6 +126,10 @@ void Compiler::Infer(const SimpleBindArg* simplebind) {
if (simplebind->dtype_map_.end() != it2) {
dtypes_[i] = it2->second;
}
auto it3 = simplebind->stype_map_.find(name);
if (simplebind->stype_map_.end() != it3) {
stypes_[i] = it3->second;
}
}
}

Expand All @@ -136,6 +144,9 @@ Compiler::Compiler(const nnvm::Graph& graph, const NDArrayMap& feed_dict,
: ngraph_("ngraph_" + randomString(6), context) {
DeepCopy(graph);

graph_.attrs["context"] = std::make_shared<nnvm::any>(
mxnet::exec::ContextVector(graph_.indexed_graph().num_nodes(), context));

// infer nnvm::Graph shape and type
auto bind = dynamic_cast<const BindArg*>(&bindbase);
auto simplebind = dynamic_cast<const SimpleBindArg*>(&bindbase);
Expand All @@ -151,19 +162,11 @@ Compiler::Compiler(const nnvm::Graph& graph, const NDArrayMap& feed_dict,
void Compiler::ProcessGraph(const NDArrayMap& feed_dict) {
graph_ = mxnet::exec::InferShape(std::move(graph_), std::move(shapes_),
"__shape__");
// TODO(adstraw): may or may not need error checking
// if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
// HandleInferShapeError(num_forward_inputs, g.indexed_graph(),
// g.GetAttr<nnvm::ShapeVector>("shape"));
//}

graph_ = mxnet::exec::InferType(std::move(graph_), std::move(dtypes_),
"__dtype__");
// TODO(adstraw): may or may not need error checking
// if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
// HandleInferTypeError(num_forward_inputs, g.indexed_graph(),
// g.GetAttr<nnvm::DTypeVector>("dtype"));
//}

graph_.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes_));
graph_ = mxnet::exec::InferStorageType(std::move(graph_), std::move(mxnet::StorageTypeVector()), "");

MakeCopiedFeedDict(feed_dict);
ParseNnvmGraph();
Expand Down Expand Up @@ -201,6 +204,8 @@ nnvm::Graph Compiler::Compile() {
ngraph_shape_[node->name_] = node->shape_;
ngraph_dtype_[node->name_] = node->dtype_;
}
// TODO: all nodes, right?
ngraph_stype_[node->name_] = node->stype_;
}

// find the subgraphs
Expand Down Expand Up @@ -309,11 +314,13 @@ void Compiler::CheckInNgraph() {
if (node->type_ == NodeType::kOp) {
if (compiler_.ngraph_op_funcs_.count(node->operation_)) {
node->in_ngraph_ = true;
if (node->dtype_ == mshadow::kFloat16) {
if (node->dtype_ == mshadow::kFloat16 ||
node->stype_ != mxnet::kDefaultStorage) {
node->in_ngraph_ = false;
} else {
for (auto input : node->inputs_) {
if (input->dtype_ == mshadow::kFloat16) {
if (input->dtype_ == mshadow::kFloat16 ||
input->stype_ != mxnet::kDefaultStorage) {
node->in_ngraph_ = false;
}
}
Expand Down Expand Up @@ -392,11 +399,13 @@ void Compiler::ParseNnvmGraph() {
const auto inferred_shapes =
graph_.GetAttr<std::vector<nnvm::TShape>>("shape");
const auto inferred_dtypes = graph_.GetAttr<std::vector<int>>("dtype");
const auto inferred_stypes = graph_.GetAttr<std::vector<int>>("storage_type");
for (auto node : this->ngraph_.nodes_) {
const uint32_t nid = idx.node_id(node->orig_node_.get());
const uint32_t eid = idx.entry_id(nid, 0);
node->shape_ = inferred_shapes[eid];
node->dtype_ = inferred_dtypes[eid];
node->stype_ = inferred_stypes[nid]; // <- TODO: nid or eid?
}
}

Expand Down
25 changes: 16 additions & 9 deletions src/ngraph/ngraph_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using LayerGraphs = std::map<std::string, std::function<Graph(const NodePtr)>>;
using NodeMap = std::map<const nnvm::Node*, std::shared_ptr<nnvm::Node>>;
using NNVMNodeVec = std::vector<nnvm::NodePtr>;
using NgraphShape = std::unordered_map<std::string, nnvm::TShape>;
using NgraphDType = std::unordered_map<std::string, int>;
using NgraphType = std::unordered_map<std::string, int>;
using NDArrayMap = nnvm::NodeEntryMap<mxnet::NDArray>;
using StateMap = std::unordered_map<const nnvm::Node*, mxnet::OpStatePtr>;

Expand Down Expand Up @@ -71,14 +71,17 @@ struct BindArg : public BindArgBase {

// SimpleBind
struct SimpleBindArg : public BindArgBase {
SimpleBindArg(size_t numforward,
const std::unordered_map<std::string, nnvm::TShape>& shapes,
const std::unordered_map<std::string, int>& dtypes)
: BindArgBase(numforward), shape_map_(shapes), dtype_map_(dtypes) {}
SimpleBindArg(size_t numforward, const NgraphShape& shapes,
const NgraphType& dtypes, const NgraphType& stypes)
: BindArgBase(numforward),
shape_map_(shapes),
dtype_map_(dtypes),
stype_map_(stypes) {}

// simple bind arguments
const NgraphShape shape_map_;
const NgraphDType dtype_map_;
const NgraphType dtype_map_;
const NgraphType stype_map_;
};

// This is a compile-time hash map that contains information on
Expand Down Expand Up @@ -154,9 +157,10 @@ class Compiler {
void ParseNnvmGraph();

StateMap CopySavedStates(const StateMap& saved_states);
// Return maps of the shapes and dtypes for further analysis in graph_executor
// Return maps of the shapes and types for further analysis in graph_executor
const NgraphShape& GetNgraphShape() { return ngraph_shape_; }
const NgraphDType& GetNgraphDtype() { return ngraph_dtype_; }
const NgraphType& GetNgraphDtype() { return ngraph_dtype_; }
const NgraphType& GetNgraphStype() { return ngraph_stype_; }
// Return copies of the feed_dict and inputs to feed back into the
// graph executor inference engine
const NDArrayMap& GetFeedDict() { return feed_dict_; }
Expand Down Expand Up @@ -186,7 +190,8 @@ class Compiler {
ngraph_bridge::Graph ngraph_;
// shape and type maps to return to the graph executor
NgraphShape ngraph_shape_;
NgraphDType ngraph_dtype_;
NgraphType ngraph_dtype_;
NgraphType ngraph_stype_;
// copied feed dict and inputs
nnvm::NodeEntryMap<mxnet::NDArray> feed_dict_;
NNVMNodeVec inputs_;
Expand All @@ -200,6 +205,8 @@ class Compiler {
nnvm::ShapeVector shapes_;
// inferred nnvm::Graph dtype
nnvm::DTypeVector dtypes_;
// inferred nnvm::Graph storage type
nnvm::StorageVector stypes_;
};

} // namespace ngraph_bridge
Expand Down
2 changes: 2 additions & 0 deletions src/ngraph/ngraph_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define MXNET_NGRAPH_NGRAPH_GRAPH_H_

#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <nnvm/tuple.h>
Expand Down Expand Up @@ -80,6 +81,7 @@ class Node {
// mxnet type information
nnvm::TShape shape_;
int dtype_ = 0;
int stype_ = mxnet::kDefaultStorage;

// information to store graph parsing in
size_t multi_output_index_ = 0;
Expand Down
8 changes: 8 additions & 0 deletions src/ngraph/ngraph_imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ NGImperative::NGImperative(const nnvm::NodeAttrs &attrs,
for (auto i : inputs) {
shapes_.push_back(i.shape_);
dtypes_.push_back(i.type_flag_);
stypes_.push_back(mxnet::kDefaultStorage);
}

stypes_.push_back(mxnet::kDefaultStorage);

// initialize ngraph
DeepCopy(g);

graph_.attrs["context"] = std::make_shared<nnvm::any>(
mxnet::exec::ContextVector(graph_.indexed_graph().num_nodes(), ctx));

MakeCopiedInputs(sym.ListInputs(nnvm::Symbol::kReadOnlyArgs));
}

Expand Down
8 changes: 5 additions & 3 deletions tests/cpp/ngraph/test_ngraph_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ class NGRAPH_COMPILER : public ::testing::Test {
nnvm_graph.outputs.push_back(relu);

nnvm::TShape shape{2, 2};
std::unordered_map<std::string, int> dtypes;
std::unordered_map<std::string, nnvm::TShape> shapes;
std::unordered_map<std::string, int> dtypes;
std::unordered_map<std::string, int> stypes;

for (auto n : {A, B, C, D}) inputs.push_back(n.node);

for (auto n : {"A", "B", "C", "D"}) {
dtypes[n] = 0;
shapes[n] = shape;
dtypes[n] = 0;
stypes[n] = 0;
}
feed_dict[A] = mxnet::NDArray(shape, mxnet::Context());
bindarg = std::make_shared<ngraph_bridge::SimpleBindArg>(4, shapes, dtypes);
bindarg = std::make_shared<ngraph_bridge::SimpleBindArg>(4, shapes, dtypes, stypes);
}

virtual void TearDown() {}
Expand Down