From f679b04fb7deb9285fea1f961b45e04d21a98093 Mon Sep 17 00:00:00 2001 From: adstraw Date: Mon, 12 Feb 2018 12:45:43 -0800 Subject: [PATCH 1/6] make ngraph_bridge aware of storage type do not compile non-default storage type with nGraph --- src/executor/graph_executor.cc | 14 +++++++- src/ngraph/ngraph_compiler.cc | 64 +++++++++++++++++++--------------- src/ngraph/ngraph_compiler.h | 42 ++++++++++++---------- src/ngraph/ngraph_graph.h | 3 ++ 4 files changed, 76 insertions(+), 47 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 98b57716d..b9b007179 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -571,6 +571,11 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_grad_ctxes, aux_state_ctxes); ngraph_bridge::BindArg bind(num_forward_inputs_, in_args, aux_states); + + // assign default context + g.attrs["context"] = std::make_shared( + ContextVector(g.indexed_graph().num_nodes(), default_ctx)); + ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), bind, default_ctx); @@ -1045,8 +1050,14 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // TODO(mbrookhart): Remove this when hetr can handle multiple contexts auto multi_context = multi_context_check(default_ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes); + ngraph_bridge::SimpleBindArg simplebind(num_forward_inputs_, arg_shape_map, - arg_dtype_map); + arg_dtype_map, arg_stype_map); + + // assign default context + g.attrs["context"] = std::make_shared( + ContextVector(g.indexed_graph().num_nodes(), default_ctx)); + ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), simplebind, default_ctx); @@ -1738,3 +1749,4 @@ Executor *Executor::Bind(nnvm::Symbol symbol, return exec; } } // namespace mxnet + diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index f268d8d58..eb325e4cd 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -1,18 +1,18 @@ /******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ #include @@ -91,17 +91,20 @@ void Compiler::Infer(const BindArg* bind) { if (mutable_nodes.count(nid)) { shapes_.push_back(bind->aux_states_[aux_top].shape()); dtypes_.push_back(bind->aux_states_[aux_top].dtype()); + stypes_.push_back(bind->aux_states_[aux_top].storage_type()); ++aux_top; } else { shapes_.push_back(bind->in_args_[arg_top].shape()); dtypes_.push_back(bind->in_args_[arg_top].dtype()); + stypes_.push_back(bind->in_args_[arg_top].storage_type()); ++arg_top; } } - // append default shapes / dtypes so that vector size = graph size + // append default shapes / types so that vector size = graph size shapes_.resize(idx.input_nodes().size(), nnvm::TShape()); dtypes_.resize(idx.input_nodes().size(), -1); + stypes_.resize(idx.input_nodes().size(), mxnet::kUndefinedStorage); } // infer nnvm::Graph shape and dtype for simple bind case @@ -110,6 +113,7 @@ void Compiler::Infer(const SimpleBindArg* simplebind) { const auto& idx = graph_.indexed_graph(); shapes_.resize(idx.input_nodes().size(), nnvm::TShape()); dtypes_.resize(idx.input_nodes().size(), -1); + stypes_.resize(idx.input_nodes().size(), mxnet::kUndefinedStorage); size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < simplebind->kNumForwardInputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); @@ -122,6 +126,10 @@ void Compiler::Infer(const SimpleBindArg* simplebind) { if (simplebind->dtype_map_.end() != it2) { dtypes_[i] = it2->second; } + auto it3 = simplebind->stype_map_.find(name); + if (simplebind->stype_map_.end() != it3) { + dtypes_[i] = it3->second; + } } } @@ -151,19 +159,13 @@ Compiler::Compiler(const nnvm::Graph& graph, const NDArrayMap& feed_dict, void Compiler::ProcessGraph(const NDArrayMap& feed_dict) { graph_ = mxnet::exec::InferShape(std::move(graph_), std::move(shapes_), "__shape__"); - // TODO(adstraw): may or may not need error checking - // if (g.GetAttr("shape_num_unknown_nodes") != 0U) { - // HandleInferShapeError(num_forward_inputs, g.indexed_graph(), - // g.GetAttr("shape")); - //} - graph_ = mxnet::exec::InferType(std::move(graph_), std::move(dtypes_), "__dtype__"); - // TODO(adstraw): may or may not need error checking - // if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { - // HandleInferTypeError(num_forward_inputs, g.indexed_graph(), - // g.GetAttr("dtype")); - //} + + graph_.attrs["storage_type"] = + std::make_shared(std::move(stypes_)); + mxnet::StorageTypeVector stv; + graph_ = mxnet::exec::InferStorageType(std::move(graph_), std::move(stv), ""); MakeCopiedFeedDict(feed_dict); ParseNnvmGraph(); @@ -298,6 +300,8 @@ void Compiler::DeepCopy(const nnvm::Graph& graph) { for (auto& input : kv.second->inputs) input.node = node_map_[input.node.get()]; + graph_.attrs = graph.attrs; + // set the output graph to use the copied nodes graph_.outputs = graph.outputs; for (auto& out : graph_.outputs) out.node = node_map_[out.node.get()]; @@ -309,11 +313,13 @@ void Compiler::CheckInNgraph() { if (node->type_ == NodeType::kOp) { if (compiler_.ngraph_op_funcs_.count(node->operation_)) { node->in_ngraph_ = true; - if (node->dtype_ == mshadow::kFloat16) { + if (node->dtype_ == mshadow::kFloat16 || + node->stype_ != mxnet::kDefaultStorage) { node->in_ngraph_ = false; } else { for (auto input : node->inputs_) { - if (input->dtype_ == mshadow::kFloat16) { + if (input->dtype_ == mshadow::kFloat16 || + input->stype_ != mxnet::kDefaultStorage) { node->in_ngraph_ = false; } } @@ -392,11 +398,13 @@ void Compiler::ParseNnvmGraph() { const auto inferred_shapes = graph_.GetAttr>("shape"); const auto inferred_dtypes = graph_.GetAttr>("dtype"); + const auto inferred_stypes = graph_.GetAttr>("storage_type"); for (auto node : this->ngraph_.nodes_) { const uint32_t nid = idx.node_id(node->orig_node_.get()); const uint32_t eid = idx.entry_id(nid, 0); node->shape_ = inferred_shapes[eid]; node->dtype_ = inferred_dtypes[eid]; + node->stype_ = inferred_stypes[eid]; } } diff --git a/src/ngraph/ngraph_compiler.h b/src/ngraph/ngraph_compiler.h index f0527511f..6e38f98f0 100644 --- a/src/ngraph/ngraph_compiler.h +++ b/src/ngraph/ngraph_compiler.h @@ -1,18 +1,18 @@ /******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ #ifndef MXNET_NGRAPH_NGRAPH_COMPILER_H_ #define MXNET_NGRAPH_NGRAPH_COMPILER_H_ @@ -41,6 +41,7 @@ using NodeMap = std::map>; using NNVMNodeVec = std::vector; using NgraphShape = std::unordered_map; using NgraphDType = std::unordered_map; +using NgraphSType = std::unordered_map; using NDArrayMap = nnvm::NodeEntryMap; using StateMap = std::unordered_map; @@ -71,14 +72,17 @@ struct BindArg : public BindArgBase { // SimpleBind struct SimpleBindArg : public BindArgBase { - SimpleBindArg(size_t numforward, - const std::unordered_map& shapes, - const std::unordered_map& dtypes) - : BindArgBase(numforward), shape_map_(shapes), dtype_map_(dtypes) {} + SimpleBindArg(size_t numforward, const NgraphShape& shapes, + const NgraphDType& dtypes, const NgraphSType& stypes) + : BindArgBase(numforward), + shape_map_(shapes), + dtype_map_(dtypes), + stype_map_(stypes) {} // simple bind arguments const NgraphShape shape_map_; const NgraphDType dtype_map_; + const NgraphSType stype_map_; }; // This is a compile-time hash map that contains information on @@ -200,6 +204,8 @@ class Compiler { nnvm::ShapeVector shapes_; // inferred nnvm::Graph dtype nnvm::DTypeVector dtypes_; + // inferred nnvm::Graph stype + nnvm::StorageVector stypes_; }; } // namespace ngraph_bridge diff --git a/src/ngraph/ngraph_graph.h b/src/ngraph/ngraph_graph.h index a94989c9d..e94b9df80 100644 --- a/src/ngraph/ngraph_graph.h +++ b/src/ngraph/ngraph_graph.h @@ -18,6 +18,7 @@ #define MXNET_NGRAPH_NGRAPH_GRAPH_H_ #include +#include #include #include #include @@ -80,6 +81,7 @@ class Node { // mxnet type information nnvm::TShape shape_; int dtype_ = 0; + int stype_ = mxnet::kUndefinedStorage; // information to store graph parsing in size_t multi_output_index_ = 0; @@ -291,3 +293,4 @@ void GraphTraverse(NodePtr node, const GraphVisitor &visitor); } // namespace ngraph_bridge #endif // MXNET_NGRAPH_NGRAPH_GRAPH_H_ + From f6a0cede3a11d3095860dc53a5a5eff577946669 Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 14 Feb 2018 15:27:09 -0800 Subject: [PATCH 2/6] fix unit test compile and non-imperative test cases --- src/executor/graph_executor.cc | 8 ------ src/ngraph/ngraph_compiler.cc | 17 ++++++++----- src/ngraph/ngraph_graph.h | 2 +- src/ngraph/ngraph_imperative.cc | 33 ++++++++++++++----------- tests/cpp/ngraph/test_ngraph_compiler.h | 8 +++--- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index b9b007179..9a568461c 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -572,10 +572,6 @@ void GraphExecutor::Init(nnvm::Symbol symbol, ngraph_bridge::BindArg bind(num_forward_inputs_, in_args, aux_states); - // assign default context - g.attrs["context"] = std::make_shared( - ContextVector(g.indexed_graph().num_nodes(), default_ctx)); - ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), bind, default_ctx); @@ -1054,10 +1050,6 @@ void GraphExecutor::Init(nnvm::Symbol symbol, ngraph_bridge::SimpleBindArg simplebind(num_forward_inputs_, arg_shape_map, arg_dtype_map, arg_stype_map); - // assign default context - g.attrs["context"] = std::make_shared( - ContextVector(g.indexed_graph().num_nodes(), default_ctx)); - ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), simplebind, default_ctx); diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index eb325e4cd..2e95b358c 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -104,7 +104,7 @@ void Compiler::Infer(const BindArg* bind) { // append default shapes / types so that vector size = graph size shapes_.resize(idx.input_nodes().size(), nnvm::TShape()); dtypes_.resize(idx.input_nodes().size(), -1); - stypes_.resize(idx.input_nodes().size(), mxnet::kUndefinedStorage); + stypes_.resize(idx.num_node_entries(), mxnet::kDefaultStorage); } // infer nnvm::Graph shape and dtype for simple bind case @@ -113,7 +113,7 @@ void Compiler::Infer(const SimpleBindArg* simplebind) { const auto& idx = graph_.indexed_graph(); shapes_.resize(idx.input_nodes().size(), nnvm::TShape()); dtypes_.resize(idx.input_nodes().size(), -1); - stypes_.resize(idx.input_nodes().size(), mxnet::kUndefinedStorage); + stypes_.resize(idx.num_node_entries(), mxnet::kDefaultStorage); size_t arg_top = 0, aux_top = 0; for (size_t i = 0; i < simplebind->kNumForwardInputs; ++i) { const uint32_t nid = idx.input_nodes().at(i); @@ -128,7 +128,7 @@ void Compiler::Infer(const SimpleBindArg* simplebind) { } auto it3 = simplebind->stype_map_.find(name); if (simplebind->stype_map_.end() != it3) { - dtypes_[i] = it3->second; + stypes_[i] = it3->second; } } } @@ -144,6 +144,9 @@ Compiler::Compiler(const nnvm::Graph& graph, const NDArrayMap& feed_dict, : ngraph_("ngraph_" + randomString(6), context) { DeepCopy(graph); + graph_.attrs["context"] = std::make_shared( + mxnet::exec::ContextVector(graph_.indexed_graph().num_nodes(), context)); + // infer nnvm::Graph shape and type auto bind = dynamic_cast(&bindbase); auto simplebind = dynamic_cast(&bindbase); @@ -161,10 +164,12 @@ void Compiler::ProcessGraph(const NDArrayMap& feed_dict) { "__shape__"); graph_ = mxnet::exec::InferType(std::move(graph_), std::move(dtypes_), "__dtype__"); - + + mxnet::StorageTypeVector stv = stypes_; graph_.attrs["storage_type"] = std::make_shared(std::move(stypes_)); - mxnet::StorageTypeVector stv; + + stv = stypes_; graph_ = mxnet::exec::InferStorageType(std::move(graph_), std::move(stv), ""); MakeCopiedFeedDict(feed_dict); @@ -404,7 +409,7 @@ void Compiler::ParseNnvmGraph() { const uint32_t eid = idx.entry_id(nid, 0); node->shape_ = inferred_shapes[eid]; node->dtype_ = inferred_dtypes[eid]; - node->stype_ = inferred_stypes[eid]; + node->stype_ = inferred_stypes[nid]; // <- TODO: nid or eid? } } diff --git a/src/ngraph/ngraph_graph.h b/src/ngraph/ngraph_graph.h index e94b9df80..57fd8878b 100644 --- a/src/ngraph/ngraph_graph.h +++ b/src/ngraph/ngraph_graph.h @@ -81,7 +81,7 @@ class Node { // mxnet type information nnvm::TShape shape_; int dtype_ = 0; - int stype_ = mxnet::kUndefinedStorage; + int stype_ = mxnet::kDefaultStorage; // information to store graph parsing in size_t multi_output_index_ = 0; diff --git a/src/ngraph/ngraph_imperative.cc b/src/ngraph/ngraph_imperative.cc index bab5e3020..611efc59e 100644 --- a/src/ngraph/ngraph_imperative.cc +++ b/src/ngraph/ngraph_imperative.cc @@ -1,18 +1,18 @@ /******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ + * Copyright 2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ #include #include @@ -65,9 +65,14 @@ NGImperative::NGImperative(const nnvm::NodeAttrs &attrs, for (auto i : inputs) { shapes_.push_back(i.shape_); dtypes_.push_back(i.type_flag_); + stypes_.push_back(mxnet::kDefaultStorage); // <- TODO } // initialize ngraph DeepCopy(g); + + graph_.attrs["context"] = std::make_shared( + mxnet::exec::ContextVector(graph_.indexed_graph().num_nodes(), ctx)); + MakeCopiedInputs(sym.ListInputs(nnvm::Symbol::kReadOnlyArgs)); } diff --git a/tests/cpp/ngraph/test_ngraph_compiler.h b/tests/cpp/ngraph/test_ngraph_compiler.h index 3ea3a95ab..ecd564280 100644 --- a/tests/cpp/ngraph/test_ngraph_compiler.h +++ b/tests/cpp/ngraph/test_ngraph_compiler.h @@ -61,17 +61,19 @@ class NGRAPH_COMPILER : public ::testing::Test { nnvm_graph.outputs.push_back(relu); nnvm::TShape shape{2, 2}; - std::unordered_map dtypes; std::unordered_map shapes; + std::unordered_map dtypes; + std::unordered_map stypes; for (auto n : {A, B, C, D}) inputs.push_back(n.node); for (auto n : {"A", "B", "C", "D"}) { - dtypes[n] = 0; shapes[n] = shape; + dtypes[n] = 0; + stypes[n] = 0; } feed_dict[A] = mxnet::NDArray(shape, mxnet::Context()); - bindarg = std::make_shared(4, shapes, dtypes); + bindarg = std::make_shared(4, shapes, dtypes, stypes); } virtual void TearDown() {} From e1927909997e9245c096c2a27f30ee32c21891c1 Mon Sep 17 00:00:00 2001 From: adstraw Date: Thu, 15 Feb 2018 17:19:47 -0800 Subject: [PATCH 3/6] return modified storage type to executor imperative tests passing with a hack --- src/executor/graph_executor.cc | 4 +++- src/ngraph/ngraph_compiler.cc | 1 + src/ngraph/ngraph_compiler.h | 6 ++++-- src/ngraph/ngraph_imperative.cc | 5 ++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 9a568461c..13ddf3dfe 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1026,7 +1026,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_mapRef, const std::unordered_map& arg_dtype_mapRef, - const std::unordered_map& arg_stype_map, + const std::unordered_map& arg_stype_mapRef, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -1041,6 +1041,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // make copies so that ngraph compilation can modify shape / dtype std::unordered_map arg_shape_map = arg_shape_mapRef; std::unordered_map arg_dtype_map = arg_dtype_mapRef; + std::unordered_map arg_stype_map = arg_stype_mapRef; #if MXNET_USE_NGRAPH == 1 // TODO(mbrookhart): Remove this when hetr can handle multiple contexts @@ -1060,6 +1061,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // modify shape / dtype with ngraph version arg_shape_map = compiler.GetNgraphShape(); arg_dtype_map = compiler.GetNgraphDtype(); + arg_stype_map = compiler.GetNgraphStype(); // create "device" and "context" attrs for the graph g = InitFullGraph(g, compiler.GetInputs(), grad_req_types); diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index 2e95b358c..27eac5db2 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -208,6 +208,7 @@ nnvm::Graph Compiler::Compile() { ngraph_shape_[node->name_] = node->shape_; ngraph_dtype_[node->name_] = node->dtype_; } + ngraph_stype_[node->name_] = node->stype_; } // find the subgraphs diff --git a/src/ngraph/ngraph_compiler.h b/src/ngraph/ngraph_compiler.h index 6e38f98f0..ce0bf1add 100644 --- a/src/ngraph/ngraph_compiler.h +++ b/src/ngraph/ngraph_compiler.h @@ -158,9 +158,10 @@ class Compiler { void ParseNnvmGraph(); StateMap CopySavedStates(const StateMap& saved_states); - // Return maps of the shapes and dtypes for further analysis in graph_executor + // Return maps of the shapes and types for further analysis in graph_executor const NgraphShape& GetNgraphShape() { return ngraph_shape_; } const NgraphDType& GetNgraphDtype() { return ngraph_dtype_; } + const NgraphDType& GetNgraphStype() { return ngraph_stype_; } // Return copies of the feed_dict and inputs to feed back into the // graph executor inference engine const NDArrayMap& GetFeedDict() { return feed_dict_; } @@ -191,6 +192,7 @@ class Compiler { // shape and type maps to return to the graph executor NgraphShape ngraph_shape_; NgraphDType ngraph_dtype_; + NgraphDType ngraph_stype_; // copied feed dict and inputs nnvm::NodeEntryMap feed_dict_; NNVMNodeVec inputs_; @@ -204,7 +206,7 @@ class Compiler { nnvm::ShapeVector shapes_; // inferred nnvm::Graph dtype nnvm::DTypeVector dtypes_; - // inferred nnvm::Graph stype + // inferred nnvm::Graph storage type nnvm::StorageVector stypes_; }; diff --git a/src/ngraph/ngraph_imperative.cc b/src/ngraph/ngraph_imperative.cc index 611efc59e..fe52b55ca 100644 --- a/src/ngraph/ngraph_imperative.cc +++ b/src/ngraph/ngraph_imperative.cc @@ -65,8 +65,11 @@ NGImperative::NGImperative(const nnvm::NodeAttrs &attrs, for (auto i : inputs) { shapes_.push_back(i.shape_); dtypes_.push_back(i.type_flag_); - stypes_.push_back(mxnet::kDefaultStorage); // <- TODO + stypes_.push_back(mxnet::kDefaultStorage); } + + stypes_.resize(100, mxnet::kDefaultStorage); // TODO: HACK + // initialize ngraph DeepCopy(g); From 1010dd1efea7bf1b19da1834d1d47539f1cf9955 Mon Sep 17 00:00:00 2001 From: adstraw Date: Fri, 16 Feb 2018 12:58:52 -0800 Subject: [PATCH 4/6] take care of some TODO's + cleanup --- src/executor/graph_executor.cc | 4 ---- src/ngraph/ngraph_compiler.cc | 4 ++-- src/ngraph/ngraph_graph.h | 1 - src/ngraph/ngraph_imperative.cc | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 13ddf3dfe..a61e7e6af 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -571,7 +571,6 @@ void GraphExecutor::Init(nnvm::Symbol symbol, arg_grad_ctxes, aux_state_ctxes); ngraph_bridge::BindArg bind(num_forward_inputs_, in_args, aux_states); - ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), bind, default_ctx); @@ -1047,10 +1046,8 @@ void GraphExecutor::Init(nnvm::Symbol symbol, // TODO(mbrookhart): Remove this when hetr can handle multiple contexts auto multi_context = multi_context_check(default_ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes); - ngraph_bridge::SimpleBindArg simplebind(num_forward_inputs_, arg_shape_map, arg_dtype_map, arg_stype_map); - ngraph_bridge::Compiler compiler( g, feed_dict, symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs), simplebind, default_ctx); @@ -1743,4 +1740,3 @@ Executor *Executor::Bind(nnvm::Symbol symbol, return exec; } } // namespace mxnet - diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index 27eac5db2..7175899e6 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -165,6 +165,7 @@ void Compiler::ProcessGraph(const NDArrayMap& feed_dict) { graph_ = mxnet::exec::InferType(std::move(graph_), std::move(dtypes_), "__dtype__"); + // TODO: this method does not match executor exactly mxnet::StorageTypeVector stv = stypes_; graph_.attrs["storage_type"] = std::make_shared(std::move(stypes_)); @@ -208,6 +209,7 @@ nnvm::Graph Compiler::Compile() { ngraph_shape_[node->name_] = node->shape_; ngraph_dtype_[node->name_] = node->dtype_; } + // TODO: all nodes, right? ngraph_stype_[node->name_] = node->stype_; } @@ -306,8 +308,6 @@ void Compiler::DeepCopy(const nnvm::Graph& graph) { for (auto& input : kv.second->inputs) input.node = node_map_[input.node.get()]; - graph_.attrs = graph.attrs; - // set the output graph to use the copied nodes graph_.outputs = graph.outputs; for (auto& out : graph_.outputs) out.node = node_map_[out.node.get()]; diff --git a/src/ngraph/ngraph_graph.h b/src/ngraph/ngraph_graph.h index 57fd8878b..fadb9d3af 100644 --- a/src/ngraph/ngraph_graph.h +++ b/src/ngraph/ngraph_graph.h @@ -293,4 +293,3 @@ void GraphTraverse(NodePtr node, const GraphVisitor &visitor); } // namespace ngraph_bridge #endif // MXNET_NGRAPH_NGRAPH_GRAPH_H_ - diff --git a/src/ngraph/ngraph_imperative.cc b/src/ngraph/ngraph_imperative.cc index fe52b55ca..a2342bfc3 100644 --- a/src/ngraph/ngraph_imperative.cc +++ b/src/ngraph/ngraph_imperative.cc @@ -68,7 +68,7 @@ NGImperative::NGImperative(const nnvm::NodeAttrs &attrs, stypes_.push_back(mxnet::kDefaultStorage); } - stypes_.resize(100, mxnet::kDefaultStorage); // TODO: HACK + stypes_.push_back(mxnet::kDefaultStorage); // initialize ngraph DeepCopy(g); From 5a9ac9bfe0898dac26b64286b5597741577583a9 Mon Sep 17 00:00:00 2001 From: adstraw Date: Fri, 16 Feb 2018 13:12:31 -0800 Subject: [PATCH 5/6] more cleanup --- src/ngraph/ngraph_compiler.cc | 28 ++++++++++++++-------------- src/ngraph/ngraph_compiler.h | 28 ++++++++++++++-------------- src/ngraph/ngraph_imperative.cc | 28 ++++++++++++++-------------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index 7175899e6..7843db654 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -1,18 +1,18 @@ /******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #include diff --git a/src/ngraph/ngraph_compiler.h b/src/ngraph/ngraph_compiler.h index ce0bf1add..b84ac9fb6 100644 --- a/src/ngraph/ngraph_compiler.h +++ b/src/ngraph/ngraph_compiler.h @@ -1,18 +1,18 @@ /******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #ifndef MXNET_NGRAPH_NGRAPH_COMPILER_H_ #define MXNET_NGRAPH_NGRAPH_COMPILER_H_ diff --git a/src/ngraph/ngraph_imperative.cc b/src/ngraph/ngraph_imperative.cc index a2342bfc3..802ad244c 100644 --- a/src/ngraph/ngraph_imperative.cc +++ b/src/ngraph/ngraph_imperative.cc @@ -1,18 +1,18 @@ /******************************************************************************* - * Copyright 2018 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #include #include From 85fe2b981449cbc79b15fbc2072e22524398dfdf Mon Sep 17 00:00:00 2001 From: adstraw Date: Fri, 16 Feb 2018 13:24:03 -0800 Subject: [PATCH 6/6] more cleanup --- src/ngraph/ngraph_compiler.cc | 9 ++------- src/ngraph/ngraph_compiler.h | 17 ++++++++--------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/ngraph/ngraph_compiler.cc b/src/ngraph/ngraph_compiler.cc index 7843db654..ee82d561c 100644 --- a/src/ngraph/ngraph_compiler.cc +++ b/src/ngraph/ngraph_compiler.cc @@ -165,13 +165,8 @@ void Compiler::ProcessGraph(const NDArrayMap& feed_dict) { graph_ = mxnet::exec::InferType(std::move(graph_), std::move(dtypes_), "__dtype__"); - // TODO: this method does not match executor exactly - mxnet::StorageTypeVector stv = stypes_; - graph_.attrs["storage_type"] = - std::make_shared(std::move(stypes_)); - - stv = stypes_; - graph_ = mxnet::exec::InferStorageType(std::move(graph_), std::move(stv), ""); + graph_.attrs["storage_type"] = std::make_shared(std::move(stypes_)); + graph_ = mxnet::exec::InferStorageType(std::move(graph_), std::move(mxnet::StorageTypeVector()), ""); MakeCopiedFeedDict(feed_dict); ParseNnvmGraph(); diff --git a/src/ngraph/ngraph_compiler.h b/src/ngraph/ngraph_compiler.h index b84ac9fb6..592448e9a 100644 --- a/src/ngraph/ngraph_compiler.h +++ b/src/ngraph/ngraph_compiler.h @@ -40,8 +40,7 @@ using LayerGraphs = std::map>; using NodeMap = std::map>; using NNVMNodeVec = std::vector; using NgraphShape = std::unordered_map; -using NgraphDType = std::unordered_map; -using NgraphSType = std::unordered_map; +using NgraphType = std::unordered_map; using NDArrayMap = nnvm::NodeEntryMap; using StateMap = std::unordered_map; @@ -73,7 +72,7 @@ struct BindArg : public BindArgBase { // SimpleBind struct SimpleBindArg : public BindArgBase { SimpleBindArg(size_t numforward, const NgraphShape& shapes, - const NgraphDType& dtypes, const NgraphSType& stypes) + const NgraphType& dtypes, const NgraphType& stypes) : BindArgBase(numforward), shape_map_(shapes), dtype_map_(dtypes), @@ -81,8 +80,8 @@ struct SimpleBindArg : public BindArgBase { // simple bind arguments const NgraphShape shape_map_; - const NgraphDType dtype_map_; - const NgraphSType stype_map_; + const NgraphType dtype_map_; + const NgraphType stype_map_; }; // This is a compile-time hash map that contains information on @@ -160,8 +159,8 @@ class Compiler { StateMap CopySavedStates(const StateMap& saved_states); // Return maps of the shapes and types for further analysis in graph_executor const NgraphShape& GetNgraphShape() { return ngraph_shape_; } - const NgraphDType& GetNgraphDtype() { return ngraph_dtype_; } - const NgraphDType& GetNgraphStype() { return ngraph_stype_; } + const NgraphType& GetNgraphDtype() { return ngraph_dtype_; } + const NgraphType& GetNgraphStype() { return ngraph_stype_; } // Return copies of the feed_dict and inputs to feed back into the // graph executor inference engine const NDArrayMap& GetFeedDict() { return feed_dict_; } @@ -191,8 +190,8 @@ class Compiler { ngraph_bridge::Graph ngraph_; // shape and type maps to return to the graph executor NgraphShape ngraph_shape_; - NgraphDType ngraph_dtype_; - NgraphDType ngraph_stype_; + NgraphType ngraph_dtype_; + NgraphType ngraph_stype_; // copied feed dict and inputs nnvm::NodeEntryMap feed_dict_; NNVMNodeVec inputs_;