Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-626] Add while_loop #11566

Merged
merged 31 commits into from
Jul 19, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6976b90
Add while_loop
junrushao Jul 5, 2018
249c8b4
Avoid input/output overlap for nnvm graph cut
junrushao Jul 6, 2018
cfa13b1
Add more testcases
junrushao Jul 6, 2018
9ca3dd5
Enhance test 4.2
junrushao Jul 6, 2018
6418065
Add more complicated testcases; Add testcase for nested loop
junrushao Jul 7, 2018
ad0accc
Check unused loop_vars in while_loop
junrushao Jul 7, 2018
8edb051
Add testcases for RNN
junrushao Jul 8, 2018
dc48a7f
Make lint happy
junrushao Jul 8, 2018
06d29cb
Make lint happy
junrushao Jul 8, 2018
316b0f7
Address TODOs
junrushao Jul 8, 2018
9572a87
Fix flaky test for while_loop
junrushao Jul 9, 2018
e603170
Address comments
junrushao Jul 9, 2018
5d298bb
Improve docstring
junrushao Jul 10, 2018
43128c0
Improve error message
junrushao Jul 10, 2018
f241e3c
Add benchmark code
junrushao Jul 10, 2018
e393bd0
Update benchmarks
junrushao Jul 10, 2018
1b11670
Allow sparse types
junrushao Jul 11, 2018
4e4f5f9
Make max_iterations default to None
junrushao Jul 11, 2018
6736e3d
Add while_loop to docs/api/python/{symbol|ndarray}/contrib.md
junrushao Jul 12, 2018
16e2823
Pad imperative while_loop so that it has the same shape with the symb…
junrushao Jul 12, 2018
93d8d0c
Add example result into the example section
junrushao Jul 12, 2018
ca4d7b0
Remove unused class member
junrushao Jul 12, 2018
e067d0b
Rename unittest to test_contrib_control_flow.py
junrushao Jul 12, 2018
c08b063
Update docstring
junrushao Jul 13, 2018
9b219d9
Update docstring
junrushao Jul 13, 2018
3ea7bda
Trigger CI
junrushao Jul 13, 2018
168bd27
Change threshold for assert_almost_equal
junrushao Jul 13, 2018
aa9722d
Trigger CI
junrushao Jul 13, 2018
e69b674
Address comments from szha
junrushao Jul 18, 2018
dfc1828
Rewrite benchmark code
junrushao Jul 18, 2018
bd48b77
Fix sphinx warning
junrushao Jul 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 6ab4da to 290226
125 changes: 125 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,128 @@ def check_input(inputs, in_type, msg):
if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)


def while_loop(loop_vars, cond, func, max_iterations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put loop_vars after func?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Fixed :-)

"""Run a while loop with user-defined computation and loop condition.

This operator simulates a while loop which iterately does customized computation
as long as the condition is satisfied.

`loop_vars` is a list of NDArrays on which the computation uses.

`cond` is a user-defined function as the loop condition.
It consumes `loop_vars`, and produces a scalar MXNet NDArray,
indicating the termination of the loop.
The loop ends when `cond` returns false (zero).
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => NDArray`.

`func` is a user-defined function as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
The number of elements, shape, dtype of each element in `step_output` should be consistent.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does consistent mean?

Copy link
Member Author

@junrushao junrushao Jul 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I change this to

In each step, step_output should contain the same number elements. Through all steps, the i-th element of step_output should have the same shape and dtype. Also, new_loop_vars should contain the same number of elements as loop_vars, and the corresponding element should have the same shape and dtype.

Does this seem better?

The `new_loop_vars` should be consistent with `loop_vars` on each step.
The `func` is variadic, and its signature should be
`cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cond => func.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :-)


`max_iterations` is a scalar that defines the maximum number of iterations allowed.

This function returns a list of NDArrays of length `|step_output| + |loop_vars|`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value has different format from TF. Any specific reasons?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the outdated docstring to the following:

This function returns two lists as a tuple. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.

Currently we don't have dynamic shape inference, so could not support TF-like dynamic-sized, per-time-step TensorArray. So we split our requirement into two parts: 1. per-time-step array with max_iteration; 2. loop variables which are of the same shape through the loop.

@zheng-da

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the interface of this while_loop operator is close to the ONNX definition. https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop
The interface of TF while_loop requires dynamic shape and also doesn't allow efficient implementation. This interface is more flexible. If it returns ([], loop_vars), it's the same as the TF interface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does it return two lists separately or concated together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong Separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two lists as a tuple -> This function returns two lists.

"as a tuple" makes it sounds like the two lists are concated into a tuple

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong fixed

The i-th element in the first `|step_output|` ones of the list represent
the i-th `step_output` at all step, stacked along axis 0.
The i-th element in the last `|loop_vars|` ones of the list
represent the final state of each loop variable.

Warning 1: when `cond` is never satisfied, we assume `step_output` is empty.
Warning 2: The output shape along axis 0 is currently `max_iteration`,
which not consistent to the symbloic version.

Parameters
----------
loop_vars: list of NDArrays.
The initial values of the loop variables.
cond: a Python function.
The loop condition.
func: a Python function.
The loop body.
max_iteration: a python int.
Maximum number of iterations.

Returns
-------
outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays.
The first list contains the stacked output from each step,
The second list contains the final state.

Examples
--------
>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you show the output results of this example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results are

>>> outputs
[
 [[ 1]
  [ 2]
  [ 4]
  [ 7]
  [11]
  [16]]
 <NDArray 6x1 @cpu(0)>]
>>> states
[
 [6]
 <NDArray 1 @cpu(0)>,
 [16]
 <NDArray 1 @cpu(0)>]

Should I put this snippet into docstring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"""
def _to_python_scalar(inputs, type_, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if isinstance(inputs, ndarray.NDArray):
inputs = inputs.asscalar()
try:
inputs = type_(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__))
return inputs

