Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support backward of backward #23

Merged
merged 3 commits into from
Aug 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ class OperatorProperty {
for (size_t i = 0; i < ret_index.size(); ++i) {
ret[i] = all_data[ret_index[i]];
}
return std::move(ret);
return ret;
}
/*!
* \brief create OperatorProperty
Expand Down
29 changes: 19 additions & 10 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,27 +85,36 @@ class StaticGraph {
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
* \brief source node id; if this field is negative, it means this
* Node is a forward node. If this field is nonnegative, it
* means this Node is the gradient of the source node.
*/
int32_t source_id;
/*!
* \brief backward; if this field is true, that means this node
* represents the backward function of the op. Else, it
* represents the forward function. When it represents the
* backward function, itself has not op but shares from the
* source node. It is because the backward function shares the
* states from the forward, and they need to share op.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
* Since we support gradient of gradient, a forward node can also
* be the gradient of another node. See source id.
*/
int32_t backward_source_id;
bool backward;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}
Node() : source_id(-1), backward(false) {}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
return !backward && !is_variable();
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
return backward;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
return op == nullptr && source_id == -1;
}
};
/*! \brief all nodes in the graph */
Expand Down
16 changes: 8 additions & 8 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ GraphExecutor::GetResource(uint32_t node_id) const {
return node.op->ForwardResource();
} else {
CHECK(node.is_backward());
return graph_.nodes[node.backward_source_id].op->BackwardResource();
return graph_.nodes[node.source_id].op->BackwardResource();
}
}

Expand All @@ -90,7 +90,7 @@ inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const {
return node.op->NumReturns();
} else if (node.is_backward()) {
return static_cast<int>(
graph_.nodes[node.backward_source_id].op->ListArguments().size());
graph_.nodes[node.source_id].op->ListArguments().size());
} else {
CHECK(node.is_variable());
return 1;
Expand Down Expand Up @@ -121,11 +121,11 @@ inline std::vector<std::pair<T, T> > GraphExecutor::GetInplaceOption(
remap[i].first = in_data[rmap_index[i].first];
remap[i].second = *static_cast<const T*>(rmap_index[i].second);
}
return std::move(remap);
return remap;
} else {
CHECK(node.is_backward());
// forward property
const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get();
const OperatorProperty *fwd = graph_.nodes[node.source_id].op.get();

std::vector<int> out_grad_index(fwd->NumVisibleReturns());
std::vector<int> in_data_index(fwd->ListArguments().size());
Expand Down Expand Up @@ -161,7 +161,7 @@ inline std::vector<std::pair<T, T> > GraphExecutor::GetInplaceOption(
remap[i].first = *args_array[remap_index[i].first];
remap[i].second = *static_cast<T*>(remap_index[i].second);
}
return std::move(remap);
return remap;
}
}

Expand Down Expand Up @@ -196,7 +196,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) {
op_ctx_ptr->run_ctx = ctx;
op->Forward(*op_ctx_ptr, in_data, req, out_data);
};
return std::move(exec);
return exec;
}

void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) {
Expand Down Expand Up @@ -406,8 +406,8 @@ void GraphExecutor::InitOpNodes() {
} else {
CHECK(graph_.nodes[nid].is_backward());
op_node.op.reset(new BackwardOpWrapper(
graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(),
op_nodes_[graph_.nodes[nid].backward_source_id].op));
graph_.nodes[graph_.nodes[nid].source_id].op.get(),
op_nodes_[graph_.nodes[nid].source_id].op));
}
bool allow_cache = true;
for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) {
Expand Down
25 changes: 15 additions & 10 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
++out_degree[e.source_id];
}
if (n.is_backward()) {
++out_degree[n.backward_source_id];
++out_degree[n.source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
Expand All @@ -41,12 +41,12 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
}
}
if (n.is_backward()) {
if (--out_degree[n.backward_source_id] == 0) {
queue.push(n.backward_source_id);
if (--out_degree[n.source_id] == 0) {
queue.push(n.source_id);
}
}
}
return std::move(ret);
return ret;
}

bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
Expand Down Expand Up @@ -79,7 +79,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
} else if (nodes[nid].is_backward()) {
// simply use shapes from forward pass to assign backward shape
const Node& forward = nodes[node.backward_source_id];
const Node& forward = nodes[node.source_id];
CHECK(forward.is_forward());
std::vector<TShape>& in_grad_shapes = (*node_out_shapes)[nid];
CHECK(in_grad_shapes.size() == forward.inputs.size());
Expand All @@ -99,7 +99,7 @@ bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
}
}
// consistent check for input shapes
auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id];
auto& out_data_shapes = (*node_out_shapes)[node.source_id];
// use BackwardInputs to select entries corresponding to node.inputs
auto in_shape = forward.op->BackwardInputs(
out_data_shapes, in_grad_shapes, out_data_shapes);
Expand Down Expand Up @@ -130,7 +130,7 @@ bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
if (nodes[i].is_forward()) {
nout = nodes[i].op->NumReturns();
} else if (nodes[i].is_backward()) {
nout = static_cast<int>(nodes[nodes[i].backward_source_id].inputs.size());
nout = static_cast<int>(nodes[nodes[i].source_id].inputs.size());
}
node_out_shapes[i].resize(nout);
}
Expand Down Expand Up @@ -161,7 +161,7 @@ StaticGraph::Node StaticGraph::CreateSumNode(
os_size << grad_source.size();
agg_node.op->Init({{"size", os_size.str()}});
agg_node.inputs = grad_source;
return std::move(agg_node);
return agg_node;
}

void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
Expand Down Expand Up @@ -198,7 +198,6 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
uint32_t nid = *it;
// skip variables
if (nodes[nid].is_variable()) continue;
CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward";
// get out_grad and out_data entry
std::vector<DataEntry> out_grad, out_data;
// nvisible is out_grad.size()
Expand Down Expand Up @@ -229,7 +228,13 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
// Create a gradient backward node
Node grad_node;
// Point to the corresponding source
grad_node.backward_source_id = nid;
grad_node.source_id = nid;
// reverse the source node
grad_node.backward = !(nodes[grad_node.source_id].backward);
// if grad node is a forward node, needs to have its own OpProperty
if (!grad_node.backward) {
grad_node.op.reset(nodes[nodes[nid].source_id].op->Copy());
}
// select out the dependent inputs
grad_node.inputs = nodes[nid].op->BackwardInputs(
out_grad, nodes[nid].inputs, out_data);
Expand Down
24 changes: 12 additions & 12 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2015 by Contributors
*\file symbol.cc
*\brief symbol of mxnet
* Copyright (c) 2015 by Contributors
* \file symbol.cc
* \brief symbol of mxnet
*/
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
Expand All @@ -12,13 +12,13 @@

namespace mxnet {
/*!
*\brief Node is represents node of an operator in the symbolic graph.
* \brief Node is represents node of an operator in the symbolic graph.
*
*It stores connection to the inputs to function represented by OperatorProperty
*NOTE on data structure: there are three types of node:
*- Normal node: contains all the necessary elements of a graph.
*- OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied.
*- Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed.
* It stores connection to the inputs to function represented by OperatorProperty
* NOTE on data structure: there are three types of node:
* - Normal node: contains all the necessary elements of a graph.
* - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied.
* - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed.
*/
struct Symbol::Node {
/*! \brief Operator of this node */
Expand Down Expand Up @@ -201,7 +201,7 @@ std::vector<std::string> Symbol::ListReturns() const {
}
}
}
return std::move(ret);
return ret;
}

Symbol Symbol::operator[] (size_t index) const {
Expand Down Expand Up @@ -415,13 +415,13 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
for (const auto &s : symbols) {
ret.heads_.insert(ret.heads_.end(), s.heads_.begin(), s.heads_.end());
}
return std::move(ret);
return ret;
}

Symbol Symbol::CreateVariable(const std::string &name) {
Symbol s;
s.heads_.push_back(DataEntry(std::make_shared<Node>(nullptr, name), 0));
return std::move(s);
return s;
}

void Symbol::ToStaticGraph(StaticGraph *out_graph) const {
Expand Down