-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-626] Add while_loop #11566
[MXNET-626] Add while_loop #11566
Changes from 10 commits
6976b90
249c8b4
cfa13b1
9ca3dd5
6418065
ad0accc
8edb051
dc48a7f
06d29cb
316b0f7
9572a87
e603170
5d298bb
43128c0
f241e3c
e393bd0
1b11670
4e4f5f9
6736e3d
16e2823
93d8d0c
ca4d7b0
e067d0b
c08b063
9b219d9
3ea7bda
168bd27
aa9722d
e69b674
dfc1828
bd48b77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,3 +191,128 @@ def check_input(inputs, in_type, msg): | |
if not_data_list and len(outputs) == 1: | ||
outputs = outputs[0] | ||
return (outputs, states) | ||
|
||
|
||
def while_loop(loop_vars, cond, func, max_iterations): | ||
"""Run a while loop with user-defined computation and loop condition. | ||
|
||
This operator simulates a while loop which iterately does customized computation | ||
as long as the condition is satisfied. | ||
|
||
`loop_vars` is a list of NDArrays on which the computation uses. | ||
|
||
`cond` is a user-defined function as the loop condition. | ||
It consumes `loop_vars`, and produces a scalar MXNet NDArray, | ||
indicating the termination of the loop. | ||
The loop ends when `cond` returns false (zero). | ||
The `cond` is variadic, and its signature should be | ||
`cond(*loop_vars) => NDArray`. | ||
|
||
`func` is a user-defined function as the loop body. | ||
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. | ||
The number of elements, shape, dtype of each element in `step_output` should be consistent. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does consistent mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I change this to
Does this seem better? |
||
The `new_loop_vars` should be consistent with `loop_vars` on each step. | ||
The `func` is variadic, and its signature should be | ||
`cond(*loop_vars) => (List[NDArray] step_output, List[NDArray] new_loop_vars)`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed :-) |
||
|
||
`max_iterations` is a scalar that defines the maximum number of iterations allowed. | ||
|
||
This function returns a list of NDArrays of length `|step_output| + |loop_vars|`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return value has different format from TF. Any specific reasons? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the outdated docstring to the following:
Currently we don't have dynamic shape inference, so could not support TF-like dynamic-sized, per-time-step There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the interface of this while_loop operator is close to the ONNX definition. https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So does it return two lists separately or concated together? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @piiswrong Separately There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function returns two lists as a tuple -> This function returns two lists. "as a tuple" makes it sounds like the two lists are concated into a tuple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @piiswrong fixed |
||
The i-th element in the first `|step_output|` ones of the list represent | ||
the i-th `step_output` at all step, stacked along axis 0. | ||
The i-th element in the last `|loop_vars|` ones of the list | ||
represent the final state of each loop variable. | ||
|
||
Warning 1: when `cond` is never satisfied, we assume `step_output` is empty. | ||
Warning 2: The output shape along axis 0 is currently `max_iteration`, | ||
which not consistent to the symbloic version. | ||
|
||
Parameters | ||
---------- | ||
loop_vars: list of NDArrays. | ||
The initial values of the loop variables. | ||
cond: a Python function. | ||
The loop condition. | ||
func: a Python function. | ||
The loop body. | ||
max_iteration: a python int. | ||
Maximum number of iterations. | ||
|
||
Returns | ||
------- | ||
outputs: a tuple of two lists, which both contains 0, 1 or more NDArrays. | ||
The first list contains the stacked output from each step, | ||
The second list contains the final state. | ||
|
||
Examples | ||
-------- | ||
>>> cond = lambda i, s: i <= 5 | ||
>>> func = lambda i, s: ([i + s], [i + 1, s + i]) | ||
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64")) | ||
>>> outputs, states = mx.nd.contrib.while_loop(loop_vars, cond, func, max_iterations=10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you show the output results of this example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The results are
Should I put this snippet into docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
""" | ||
def _to_python_scalar(inputs, type_, name): | ||
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, | ||
to the given type | ||
""" | ||
if isinstance(inputs, ndarray.NDArray): | ||
inputs = inputs.asscalar() | ||
try: | ||
inputs = type_(inputs) | ||
except: | ||
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) | ||
return inputs | ||
|
||
def _to_ndarray_tuple(inputs, name): | ||
"""Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, | ||
a tuple of mxnet NDArray, into a tuple of NDArray | ||
""" | ||
if isinstance(inputs, list): | ||
inputs = tuple(inputs) | ||
if isinstance(inputs, ndarray.NDArray): | ||
inputs = (inputs, ) | ||
if not isinstance(inputs, tuple): | ||
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) | ||
for item in inputs: | ||
if not isinstance(item, ndarray.NDArray): | ||
raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) | ||
return inputs | ||
|
||
def _func_wrapper(loop_vars): | ||
"""This wrapper unifies | ||
"func: loop_vars -> new_loop_vars" | ||
and "func: loop_vars -> (step_output, new_loop_vars)" | ||
into "func: loop_vars -> (None or tuple of step_outputs, tuple of new_loop_vars) | ||
""" | ||
step_output, new_loop_vars = func(*loop_vars) | ||
if step_output is None: | ||
step_output = [] | ||
if new_loop_vars is None: | ||
new_loop_vars = [] | ||
step_output = _to_ndarray_tuple(step_output, "step_output") | ||
new_loop_vars = _to_ndarray_tuple(new_loop_vars, "new_loop_vars") | ||
if len(loop_vars) != len(new_loop_vars): | ||
raise ValueError("The length of loop_vars should be consistent during the loop") | ||
return step_output, new_loop_vars | ||
|
||
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") | ||
loop_vars = _to_ndarray_tuple(loop_vars, "loop_vars") | ||
# It should be work as fine if loop_vars are empty I guess, | ||
# but it is semantically unnecessary to include this case. | ||
if len(loop_vars) == 0: | ||
raise ValueError("loop_vars should contain at least one element") | ||
|
||
steps = 0 | ||
outputs = [] | ||
while steps < max_iterations and \ | ||
_to_python_scalar(cond(*loop_vars), bool, "Return value of cond"): # loop condition | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this could end before reaching max_iterations. Isn't this inconsistent with symbol? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, they are not consistent, and I put a warning in the docstring. Should I do some padding stuff so that they look the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zheng-da So should I pad the arrays to make them consistent? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to do so, in my opinion. what do you think? @piiswrong There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, ndarray and symbol functions should give the same result for the same input. Otherwise hybridize may break |
||
step_output, loop_vars = _func_wrapper(loop_vars) | ||
outputs.append(step_output) | ||
steps += 1 | ||
if len(outputs) != steps or len(step_output) != len(outputs[0]): | ||
raise ValueError("step_output are inconsistent on each step") | ||
try: | ||
outputs = list(ndarray.op.stack(*item) for item in zip(*outputs)) | ||
except ValueError: | ||
raise ValueError("step_outputs are inconsistent on each step") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. be explicit about which value is inconsistent. Print out the inconsistent shapes if possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
return outputs, list(loop_vars) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -336,3 +336,205 @@ def check_data(inputs, in_type, msg): | |
states = states[0] | ||
|
||
return (outs, states) | ||
|
||
def while_loop(cond, func, loop_vars, max_iterations, name="while_loop"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interface different from ndarray? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. Fixed |
||
"""Run a while loop with user-defined computation and loop condition. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is max_iterations always required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, without dynamic shape, a user has to provide max_iterations |
||
|
||
This operator simulates a while loop which iterately does customized computation | ||
as long as the condition is satisfied. | ||
|
||
`loop_vars` is a list of Symbols on which the computation uses. | ||
|
||
`cond` is a user-defined function as the loop condition. | ||
It consumes `loop_vars`, and produces a scalar MXNet symbol, | ||
indicating the termination of the loop. | ||
The loop ends when `cond` returns false (zero). | ||
The `cond` is variadic, and its signature should be | ||
`cond(*loop_vars) => Symbol`. | ||
|
||
`func` is a user-defined function as the loop body. | ||
It also consumes `loop_vars`, and produces `step_output` and `new_loop_vars` at each step. | ||
The number of elements, shape, dtype of each element in `step_output` should be consistent. | ||
The `new_loop_vars` should be consistent with `loop_vars` on each step. | ||
The `func` is variadic, and its signature should be | ||
`cond(*loop_vars) => (List[Symbol] step_output, List[Symbol] new_loop_vars)`. | ||
|
||
`max_iterations` is a scalar that defines the maximum number of iterations allowed. | ||
|
||
This function returns a list of Symbols of length `|step_output| + |loop_vars|`. | ||
The i-th element in the first `|step_output|` ones of the list represent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand this sentence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad, this is outdated. Does the following seem better?
|
||
the i-th `step_output` at all step, stacked along axis 0. | ||
The i-th element in the last `|loop_vars|` ones of the list | ||
represent the final state of each loop variable. | ||
|
||
Parameters | ||
---------- | ||
loop_vars: list of Symbol. | ||
The initial values of the loop variables. | ||
cond: a Python function. | ||
The loop condition. | ||
func: a Python function. | ||
The loop body. | ||
max_iteration: a python int. | ||
Maximum number of iterations. | ||
|
||
Returns | ||
------- | ||
outputs: a tuple of two lists, which both contains 0, 1 or more Symbols. | ||
The first list contains the stacked output from each step, | ||
The second list contains the final state. | ||
|
||
Examples | ||
-------- | ||
>>> cond = lambda i, s: i <= 5 | ||
>>> func = lambda i, s: ([i + s], [i + 1, s + i]) | ||
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s')) | ||
>>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10) | ||
""" | ||
def _to_python_scalar(inputs, type_, name): | ||
"""Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, | ||
to the given type | ||
""" | ||
if hasattr(inputs, "asscalar"): | ||
inputs = inputs.asscalar() | ||
try: | ||
inputs = type_(inputs) | ||
except: | ||
raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) | ||
return inputs | ||
|
||
def _to_symbol_tuple(inputs, name): | ||
"""Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, | ||
a tuple of mxnet Symbol, into a tuple of Symbol | ||
""" | ||
if isinstance(inputs, list): | ||
inputs = tuple(inputs) | ||
if isinstance(inputs, Symbol): | ||
inputs = (inputs, ) | ||
if not isinstance(inputs, tuple): | ||
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) | ||
for item in inputs: | ||
if not isinstance(item, Symbol): | ||
raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) | ||
return inputs | ||
|
||
def _cond_wrapper(loop_vars): | ||
result = cond(*loop_vars) | ||
if not isinstance(result, Symbol): | ||
raise ValueError("Return of cond must be a Symbol") | ||
return [], [result] | ||
|
||
def _func_wrapper(loop_vars): | ||
"""This wrapper unifies | ||
"func: loop_vars -> new_loop_vars" | ||
and "func: loop_vars -> (step_output, new_loop_vars)" | ||
into "func: loop_vars -> (list of step_outputs, tuple of new_loop_vars) | ||
""" | ||
step_output, new_loop_vars = func(*loop_vars) | ||
if step_output is None: | ||
step_output = [] | ||
if new_loop_vars is None: | ||
new_loop_vars = [] | ||
step_output = _to_symbol_tuple(step_output, "step_output") | ||
new_loop_vars = _to_symbol_tuple(new_loop_vars, "new_loop_vars") | ||
if len(loop_vars) != len(new_loop_vars): | ||
raise ValueError("The number of loop_vars should be consistent during the loop") | ||
return list(step_output), list(new_loop_vars) | ||
|
||
def _create_subgraph(graph_vars, graph_func, subgraph_name): | ||
with AttrScope(__subgraph_name__=subgraph_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it matter if user doesn't provide name and subgraph_name duplicates? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably not. the C code that cuts the subgraph looks for nodes with the attribute of |
||
# create new variables with the same name, | ||
# them feed them to the given func | ||
new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] | ||
outputs, final_state = graph_func(new_graph_vars) | ||
# first `num_out_data` elements belong to `outputs` | ||
# other elements belong to `final_state` | ||
num_out_data = len(outputs) | ||
num_outputs = len(outputs) + len(final_state) | ||
# nnvm cut-graph does not allow inputs and outputs overlap | ||
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs | ||
all_input_names = symbol.Group(outputs + final_state).list_inputs() | ||
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x | ||
# group all outputs of graph_func | ||
graph = symbol.Group(list(map(make_identity, outputs + final_state))) | ||
return graph, num_out_data, num_outputs | ||
|
||
def _union_inputs(*graphs): | ||
# Given a list of graphs, each whose inputs are either from loop_vars or other variables. | ||
# 1) calculate a list `inputs`, the union of their inputs. | ||
# 2) for each graph, determine in which indices their inputs reside in `inputs` | ||
# 3) for each variable in the input of `graph`, find which index it is | ||
inputs = [] # List[Symbol], result of 1) | ||
locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, | ||
# where tuples are results of 2) and 3) | ||
input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it | ||
# to a `loc`, where inputs[loc] = sym | ||
for graph in graphs: | ||
# input_syms: all inputs to the `graph` | ||
name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} | ||
# some loop_vars are inputs to `graph`, some are not | ||
name_to_loop_vars = {sym.name: sym for sym in loop_vars} | ||
# other inputs to `graph` created by cut_graph | ||
name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like that they are equivalent. Just copied from this line in foreach. |
||
# also we collect the mapping from var's name to var's loc in loop_vars | ||
name_to_var_locs = {sym.name: i for i, sym in enumerate(loop_vars)} | ||
# collect arguments for each subgraph | ||
input_locs = [] # results from the second step | ||
var_locs = [-1] * len(loop_vars) # results from the third step | ||
for name in graph.list_inputs(): | ||
assert name in name_to_input_syms # it should obviously hold | ||
# name -> sym | ||
if name in name_to_loop_vars: | ||
sym = name_to_loop_vars[name] | ||
elif name in name_to_cut_g_syms: | ||
sym = name_to_cut_g_syms[name] | ||
else: | ||
sym = copy.deepcopy(name_to_input_syms[name]) | ||
# do 2), and 1) is implicitly done | ||
if id(sym) in input_id_to_loc: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are several subgraphs, more specifically, two subgraphs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the id instead of checking for the symbol directly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @szha There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
loc = input_id_to_loc[id(sym)] | ||
else: | ||
loc = len(input_id_to_loc) | ||
inputs.append(sym) | ||
input_id_to_loc[id(sym)] = loc | ||
input_locs.append(loc) | ||
# do 3) | ||
if name in name_to_var_locs: | ||
var_locs[name_to_var_locs[name]] = len(input_locs) - 1 | ||
locs.append((input_locs, var_locs)) | ||
return inputs, locs | ||
max_iterations = _to_python_scalar(max_iterations, int, "max_iteration") | ||
loop_vars = _to_symbol_tuple(loop_vars, "loop_vars") | ||
# It should be work as fine if loop_vars are empty I guess, | ||
# but it is semantically unnecessary to include this case. | ||
if len(loop_vars) == 0: | ||
raise ValueError("loop_vars should contain at least one element") | ||
# create graph for `cond' | ||
cond_g, num_out_data, num_outputs = \ | ||
_create_subgraph(loop_vars, _cond_wrapper, name + "_cond") | ||
assert num_out_data == 0 | ||
assert num_outputs == 1 | ||
# create graph for `func` | ||
func_g, num_out_data, num_outputs = \ | ||
_create_subgraph(loop_vars, _func_wrapper, name + "_func") | ||
# find symbols used in either cond_g or func_g | ||
input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ | ||
_union_inputs(cond_g, func_g) | ||
for i_th, loc in enumerate(func_var_locs): | ||
if loc == -1: | ||
raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) | ||
result = symbol._internal._while_loop( | ||
# [cond, func_g, *input_syms] | ||
cond_g, | ||
func_g, | ||
*input_syms, | ||
max_iterations=max_iterations, | ||
cond_input_locs=cond_input_locs, | ||
func_input_locs=func_input_locs, | ||
func_var_locs=func_var_locs, | ||
num_out_data=num_out_data, | ||
num_outputs=num_outputs | ||
) | ||
outputs = [result[i] for i in range(num_out_data)] | ||
final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] | ||
return outputs, final_loop_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put loop_vars after func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Fixed :-)