def _to_ndarray_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray,
a tuple of mxnet NDArray, into a tuple of NDArray
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, ndarray.NDArray):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
for item in inputs:
if not isinstance(item, ndarray.NDArray):
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, ))
return inputs

def _func_wrapper(loop_vars):
"""This wrapper unifies
"func: loop_vars -> new_loop_vars"
and "func: loop_vars -> (step_output, new_loop_vars)"
into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars)
"""
step_output, new_loop_vars = func(*loop_vars)
if step_output is None:
step_output = []
if new_loop_vars is None:
new_loop_vars = []
step_output = _to_ndarray_tuple(step_output, "step_output")
new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars")
if len(loop_vars) != len(new_loop_vars):
raise ValueError("The length of loop_vars should be consistent during the loop")
return step_output, new_loop_vars

max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars")
# It should be work as fine if loop_vars are empty I guess,
# but it is semantically unnecessary to include this case.
if len(loop_vars) == 0:
raise ValueError("loop_vars should contain at least one element")

steps = 0
outputs = []
while steps < max_iterations and \
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this could end before reaching max_iterations. Isn't this inconsistent with symbol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think so.

Copy link
Member Author

@junrushao junrushao Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zheng-da So should I pad the arrays to make them consistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to do so, in my opinion. what do you think? @piiswrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break

step_output, loop_vars = _func_wrapper(loop_vars)
outputs.append(step_output)
steps += 1
if len(outputs) != steps or len(step_output) != len(outputs[0]):
raise ValueError("step_output are inconsistent on each step")
try:
outputs = list(ndarray.op.stack(*item) for item in zip(*outputs))
except ValueError:
raise ValueError("step_outputs are inconsistent on each step")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be explicit about which value is inconsistent. Print out the inconsistent shapes if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

return outputs, list(loop_vars)
202 changes: 202 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg):
states = states[0]

return (outs, states)

def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interface different from ndarray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Fixed

