Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-432] Add Foreach (#11531)
Browse files Browse the repository at this point in the history
* Test input a graph.

* Update foreach to execute the subgraph.

* print inputs/outputs in foreach.

* Remove print.

* add test code for foreach.

* exec foreach outside the engine.

* Implements forward of foreach.

* Add support for variable numbers of inputs and outputs.

* Add a python wrapper for foreach.

* Fix the order of inputs.

* add test with lstm.

* hide C version of foreach.

* fix a bug temporarily.

* Test free variables.

* change for the new interface of InputGraph attribute.

* Add attribute to the subgraph.

* Handle free variables.

* Get all input symbols of a subgraph.

* Fix shape, dtype and storage inference.

* reorganize the output of foreach.

* Add a gluon RNN unroll with symbol foreach.

* print unnecessary print.

* have imperative and symbolic foreach.

* Fix an error after moving foreach.

* Fix imperative foreach

* Fix a minor problem.

* Use CachedOp to execute subgraph.

* update TODO.

* make foreach op use FStatefulComputeEx.

TODO we need to change stateful executor to handle subgraph.

* Add backward.

* Fix bugs.

* enable backward test in lstm.

* Fix a bug in foreach backward for free variables.

* change for the new CachedOp.

* Detect the backward computation.

* Fix bugs in foreach.

* fix tests.

* update tests.

* check state shape.

* enable nested foreach.

* remove print.

* fix a bug in test.

* handle infer storage type for backward.

* address comments.

* address comments.

* move some common functions out.

* address comments.

* fix lint.

* Fix lint.

* add doc.

* undo modification in imperative.h

* add doc and remove example code.

* fix lint.

* fix lint.

* Fix lint.

* make nd.foreach and sym.foreach consistent.

* fix compile error.

* address comments.

* update.

* check for loop only works for dense arrays.

* move control flow op out of nn/

* fix include.

* add a test in gluon.

* work for GPU.

* small fix.

* remove subgraph_name

* create loop state for reuse in the future.

* move code.

* Revert "remove subgraph_name"

This reverts commit 977f562.

* cut graph.

* rename new var nodes.

* Fix tests.

* Fix bugs caused by ctypes (#29)

* Add save/load json in testcases for foreach (#30)

* support subgraph in stateful executor.

* Fix compilation.

* fix a bug when a subgraph has variable nodes.

* Fix a bug of getting symbols.

* copy var nodes.

* Fix getting op states.

* fix lint error.

* address comments.

* fix lint error.

* simplify the execution of subgraph in the main thread.

* fix lint error.

* avoid waiting for computation in each iteration.

* reuse cached op for inference.

* share memory across mini-batches.

* reuse memory.

reuse memory between iterations in inference.
reuse memory between mini-batches in training.

* add tests for multiple batches.

* remove entry.

* add benchmark for foreach.

* benchmark large batch size.

* Fix the benchmark for GPU.

* address comments.

* update shape/dtype/storage inference.

* update contrib API docs.

* support nested foreach.

* use a single CachedOp for all iterations.

* use large dim.

* update benchmark.

* update benchmark.

* update benchmark.

* update benchmark.

* return symbol arrays correctly in MXSymbolCutSubgraph.

* return symbol arrays in MXSymbolGetInputSymbols.

* fix lint error.

* use cachedop to infer storage in backward.

* fix scala API.

* update comments.

* fix scala.

* fix test.

* fix attribute name.

* move benchmark.

* fix the mapping of operator inputs/outputs and subgraph inputs/outputs.

* add tests for dtype/shape inference.

* reorganize tests.

* fix a bug of cutting NodeEntry.

When two node entries refer to the same output of a node, we should
create only one var node for these two node entries.

* fix lint error.

* handle the case that outputs are inputs.

* handle the case that inputs aren't used.

* handle the case without output data.

* fix a bug in foreach backward.

* fix a bug when there isn't output data.

* Fix lint error.

* test diff Gluon RNN cells.

* test all symbol RNN cells.

* adjust the test precision.

* Fix a bug in getting a list of variable names.

We can't get a list of variable names from a hashtable. The order can't
be guaranteed. Python2 and Python3 output different orders.

* fix lint error.

* Test 1D array.

* fix a bug when subgraph inputs and outputs share NDArray.

* fix.

* fix

* add comments.
  • Loading branch information
zheng-da authored and piiswrong committed Jul 2, 2018
1 parent 7c74d1f commit 030fbc3
Show file tree
Hide file tree
Showing 21 changed files with 2,217 additions and 21 deletions.
189 changes: 189 additions & 0 deletions benchmark/python/control_flow/rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import subprocess
import mxnet as mx
from mxnet import gluon
import time
import copy

def get_gpus():
"""
return a list of GPUs
"""
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return []
return range(len([i for i in re.split('\n') if 'GPU' in i]))

class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = cell

def hybrid_forward(self, F, inputs, states):
out, states = F.contrib.foreach(self.cell, inputs, states)
return out

def benchmark_rnn(cell, rnn_data, states):
ctx = rnn_data.context
num_batches = 20

# Imperative
cell0 = copy.deepcopy(cell)
layer0 = TestRNNLayer(cell0)
layer0.initialize(ctx=ctx)

# Hybridize
cell1 = copy.deepcopy(cell)
cell1.hybridize()
layer1 = TestRNNLayer(cell1)
layer1.initialize(ctx=ctx)

# Hybridize
cell2 = copy.deepcopy(cell)
layer2 = TestRNNLayer(cell2)
layer2.initialize(ctx=ctx)
layer2.hybridize()
layer2(rnn_data, states)

# Hybridize
cell3 = copy.deepcopy(cell)
cell3.hybridize(static_alloc=True)
layer3 = TestRNNLayer(cell3)
layer3.initialize(ctx=ctx)

tic = time.time()
for i in range(num_batches):
res0 = layer0(rnn_data, states)
mx.nd.waitall()
print("Imperative inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res1 = layer1(rnn_data, states)
mx.nd.waitall()
print("Hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res3 = layer3(rnn_data, states)
mx.nd.waitall()
print("Static-hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res2 = layer2(rnn_data, states)
mx.nd.waitall()
print("Hybrid inference takes " + str(time.time() - tic))

layer2.export("foreach_rnn")
symnet = mx.symbol.load('foreach_rnn-symbol.json')
args1 = {}
params = layer2.collect_params()
for key in params.keys():
args1[key] = params[key].data()
args1['data0'] = rnn_data
for i in range(len(states)):
args1['data' + str(i + 1)] = states[i]
exe = symnet.bind(ctx=ctx, args=args1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=False)
mx.nd.waitall()
print("Symbol inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res0 = layer0(rnn_data, states)
res0.backward()
mx.nd.waitall()
print("Imperative training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res1 = layer1(rnn_data, states)
res1.backward()
mx.nd.waitall()
print("Hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res3 = layer3(rnn_data, states)
res3.backward()
mx.nd.waitall()
print("Static-hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res2 = layer2(rnn_data, states)
res2.backward()
mx.nd.waitall()
print("Hybrid training takes " + str(time.time() - tic))

# gradients for the backward of the foreach symbol
args_grad1 = {}
for key in args1.keys():
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=True)
exe.backward(res2)
mx.nd.waitall()
print("Symbol training takes " + str(time.time() - tic))
print("")

if __name__ == '__main__':
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue

if isinstance(cell, gluon.rnn.GRUCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
elif isinstance(cell, gluon.rnn.LSTMCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
if ctx == mx.gpu(0):
dev = "GPU"
else:
dev = "CPU"
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev,
batch_size))
benchmark_rnn(cell, rnn_data, states)
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
fft
ifft
quantize
foreach
```

## API Reference
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
fft
ifft
quantize
foreach
```

## API Reference
Expand Down
22 changes: 22 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **name);

/*!
* \brief Get the input symbols of the graph.
* \param sym The graph.
* \param inputs The input symbols of the graph.
* \param input_size the number of input symbols returned.
*/
MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
* The input graph will be modified. A variable node will be created for each
* edge that connects to nodes outside the subgraph. The outside nodes that
* connect to the subgraph will be returned.
* \param sym The graph.
* \param inputs The nodes that connect to the subgraph.
* \param input_size The number of such nodes.
*/
MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down
11 changes: 9 additions & 2 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ enum OpReqType {
* \sa Resource
*/
struct OpContext {
/*! \brief whether there is a backward phase to compute gradients. */
bool need_grad;
/*! \brief whether it is training phase */
int is_train;
bool is_train;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief the callback when operation completes, used by asynchronize ops */
Expand Down Expand Up @@ -98,7 +100,12 @@ enum class ExecType {
* In current implementation, copy operator is specially handled by executor.
* This flag is used for special case treatment and future extension of different copy ops.
*/
kCrossDeviceCopy
kCrossDeviceCopy,
/*!
* \brief A subgraph execution should happen in the main thread, instead of
* in the execution engine.
*/
kSubgraphExec,
};

/*! \brief the dispatch mode of the operator */
Expand Down
96 changes: 96 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import math
from ..context import current_context
from ..random import uniform
from ..base import _as_list
from . import ndarray
try:
from .gen_contrib import *
except ImportError:
Expand Down Expand Up @@ -95,3 +97,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(body, data, init_states):
"""Run a for loop with user-defined computation over NDArrays on dimension 0.
This operator simulates a for loop and body has the computation for an iteration
of the for loop. It runs the computation in body on each slice from the input
NDArrays.
body takes two arguments as input and outputs a tuple of two elements,
as illustrated below:
out, states = body(data1, states)
data1 can be either an NDArray or a list of NDArrays. If data is an NDArray,
data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same
size as data. states is a list of NDArrays and have the same size as init_states.
Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated
as the first output of foreach; states from the last execution of body
are the second output of foreach.
The computation done by this operator is equivalent to the pseudo code below
when the input data is NDArray:
states = init_states
outs = []
for i in data.shape[0]:
s = data[i]
out, states = body(s, states)
outs.append(out)
outs = stack(*outs)
Parameters
----------
body : a Python function.
Define computation in an iteration.
data: an NDArray or a list of NDArrays.
The input data.
init_states: an NDArray or a list of NDArrays.
The initial values of the loop states.
name: string.
The name of the operator.
Returns
-------
outputs: an NDArray or a list of NDArrays.
The output data concatenated from the output of all iterations.
states: a list of NDArrays.
The loop states in the last iteration.
Examples
--------
>>> step = lambda data, states: (data + states[0], [states[0] * 2])
>>> data = mx.nd.random.uniform(shape=(2, 10))
>>> states = [mx.nd.random.uniform(shape=(10))]
>>> outs, states = mx.nd.contrib.foreach(step, data, states)
"""

def check_input(inputs, in_type, msg):
is_NDArray_or_list = True
if isinstance(inputs, list):
for i in inputs:
if not isinstance(i, in_type):
is_NDArray_or_list = False
break
else:
is_NDArray_or_list = isinstance(inputs, in_type)
assert is_NDArray_or_list, msg

check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays")
check_input(init_states, ndarray.NDArray,
"init_states should be an NDArray or a list of NDArrays")

not_data_list = isinstance(data, ndarray.NDArray)
num_iters = data.shape[0] if not_data_list else data[0].shape[0]
states = init_states
outputs = []
for i in range(num_iters):
if not_data_list:
eles = data[i]
else:
eles = [d[i] for d in data]
outs, states = body(eles, states)
outs = _as_list(outs)
outputs.append(outs)
outputs = zip(*outputs)
tmp_outputs = []
for out in outputs:
tmp_outputs.append(ndarray.op.stack(*out))
outputs = tmp_outputs

if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)
Loading

0 comments on commit 030fbc3

Please sign in to comment.