-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-626] Add while_loop #11566
[MXNET-626] Add while_loop #11566
Changes from 18 commits
6976b90
249c8b4
cfa13b1
9ca3dd5
6418065
ad0accc
8edb051
dc48a7f
06d29cb
316b0f7
9572a87
e603170
5d298bb
43128c0
f241e3c
e393bd0
1b11670
4e4f5f9
6736e3d
16e2823
93d8d0c
ca4d7b0
e067d0b
c08b063
9b219d9
3ea7bda
168bd27
aa9722d
e69b674
dfc1828
bd48b77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py | ||
|
||
import subprocess | ||
import mxnet as mx | ||
from mxnet import gluon | ||
import time | ||
import copy | ||
|
||
def get_gpus(): | ||
""" | ||
return a list of GPUs | ||
""" | ||
try: | ||
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) | ||
except OSError: | ||
return [] | ||
return range(len([i for i in re.split('\n') if 'GPU' in i])) | ||
|
||
class TestRNNLayer(gluon.HybridBlock): | ||
def __init__(self, cell, length, prefix=None, params=None): | ||
super(TestRNNLayer, self).__init__(prefix=prefix, params=params) | ||
self.length = length | ||
self.cell = cell | ||
|
||
def hybrid_forward(self, F, inputs, states): | ||
def _func(*states): | ||
i = states[0] | ||
s = states[1: ] | ||
data = inputs.take(i).squeeze(axis=0) | ||
out, new_s = self.cell(data, s) | ||
new_s = [i + 1] + new_s | ||
return out, new_s | ||
out, states = F.contrib.while_loop( | ||
cond=lambda i, *_: i < self.length, | ||
func=_func, | ||
# lambda i, *s: [i + 1] + list(self.cell(s)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed. Sorry for that T_T |
||
loop_vars=states, | ||
max_iterations=self.length, | ||
) | ||
return out + states | ||
|
||
def benchmark_rnn(cell, rnn_data, states, length): | ||
ctx = rnn_data.context | ||
num_batches = 20 | ||
|
||
# Imperative | ||
cell0 = copy.deepcopy(cell) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: there's quite a bit of repetition in this function |
||
layer0 = TestRNNLayer(cell0, length) | ||
layer0.initialize(ctx=ctx) | ||
|
||
# Hybrid-cell | ||
cell1 = copy.deepcopy(cell) | ||
cell1.hybridize() | ||
layer1 = TestRNNLayer(cell1, length) | ||
layer1.initialize(ctx=ctx) | ||
|
||
# Hybrid | ||
cell2 = copy.deepcopy(cell) | ||
layer2 = TestRNNLayer(cell2, length) | ||
layer2.initialize(ctx=ctx) | ||
layer2.hybridize() | ||
layer2(rnn_data, states) | ||
|
||
# Static-hybrid-cell | ||
cell3 = copy.deepcopy(cell) | ||
cell3.hybridize(static_alloc=True) | ||
layer3 = TestRNNLayer(cell3, length) | ||
layer3.initialize(ctx=ctx) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res0 = layer0(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Imperative inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res1 = layer1(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Hybrid-cell inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res3 = layer3(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Static-hybrid-cell inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
res2 = layer2(rnn_data, states) | ||
mx.nd.waitall() | ||
print("Hybrid inference takes " + str(time.time() - tic)) | ||
|
||
layer2.export("while_loop_rnn") | ||
symnet = mx.symbol.load('while_loop_rnn-symbol.json') | ||
args1 = {} | ||
params = layer2.collect_params() | ||
for key in params.keys(): | ||
args1[key] = params[key].data() | ||
args1['data0'] = rnn_data | ||
for i in range(len(states)): | ||
args1['data' + str(i + 1)] = states[i] | ||
exe = symnet.bind(ctx=ctx, args=args1) | ||
tic = time.time() | ||
for i in range(num_batches): | ||
exe.forward(is_train=False) | ||
mx.nd.waitall() | ||
print("Symbol inference takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res0 = layer0(rnn_data, states) | ||
res0[0].backward() | ||
mx.nd.waitall() | ||
print("Imperative training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res1 = layer1(rnn_data, states) | ||
res1[0].backward() | ||
mx.nd.waitall() | ||
print("Hybrid-cell training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res3 = layer3(rnn_data, states) | ||
res3[0].backward() | ||
mx.nd.waitall() | ||
print("Static-hybrid-cell training takes " + str(time.time() - tic)) | ||
|
||
tic = time.time() | ||
for i in range(num_batches): | ||
with mx.autograd.record(): | ||
res2 = layer2(rnn_data, states) | ||
res2[0].backward() | ||
mx.nd.waitall() | ||
print("Hybrid training takes " + str(time.time() - tic)) | ||
|
||
# gradients for the backward of the while_loop symbol | ||
args_grad1 = {} | ||
for key in args1.keys(): | ||
if key != "data1": | ||
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) | ||
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) | ||
tic = time.time() | ||
for i in range(num_batches): | ||
exe.forward(is_train=True) | ||
exe.backward(res2) | ||
mx.nd.waitall() | ||
print("Symbol training takes " + str(time.time() - tic)) | ||
print("") | ||
|
||
if __name__ == '__main__': | ||
def _zeros(shape): | ||
return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) | ||
def _array(shape): | ||
return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) | ||
ndim = 512 | ||
seq_len = 100 | ||
batch_sizes = [1, 32] | ||
cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), | ||
gluon.rnn.GRUCell(ndim, prefix='rnn_'), | ||
gluon.rnn.LSTMCell(ndim, prefix='rnn_')] | ||
ctxs = [mx.cpu(0), mx.gpu(0)] | ||
for cell in cells: | ||
for ctx in ctxs: | ||
for batch_size in batch_sizes: | ||
if len(get_gpus()) == 0 and ctx == mx.gpu(0): | ||
continue | ||
if isinstance(cell, gluon.rnn.RNNCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
] | ||
if isinstance(cell, gluon.rnn.GRUCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
] | ||
elif isinstance(cell, gluon.rnn.LSTMCell): | ||
rnn_data = _array((seq_len, batch_size, ndim)) | ||
states = [ | ||
_zeros((1, )), | ||
_array((batch_size, ndim)), | ||
_array((batch_size, ndim)), | ||
] | ||
if ctx == mx.gpu(0): | ||
dev = "GPU" | ||
else: | ||
dev = "CPU" | ||
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) | ||
benchmark_rnn(cell, rnn_data, states, seq_len) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
except ImportError: | ||
pass | ||
|
||
__all__ = ["rand_zipfian"] | ||
__all__ = ["rand_zipfian", "foreach", "while_loop"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to update docs/api/python/ndarray&symbol/contrib.md, too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
|
||
# pylint: disable=line-too-long | ||
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): | ||
|
@@ -191,3 +191,140 @@ def check_input(inputs, in_type, msg): | |
if not_data_list and len(outputs) == 1: | ||
outputs = outputs[0] | ||
return (outputs, states) | ||
|
||
|
||
def while_loop(cond, func, loop_vars, max_iterations=None): | ||
"""Run a while loop with user-defined computation and loop condition. | ||
|
||
This operator simulates a while loop which iterately does customized computation | ||
as long as the condition is satisfied. | ||
|
||
`loop_vars` is a list of NDArrays on which the computation uses. | ||
|
||
`cond` is a user-defined function, used as the loop condition. | ||
It consumes `loop_vars`, and produces a scalar MXNet NDArray, | ||
indicating the termination of the loop. | ||
The loop ends when `cond` returns false (zero). | ||
The `cond` is variadic, and its signature should be | ||
`cond(*loop_vars) => NDArray`. | ||
|
||
`func` is a user-defined function, used as the loop body. | ||
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. | ||
In each step, `step_output` should contain the same number elements. | ||
Through all steps, the i-th element of `step_output` should have the same shape and dtype. | ||
Also, `new_loop_vars` should contain the same number of elements as `loop_vars`, | ||
and the corresponding element should have the same shape and dtype. | ||
The `func` is variadic, and its signature should be | ||
`func(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. | ||
|
||
`max_iterations` is a scalar that defines the maximum number of iterations allowed. | ||
|
||
This function returns two lists as a tuple. | ||
The first list has the length of `|step_output|`, | ||
in which the i-th element are all i-th elements of | ||
`step_output` from all steps, stacked along axis 0. | ||
The second list has the length of `|loop_vars|`, | ||
which represents final states of loop variables. | ||
|
||
Warning 1: when `cond` is never satisfied, we assume `step_output` is empty, | ||
because it cannot be inferred. This is different from the symbolic version. | ||
|
||
Warning 2: The output shape along axis 0 is currently the actual number of iterations taken, | ||
which is different from the symbolic version, where it is `max_iteration`. | ||
|
||
Parameters | ||
---------- | ||
cond: a Python function. | ||
The loop condition. | ||
func: a Python function. | ||
The loop body. | ||
loop_vars: list of NDArrays. | ||
The initial values of the loop variables. | ||
max_iteration: a python int. | ||
Maximum number of iterations. | ||
|
||
Returns | ||
------- | ||
outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays. | ||
The first list contains the stacked output from each step, | ||
The second list contains the final state. | ||
|
||
Examples | ||
-------- | ||
>>> cond = lambda i, s: i <= 5 | ||
>>> func = lambda i, s: ([i + s], [i + 1, s + i]) | ||
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) | ||
>>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10) | ||
""" | ||
def _to_python_scalar(inputs, type_, name): | ||
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, | ||
to the given type | ||
""" | ||
if isinstance(inputs, ndarray.NDArray): | ||
inputs = inputs.asscalar() | ||
try: | ||
inputs = type_(inputs) | ||
except: | ||
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) | ||
return inputs | ||
|
||
def _to_ndarray_tuple(inputs, name): | ||
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, | ||
a tuple of mxnet NDArray, into a tuple of NDArray | ||
""" | ||
if isinstance(inputs, list): | ||
inputs = tuple(inputs) | ||
if isinstance(inputs, ndarray.NDArray): | ||
inputs = (inputs, ) | ||
if not isinstance(inputs, tuple): | ||
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) | ||
for item in inputs: | ||
if not isinstance(item, ndarray.NDArray): | ||
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) | ||
return inputs | ||
|
||
def _func_wrapper(loop_vars): | ||
"""This wrapper unifies | ||
"func: loop_vars -> new_loop_vars" | ||
and "func: loop_vars -> (step_output, new_loop_vars)" | ||
into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) | ||
""" | ||
step_output, new_loop_vars = func(*loop_vars) | ||
if step_output is None: | ||
step_output = [] | ||
if new_loop_vars is None: | ||
new_loop_vars = [] | ||
step_output = _to_ndarray_tuple(step_output, "step_output") | ||
new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") | ||
if len(loop_vars) != len(new_loop_vars): | ||
raise ValueError("The length of loop_vars should be consistent during the loop") | ||
return step_output, new_loop_vars | ||
|
||
if max_iterations is None: | ||
raise ValueError("max_iterations should be specified") | ||
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") | ||
loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") | ||
# It should be work as fine if loop_vars are empty I guess, | ||
# but it is semantically unnecessary to include this case. | ||
if len(loop_vars) == 0: | ||
raise ValueError("loop_vars should contain at least one element") | ||
|
||
steps = 0 | ||
outputs = [] | ||
while steps < max_iterations and \ | ||
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this could end before reaching max_iterations. Isn't this inconsistent with symbol? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zheng-da So should I pad the arrays to make them consistent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to do so, in my opinion. what do you think? @piiswrong There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break |
||
step_output, loop_vars = _func_wrapper(loop_vars) | ||
outputs.append(step_output) | ||
steps += 1 | ||
if len(outputs) != steps or len(step_output) != len(outputs[0]): | ||
raise ValueError("Number of elements in step_output should be the same in each step") | ||
stacked_outputs = [] | ||
for i_th, items in enumerate(zip(*outputs), 1): | ||
try: | ||
stacked_outputs.append(ndarray.op.stack(*items)) | ||
except ValueError: | ||
raise ValueError("\n".join( | ||
["Shapes of %d-th elements in step_outputs are inconsistent, which are:" % i_th] + | ||
[" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] | ||
)) | ||
return stacked_outputs, list(loop_vars) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: there's quite a bit of repetition in the below code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't touch this file. It is renamed from https://github.com/apache/incubator-mxnet/blob/master/benchmark/python/control_flow/rnn.py. Should I simplify this in this PR, or in a separate one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can coordinate with @zheng-da
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Da and I decide that I rewrite these two files. Will push a commit later today.