Skip to content

Commit

Permalink
[Pass] Enable BackwardOp
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent e008ea8 commit 8a39682
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 22 deletions.
6 changes: 3 additions & 3 deletions nnvm/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# NNVM: Build deep learning system by parts

NNVM is not a deep learning library. It is a modular, lightweight library to
NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to
help build deep learning libraries efficiently.

## What is it

While most deep learning systems offer end to end solutions,
it is interesting to ask if we can actually assemble a deep learning system by parts.
The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about.
We believe that the modular system is an interesting direction.
We believe that the decentralized modular system is an interesting direction.
The hope is that effective parts can be assembled together just like you assemble your own desktops.
So the customized deep learning solution can be minimax, minimum in terms of dependencies,
while maxiziming the users' need.
Expand All @@ -18,7 +18,7 @@ computation graph optimization such as memory reduction, device allocation,
operator fusion while being agnostic to the operator
interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs in general.
for neural nets and computation graphs generation and optimizations.

## Deep learning system by parts

Expand Down
13 changes: 13 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ using FInferShape = FInferNodeEntryAttr<TShape>;
*/
using FInferType = FInferNodeEntryAttr<int>;

/*!
* \brief Whether this op is an explicit backward operator
*
* If TIsBackwardOp is set to be true:
* - The first control_deps of the node points to the corresponding forward operator.
* - The outputs operator corresponds to exactly inputs of forward op one by one.
*
* \note Register under "TIsBackwardOp", default to false.
*
* This enables easier shape/type inference for backward operators for slice and reduction.
*/
using TIsBackwardOp = bool;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
2 changes: 1 addition & 1 deletion nnvm/include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Symbol {
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs);
std::unordered_map<std::string, std::string> attrs);
/*!
* \brief create variable symbol node
* \param name name of the variable
Expand Down
2 changes: 1 addition & 1 deletion nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
}

Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs) {
std::unordered_map<std::string, std::string> attrs) {
Symbol s;
NodePtr n = Node::Create();
n->op = op;
Expand Down
42 changes: 26 additions & 16 deletions nnvm/src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,21 @@
namespace nnvm {
namespace pass {

template<typename AttrType>
template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret,
const AttrType def_value,
const char* infer_name,
const char* arg_name,
const char* attr_key_name,
const char* attr_name,
const char* known_name) {
const char* known_name,
IsNone fis_none) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward =
Op::GetAttr<TIsBackwardOp>("TIsBackwardOp");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value);

Expand Down Expand Up @@ -66,6 +69,19 @@ Graph InferAttr(Graph &&ret,
if (finfer_shape.count(inode.source->op)) {
num_known +=
finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape);
} else if (is_backward.get(inode.source->op, false)) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs())
<< "BackwardOp need to correspond to the forward node";
bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false;
}
num_known += known;
}
}
// set the shapes
Expand All @@ -79,13 +95,10 @@ NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret),
TShape(),
"FInferShape",
"shape_args",
"shape_attr_key",
"shape",
"shape_num_known_nodes");
std::move(ret), TShape(),
"FInferShape", "shape_args", "shape_attr_key",
"shape", "shape_num_known_nodes",
[](const TShape& s) { return s.ndim() == 0; });
})
.set_change_graph(false)
.provide_graph_attr("shape");
Expand All @@ -94,13 +107,10 @@ NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) {
return InferAttr<int>(
std::move(ret),
0,
"FInferType",
"dtype_args",
"dtype_attr_key",
"dtype",
"dtype_num_known_nodes");
std::move(ret), 0,
"FInferType", "dtype_args", "dtype_attr_key",
"dtype", "dtype_num_known_nodes",
[](const int t) { return t == -1; });
})
.set_change_graph(false)
.provide_graph_attr("dtype");
Expand Down
3 changes: 2 additions & 1 deletion nnvm/src/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ void test_speed() {
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::unordered_map<std::string, std::string> kwargs;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, {});
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, kwargs);
vec[0] = &s;
vec[1] =&s;
tmp.clear();
Expand Down

0 comments on commit 8a39682

Please sign in to comment.