Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-626] Add while_loop #11566

Merged
merged 31 commits into from
Jul 19, 2018
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6976b90
Add while_loop
junrushao Jul 5, 2018
249c8b4
Avoid input/output overlap for nnvm graph cut
junrushao Jul 6, 2018
cfa13b1
Add more testcases
junrushao Jul 6, 2018
9ca3dd5
Enhance test 4.2
junrushao Jul 6, 2018
6418065
Add more complicated testcases; Add testcase for nested loop
junrushao Jul 7, 2018
ad0accc
Check unused loop_vars in while_loop
junrushao Jul 7, 2018
8edb051
Add testcases for RNN
junrushao Jul 8, 2018
dc48a7f
Make lint happy
junrushao Jul 8, 2018
06d29cb
Make lint happy
junrushao Jul 8, 2018
316b0f7
Address TODOs
junrushao Jul 8, 2018
9572a87
Fix flaky test for while_loop
junrushao Jul 9, 2018
e603170
Address comments
junrushao Jul 9, 2018
5d298bb
Improve docstring
junrushao Jul 10, 2018
43128c0
Improve error message
junrushao Jul 10, 2018
f241e3c
Add benchmark code
junrushao Jul 10, 2018
e393bd0
Update benchmarks
junrushao Jul 10, 2018
1b11670
Allow sparse types
junrushao Jul 11, 2018
4e4f5f9
Make max_iterations default to None
junrushao Jul 11, 2018
6736e3d
Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
junrushao Jul 12, 2018
16e2823
Pad imperative while_loop so that it has the same shape with the symb…
junrushao Jul 12, 2018
93d8d0c
Add example result into the example section
junrushao Jul 12, 2018
ca4d7b0
Remove unused class member
junrushao Jul 12, 2018
e067d0b
Rename unittest to test_contrib_control_flow.py
junrushao Jul 12, 2018
c08b063
Update docstring
junrushao Jul 13, 2018
9b219d9
Update docstring
junrushao Jul 13, 2018
3ea7bda
Trigger CI
junrushao Jul 13, 2018
168bd27
Change threshold for assert_almost_equal
junrushao Jul 13, 2018
aa9722d
Trigger CI
junrushao Jul 13, 2018
e69b674
Address comments from szha
junrushao Jul 18, 2018
dfc1828
Rewrite benchmark code
junrushao Jul 18, 2018
bd48b77
Fix sphinx warning
junrushao Jul 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 6ab4da to 290226
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,22 @@ def benchmark_rnn(cell, rnn_data, states):
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue

if isinstance(cell, gluon.rnn.GRUCell):
if isinstance(cell, gluon.rnn.RNNCell):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in the below code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't touch this file. It is renamed from https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can coordinate with @zheng-da

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha Da and I decide that I rewrite these two files. Will push a commit later today.

rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
elif isinstance(cell, gluon.rnn.GRUCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
Expand Down
214 changes: 214 additions & 0 deletions benchmark/python/control_flow/while_loop_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py

import subprocess
import mxnet as mx
from mxnet import gluon
import time
import copy

def get_gpus():
"""
return a list of GPUs
"""
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return []
return range(len([i for i in re.split('\n') if 'GPU' in i]))

class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell, length, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.length = length
self.cell = cell

def hybrid_forward(self, F, inputs, states):
def _func(*states):
i = states[0]
s = states[1: ]
data = inputs.take(i).squeeze(axis=0)
out, new_s = self.cell(data, s)
new_s = [i + 1] + new_s
return out, new_s
out, states = F.contrib.while_loop(
cond=lambda i, *_: i < self.length,
func=_func,
# lambda i, *s: [i + 1] + list(self.cell(s)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

Copy link
Member Author

@junrushao junrushao Jul 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. Sorry for that T_T

loop_vars=states,
max_iterations=self.length,
)
return out + states

def benchmark_rnn(cell, rnn_data, states, length):
ctx = rnn_data.context
num_batches = 20

# Imperative
cell0 = copy.deepcopy(cell)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there's quite a bit of repetition in this function

layer0 = TestRNNLayer(cell0, length)
layer0.initialize(ctx=ctx)

# Hybrid-cell
cell1 = copy.deepcopy(cell)
cell1.hybridize()
layer1 = TestRNNLayer(cell1, length)
layer1.initialize(ctx=ctx)

# Hybrid
cell2 = copy.deepcopy(cell)
layer2 = TestRNNLayer(cell2, length)
layer2.initialize(ctx=ctx)
layer2.hybridize()
layer2(rnn_data, states)

# Static-hybrid-cell
cell3 = copy.deepcopy(cell)
cell3.hybridize(static_alloc=True)
layer3 = TestRNNLayer(cell3, length)
layer3.initialize(ctx=ctx)

tic = time.time()
for i in range(num_batches):
res0 = layer0(rnn_data, states)
mx.nd.waitall()
print("Imperative inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res1 = layer1(rnn_data, states)
mx.nd.waitall()
print("Hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res3 = layer3(rnn_data, states)
mx.nd.waitall()
print("Static-hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res2 = layer2(rnn_data, states)
mx.nd.waitall()
print("Hybrid inference takes " + str(time.time() - tic))

layer2.export("while_loop_rnn")
symnet = mx.symbol.load('while_loop_rnn-symbol.json')
args1 = {}
params = layer2.collect_params()
for key in params.keys():
args1[key] = params[key].data()
args1['data0'] = rnn_data
for i in range(len(states)):
args1['data' + str(i + 1)] = states[i]
exe = symnet.bind(ctx=ctx, args=args1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=False)
mx.nd.waitall()
print("Symbol inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res0 = layer0(rnn_data, states)
res0[0].backward()
mx.nd.waitall()
print("Imperative training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res1 = layer1(rnn_data, states)
res1[0].backward()
mx.nd.waitall()
print("Hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res3 = layer3(rnn_data, states)
res3[0].backward()
mx.nd.waitall()
print("Static-hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res2 = layer2(rnn_data, states)
res2[0].backward()
mx.nd.waitall()
print("Hybrid training takes " + str(time.time() - tic))

# gradients for the backward of the while_loop symbol
args_grad1 = {}
for key in args1.keys():
if key != "data1":
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=True)
exe.backward(res2)
mx.nd.waitall()
print("Symbol training takes " + str(time.time() - tic))
print("")

if __name__ == '__main__':
def _zeros(shape):
return mx.nd.zeros(shape=shape, ctx=mx.cpu(0))
def _array(shape):
return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0))
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'),
gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue
if isinstance(cell, gluon.rnn.RNNCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
]
if isinstance(cell, gluon.rnn.GRUCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
]
elif isinstance(cell, gluon.rnn.LSTMCell):
rnn_data = _array((seq_len, batch_size, ndim))
states = [
_zeros((1, )),
_array((batch_size, ndim)),
_array((batch_size, ndim)),
]
if ctx == mx.gpu(0):
dev = "GPU"
else:
dev = "CPU"
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size))
benchmark_rnn(cell, rnn_data, states, seq_len)
139 changes: 138 additions & 1 deletion python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
except ImportError:
pass

__all__ = ["rand_zipfian"]
__all__ = ["rand_zipfian", "foreach", "while_loop"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to update docs/api/python/ndarray&symbol/contrib.md, too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
Expand Down Expand Up @@ -191,3 +191,140 @@ def check_input(inputs, in_type, msg):
if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)


def while_loop(cond, func, loop_vars, max_iterations=None):
"""Run a while loop with user-defined computation and loop condition.

This operator simulates a while loop which iterately does customized computation
as long as the condition is satisfied.

`loop_vars` is a list of NDArrays on which the computation uses.

`cond` is a user-defined function, used as the loop condition.
It consumes `loop_vars`, and produces a scalar MXNet NDArray,
indicating the termination of the loop.
The loop ends when `cond` returns false (zero).
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => NDArray`.

`func` is a user-defined function, used as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
In each step, `step_output` should contain the same number elements.
Through all steps, the i-th element of `step_output` should have the same shape and dtype.
Also, `new_loop_vars` should contain the same number of elements as `loop_vars`,
and the corresponding element should have the same shape and dtype.
The `func` is variadic, and its signature should be
`func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`.

`max_iterations` is a scalar that defines the maximum number of iterations allowed.

This function returns two lists as a tuple.
The first list has the length of `|step_output|`,
in which the i-th element are all i-th elements of
`step_output` from all steps, stacked along axis 0.
The second list has the length of `|loop_vars|`,
which represents final states of loop variables.

Warning 1: when `cond` is never satisfied, we assume `step_output` is empty,
because it cannot be inferred. This is different from the symbolic version.

Warning 2: The output shape along axis 0 is currently the actual number of iterations taken,
which is different from the symbolic version, where it is `max_iteration`.

Parameters
----------
cond: a Python function.
The loop condition.
func: a Python function.
The loop body.
loop_vars: list of NDArrays.
The initial values of the loop variables.
max_iteration: a python int.
Maximum number of iterations.

Returns
-------
outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays.
The first list contains the stacked output from each step,
The second list contains the final state.

Examples
--------
>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10)
"""
def _to_python_scalar(inputs, type_, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if isinstance(inputs, ndarray.NDArray):
inputs = inputs.asscalar()
try:
inputs = type_(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__))
return inputs

def _to_ndarray_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray,
a tuple of mxnet NDArray, into a tuple of NDArray
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, ndarray.NDArray):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
for item in inputs:
if not isinstance(item, ndarray.NDArray):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
return inputs

def _func_wrapper(loop_vars):
"""This wrapper unifies
"func: loop_vars -> new_loop_vars"
and "func: loop_vars -> (step_output, new_loop_vars)"
into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars)
"""
step_output, new_loop_vars = func(*loop_vars)
if step_output is None:
step_output = []
if new_loop_vars is None:
new_loop_vars = []
step_output = _to_ndarray_tuple(step_output, "step_output")
new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars")
if len(loop_vars) != len(new_loop_vars):
raise ValueError("The length of loop_vars should be consistent during the loop")
return step_output, new_loop_vars

if max_iterations is None:
raise ValueError("max_iterations should be specified")
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars")
# It should be work as fine if loop_vars are empty I guess,
# but it is semantically unnecessary to include this case.
if len(loop_vars) == 0:
raise ValueError("loop_vars should contain at least one element")

steps = 0
outputs = []
while steps < max_iterations and \
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this could end before reaching max_iterations. Isn't this inconsistent with symbol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so.

Copy link
Member Author

@junrushao junrushao Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zheng-da So should I pad the arrays to make them consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to do so, in my opinion. what do you think? @piiswrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break

step_output, loop_vars = _func_wrapper(loop_vars)
outputs.append(step_output)
steps += 1
if len(outputs) != steps or len(step_output) != len(outputs[0]):
raise ValueError("Number of elements in step_output should be the same in each step")
stacked_outputs = []
for i_th, items in enumerate(zip(*outputs), 1):
try:
stacked_outputs.append(ndarray.op.stack(*items))
except ValueError:
raise ValueError("\n".join(
["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] +
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)]
))
return stacked_outputs, list(loop_vars)
Loading