"""Run a while loop with user-defined computation and loop condition.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is max_iterations always required?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, without dynamic shape, a user has to provide max_iterations


This operator simulates a while loop which iterately does customized computation
as long as the condition is satisfied.

`loop_vars` is a list of Symbols on which the computation uses.

`cond` is a user-defined function as the loop condition.
It consumes `loop_vars`, and produces a scalar MXNet symbol,
indicating the termination of the loop.
The loop ends when `cond` returns false (zero).
The `cond` is variadic, and its signature should be
`cond(*loop_vars) => Symbol`.

`func` is a user-defined function as the loop body.
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step.
The number of elements, shape, dtype of each element in `step_output` should be consistent.
The `new_loop_vars` should be consistent with `loop_vars` on each step.
The `func` is variadic, and its signature should be
`cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`.

`max_iterations` is a scalar that defines the maximum number of iterations allowed.

This function returns a list of Symbols of length `|step_output| + |loop_vars|`.
The i-th element in the first `|step_output|` ones of the list represent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this sentence.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, this is outdated. Does the following seem better?

This function returns two lists as a tuple. The first list has the length of |step_output|, in which the i-th element are all i-th elements of step_output from all steps, stacked along axis 0. The second list has the length of |loop_vars|, which represents final states of loop variables.

the i-th `step_output` at all step, stacked along axis 0.
The i-th element in the last `|loop_vars|` ones of the list
represent the final state of each loop variable.

Parameters
----------
loop_vars: list of Symbol.
The initial values of the loop variables.
cond: a Python function.
The loop condition.
func: a Python function.
The loop body.
max_iteration: a python int.
Maximum number of iterations.

Returns
-------
outputs: a tuple of two lists, which both contains 0, 1 or more Symbols.
The first list contains the stacked output from each step,
The second list contains the final state.

Examples
--------
>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
>>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10)
"""
def _to_python_scalar(inputs, type_, name):
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types,
to the given type
"""
if hasattr(inputs, "asscalar"):
inputs = inputs.asscalar()
try:
inputs = type_(inputs)
except:
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__))
return inputs

def _to_symbol_tuple(inputs, name):
"""Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol,
a tuple of mxnet Symbol, into a tuple of Symbol
"""
if isinstance(inputs, list):
inputs = tuple(inputs)
if isinstance(inputs, Symbol):
inputs = (inputs, )
if not isinstance(inputs, tuple):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
for item in inputs:
if not isinstance(item, Symbol):
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, ))
return inputs

def _cond_wrapper(loop_vars):
result = cond(*loop_vars)
if not isinstance(result, Symbol):
raise ValueError("Return of cond must be a Symbol")
return [], [result]

def _func_wrapper(loop_vars):
"""This wrapper unifies
"func: loop_vars -> new_loop_vars"
and "func: loop_vars -> (step_output, new_loop_vars)"
into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars)
"""
step_output, new_loop_vars = func(*loop_vars)
if step_output is None:
step_output = []
if new_loop_vars is None:
new_loop_vars = []
step_output = _to_symbol_tuple(step_output, "step_output")
new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars")
if len(loop_vars) != len(new_loop_vars):
raise ValueError("The number of loop_vars should be consistent during the loop")
return list(step_output), list(new_loop_vars)

def _create_subgraph(graph_vars, graph_func, subgraph_name):
with AttrScope(__subgraph_name__=subgraph_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it matter if user doesn't provide name and subgraph_name duplicates?

Copy link
Contributor

@zheng-da zheng-da Jul 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not. the C code that cuts the subgraph looks for nodes with the attribute of __subgraph_name__

# create new variables with the same name,
# them feed them to the given func
new_graph_vars = [symbol.var(sym.name) for sym in graph_vars]
outputs, final_state = graph_func(new_graph_vars)
# first `num_out_data` elements belong to `outputs`
# other elements belong to `final_state`
num_out_data = len(outputs)
num_outputs = len(outputs) + len(final_state)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs + final_state).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs

def _union_inputs(*graphs):
# Given a list of graphs, each whose inputs are either from loop_vars or other variables.
# 1) calculate a list `inputs`, the union of their inputs.
# 2) for each graph, determine in which indices their inputs reside in `inputs`
# 3) for each variable in the input of `graph`, find which index it is
inputs = [] # List[Symbol], result of 1)
locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples,
# where tuples are results of 2) and 3)
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it
# to a `loc`, where inputs[loc] = sym
for graph in graphs:
# input_syms: all inputs to the `graph`
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)}
# some loop_vars are inputs to `graph`, some are not
name_to_loop_vars = {sym.name: sym for sym in loop_vars}
# other inputs to `graph` created by cut_graph
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between list_outputs()[0] and name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like that they are equivalent. Just copied from this line in foreach.

# also we collect the mapping from var's name to var's loc in loop_vars
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)}
# collect arguments for each subgraph
input_locs = [] # results from the second step
var_locs = [-1] * len(loop_vars) # results from the third step
for name in graph.list_inputs():
assert name in name_to_input_syms # it should obviously hold
# name -> sym
if name in name_to_loop_vars:
sym = name_to_loop_vars[name]
elif name in name_to_cut_g_syms:
sym = name_to_cut_g_syms[name]
else:
sym = copy.deepcopy(name_to_input_syms[name])
# do 2), and 1) is implicitly done
if id(sym) in input_id_to_loc:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does id(sym) exist in input_id_to_loc before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several subgraphs, more specifically, two subgraphs cond and func in while_loop. They may have common input symbols, so these symbols may have been added to input_id_to_loc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the id instead of checking for the symbol directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha mx.sym.Symbol.__eq__ has been overridden and returns an NDArray instead of bool. Thus directly using sym as keys of a dict won't work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

loc = input_id_to_loc[id(sym)]
else:
loc = len(input_id_to_loc)
inputs.append(sym)
input_id_to_loc[id(sym)] = loc
input_locs.append(loc)
# do 3)
if name in name_to_var_locs:
var_locs[name_to_var_locs[name]] = len(input_locs) - 1
locs.append((input_locs, var_locs))
return inputs, locs
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration")
loop_vars = _to_symbol_tuple(loop_vars, "loop_vars")
# It should be work as fine if loop_vars are empty I guess,
# but it is semantically unnecessary to include this case.
if len(loop_vars) == 0:
raise ValueError("loop_vars should contain at least one element")
# create graph for `cond'
cond_g, num_out_data, num_outputs = \
_create_subgraph(loop_vars, _cond_wrapper, name + "_cond")
assert num_out_data == 0
assert num_outputs == 1
# create graph for `func`
func_g, num_out_data, num_outputs = \
_create_subgraph(loop_vars, _func_wrapper, name + "_func")
# find symbols used in either cond_g or func_g
input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \
_union_inputs(cond_g, func_g)
for i_th, loc in enumerate(func_var_locs):
if loc == -1:
raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th)
result = symbol._internal._while_loop(
# [cond, func_g, *input_syms]
cond_g,
func_g,
*input_syms,
max_iterations=max_iterations,
cond_input_locs=cond_input_locs,
func_input_locs=func_input_locs,
func_var_locs=func_var_locs,
num_out_data=num_out_data,
num_outputs=num_outputs
)
outputs = [result[i] for i in range(num_out_data)]
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)]
return outputs, final_loop_vars
Loading