Skip to content

Commit

Permalink
Check that the node is not null, add contains to OpMap (apache#3037)
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy authored and Wei Chen committed May 13, 2019
1 parent 23e7e7d commit f2041c8
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 20 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
10 changes: 7 additions & 3 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,16 @@ inline void DFSVisit(const std::vector<NodeEntry>& heads,
});
PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[fvisit](GNode n) {
fvisit(*n);
}, // FVisit
[](GNode n)->Node* {
return n->get();
}, // HashFunc
[](GNode n)->uint32_t { // InDegree
if (!(*n)) return 0;
return (*n)->inputs.size() + (*n)->control_deps.size();
},
},
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).node;
Expand Down
22 changes: 20 additions & 2 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,13 @@ class OpMap {
*/
inline int count(const Op* op) const;

/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return true if op is contained in map, false otherwise.
*/
inline bool contains(const Op* op) const;

private:
friend class Op;
// internal attribute name
Expand Down Expand Up @@ -578,9 +585,20 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap
template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
if (contains(op)) {
return 1;
} else {
return 0;
}
}

template<typename ValueType>
inline bool OpMap<ValueType>::contains(const Op* op) const {
if (op == nullptr) {
return false;
}
const uint32_t idx = op->index_;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
return idx < data_.size() ? (data_[idx].second != 0) : false;
}

template<typename ValueType>
Expand Down
1 change: 1 addition & 0 deletions nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
CHECK(n);
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
Expand Down
32 changes: 18 additions & 14 deletions nnvm/src/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -143,23 +143,23 @@ Graph Gradient(Graph src) {
<< "because it is unreachable from the outputs.";
}

// construct mirror reduece memory strategy if needed
// construct mirror as memory reduction strategy if needed
std::unordered_map<Node*, NodePtr> mirror_map;
if (mirror_fun != nullptr) {
for (const NodePtr& n : topo_order) {
if (mirror_fun(*n)) {
for (const NodePtr& node_ptr : topo_order) {
if (mirror_fun(*node_ptr)) {
NodePtr new_node = Node::Create();
*new_node = *n;
*new_node = *node_ptr;
new_node->attrs.name += "_mirror";
for (auto& e : new_node->inputs) {
e.node = mirror_map.at(e.node.get());
}
for (auto& n : new_node->control_deps) {
n = mirror_map.at(n.get());
}
mirror_map[n.get()] = std::move(new_node);
mirror_map[node_ptr.get()] = std::move(new_node);
} else {
mirror_map[n.get()] = n;
mirror_map[node_ptr.get()] = node_ptr;
}
}
}
Expand All @@ -185,7 +185,8 @@ Graph Gradient(Graph src) {
if ((*rit)->inputs.size() != 0) {
NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
std::vector<NodeEntry> input_grads;
if (grad_fun_map.count(ptr->op())) {
// Check for FGradient
if (grad_fun_map.contains(ptr->op())) {
input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
<< "Gradient function not returning enough gradient";
Expand All @@ -205,20 +206,23 @@ Graph Gradient(Graph src) {
if (p->op()->attr_parser != nullptr) {
p->op()->attr_parser(&(p->attrs));
}
input_grads.emplace_back(nnvm::NodeEntry{p, 0, 0});
input_grads.emplace_back(p, 0, 0);
}
} else {
LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
<< "because it didn't register FGradient attribute.";
}
for (const auto& nodeEntry : input_grads)
CHECK(nodeEntry.node);
auto git = input_grads.begin();
CHECK((*rit)->inputs.size() <= input_grads.size());
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
auto& ge = output_grads[it->node.get()][it->index];
auto& output_grad_entry = output_grads[it->node.get()][it->index];
// if any of the backward op can do shape inference, the hint is not necessary.
if (finfer_shape.count(git->node->op())) {
ge.need_attr_hint = false;
if (finfer_shape.contains(git->node->op())) {
output_grad_entry.need_attr_hint = false;
}
ge.grads.emplace_back(std::move(*git));
output_grad_entry.grads.emplace_back(std::move(*git));
}
}
}
Expand Down

0 comments on commit f2041c8

Please sign in to comment.