From 030fbc3c7baec4a0f0cce78a45aa1da00eea0b48 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 2 Jul 2018 15:26:51 -0700 Subject: [PATCH] [MXNET-432] Add Foreach (#11531) * Test input a graph. * Update foreach to execute the subgraph. * print inputs/outputs in foreach. * Remove print. * add test code for foreach. * exec foreach outside the engine. * Implements forward of foreach. * Add support for variable numbers of inputs and outputs. * Add a python wrapper for foreach. * Fix the order of inputs. * add test with lstm. * hide C version of foreach. * fix a bug temporarily. * Test free variables. * change for the new interface of InputGraph attribute. * Add attribute to the subgraph. * Handle free variables. * Get all input symbols of a subgraph. * Fix shape, dtype and storage inference. * reorganize the output of foreach. * Add a gluon RNN unroll with symbol foreach. * print unnecessary print. * have imperative and symbolic foreach. * Fix an error after moving foreach. * Fix imperative foreach * Fix a minor problem. * Use CachedOp to execute subgraph. * update TODO. * make foreach op use FStatefulComputeEx. TODO we need to change stateful executor to handle subgraph. * Add backward. * Fix bugs. * enable backward test in lstm. * Fix a bug in foreach backward for free variables. * change for the new CachedOp. * Detect the backward computation. * Fix bugs in foreach. * fix tests. * update tests. * check state shape. * enable nested foreach. * remove print. * fix a bug in test. * handle infer storage type for backward. * address comments. * address comments. * move some common functions out. * address comments. * fix lint. * Fix lint. * add doc. * undo modification in imperative.h * add doc and remove example code. * fix lint. * fix lint. * Fix lint. * make nd.foreach and sym.foreach consistent. * fix compile error. * address comments. * update. * check for loop only works for dense arrays. * move control flow op out of nn/ * fix include. * add a test in gluon. * work for GPU. * small fix. * remove subgraph_name * create loop state for reuse in the future. * move code. * Revert "remove subgraph_name" This reverts commit 977f5624ad0b0dedb9dcb8629f975afc56bb1e1a. * cut graph. * rename new var nodes. * Fix tests. * Fix bugs caused by ctypes (#29) * Add save/load json in testcases for foreach (#30) * support subgraph in stateful executor. * Fix compilation. * fix a bug when a subgraph has variable nodes. * Fix a bug of getting symbols. * copy var nodes. * Fix getting op states. * fix lint error. * address comments. * fix lint error. * simplify the execution of subgraph in the main thread. * fix lint error. * avoid waiting for computation in each iteration. * reuse cached op for inference. * share memory across mini-batches. * reuse memory. reuse memory between iterations in inference. reuse memory between mini-batches in training. * add tests for multiple batches. * remove entry. * add benchmark for foreach. * benchmark large batch size. * Fix the benchmark for GPU. * address comments. * update shape/dtype/storage inference. * update contrib API docs. * support nested foreach. * use a single CachedOp for all iterations. * use large dim. * update benchmark. * update benchmark. * update benchmark. * update benchmark. * return symbol arrays correctly in MXSymbolCutSubgraph. * return symbol arrays in MXSymbolGetInputSymbols. * fix lint error. * use cachedop to infer storage in backward. * fix scala API. * update comments. * fix scala. * fix test. * fix attribute name. * move benchmark. * fix the mapping of operator inputs/outputs and subgraph inputs/outputs. * add tests for dtype/shape inference. * reorganize tests. * fix a bug of cutting NodeEntry. When two node entries refer to the same output of a node, we should create only one var node for these two node entries. * fix lint error. * handle the case that outputs are inputs. * handle the case that inputs aren't used. * handle the case without output data. * fix a bug in foreach backward. * fix a bug when there isn't output data. * Fix lint error. * test diff Gluon RNN cells. * test all symbol RNN cells. * adjust the test precision. * Fix a bug in getting a list of variable names. We can't get a list of variable names from a hashtable. The order can't be guaranteed. Python2 and Python3 output different orders. * fix lint error. * Test 1D array. * fix a bug when subgraph inputs and outputs share NDArray. * fix. * fix * add comments. --- benchmark/python/control_flow/rnn.py | 189 ++++++ docs/api/python/ndarray/contrib.md | 1 + docs/api/python/symbol/contrib.md | 1 + include/mxnet/c_api.h | 22 + include/mxnet/op_attr_types.h | 11 +- python/mxnet/ndarray/contrib.py | 96 +++ python/mxnet/symbol/contrib.py | 247 +++++++- .../apache/mxnet/utils/CToScalaUtils.scala | 2 +- src/c_api/c_api_symbolic.cc | 71 +++ src/executor/graph_executor.cc | 12 +- src/executor/graph_executor.h | 2 + src/imperative/cached_op.cc | 3 +- src/imperative/cached_op.h | 2 +- src/imperative/imperative_utils.h | 29 +- src/ndarray/ndarray.cc | 3 +- src/nnvm/graph_editor.cc | 108 ++++ src/operator/control_flow.cc | 545 ++++++++++++++++++ src/operator/subgraph_op_common.cc | 260 +++++++++ src/operator/subgraph_op_common.h | 99 ++++ tests/python/unittest/test_gluon_rnn.py | 61 +- tests/python/unittest/test_operator.py | 474 ++++++++++++++- 21 files changed, 2217 insertions(+), 21 deletions(-) create mode 100644 benchmark/python/control_flow/rnn.py create mode 100644 src/nnvm/graph_editor.cc create mode 100644 src/operator/control_flow.cc create mode 100644 src/operator/subgraph_op_common.cc create mode 100644 src/operator/subgraph_op_common.h diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py new file mode 100644 index 000000000000..5e41b7508b66 --- /dev/null +++ b/benchmark/python/control_flow/rnn.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import subprocess +import mxnet as mx +from mxnet import gluon +import time +import copy + +def get_gpus(): + """ + return a list of GPUs + """ + try: + re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) + except OSError: + return [] + return range(len([i for i in re.split('\n') if 'GPU' in i])) + +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def benchmark_rnn(cell, rnn_data, states): + ctx = rnn_data.context + num_batches = 20 + + # Imperative + cell0 = copy.deepcopy(cell) + layer0 = TestRNNLayer(cell0) + layer0.initialize(ctx=ctx) + + # Hybridize + cell1 = copy.deepcopy(cell) + cell1.hybridize() + layer1 = TestRNNLayer(cell1) + layer1.initialize(ctx=ctx) + + # Hybridize + cell2 = copy.deepcopy(cell) + layer2 = TestRNNLayer(cell2) + layer2.initialize(ctx=ctx) + layer2.hybridize() + layer2(rnn_data, states) + + # Hybridize + cell3 = copy.deepcopy(cell) + cell3.hybridize(static_alloc=True) + layer3 = TestRNNLayer(cell3) + layer3.initialize(ctx=ctx) + + tic = time.time() + for i in range(num_batches): + res0 = layer0(rnn_data, states) + mx.nd.waitall() + print("Imperative inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res1 = layer1(rnn_data, states) + mx.nd.waitall() + print("Hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res3 = layer3(rnn_data, states) + mx.nd.waitall() + print("Static-hybrid-cell inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + res2 = layer2(rnn_data, states) + mx.nd.waitall() + print("Hybrid inference takes " + str(time.time() - tic)) + + layer2.export("foreach_rnn") + symnet = mx.symbol.load('foreach_rnn-symbol.json') + args1 = {} + params = layer2.collect_params() + for key in params.keys(): + args1[key] = params[key].data() + args1['data0'] = rnn_data + for i in range(len(states)): + args1['data' + str(i + 1)] = states[i] + exe = symnet.bind(ctx=ctx, args=args1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=False) + mx.nd.waitall() + print("Symbol inference takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res0 = layer0(rnn_data, states) + res0.backward() + mx.nd.waitall() + print("Imperative training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res1 = layer1(rnn_data, states) + res1.backward() + mx.nd.waitall() + print("Hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res3 = layer3(rnn_data, states) + res3.backward() + mx.nd.waitall() + print("Static-hybrid-cell training takes " + str(time.time() - tic)) + + tic = time.time() + for i in range(num_batches): + with mx.autograd.record(): + res2 = layer2(rnn_data, states) + res2.backward() + mx.nd.waitall() + print("Hybrid training takes " + str(time.time() - tic)) + + # gradients for the backward of the foreach symbol + args_grad1 = {} + for key in args1.keys(): + args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) + exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) + tic = time.time() + for i in range(num_batches): + exe.forward(is_train=True) + exe.backward(res2) + mx.nd.waitall() + print("Symbol training takes " + str(time.time() - tic)) + print("") + +if __name__ == '__main__': + ndim = 512 + seq_len = 100 + batch_sizes = [1, 32] + cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'), + gluon.rnn.LSTMCell(ndim, prefix='rnn_')] + ctxs = [mx.cpu(0), mx.gpu(0)] + for cell in cells: + for ctx in ctxs: + for batch_size in batch_sizes: + if len(get_gpus()) == 0 and ctx == mx.gpu(0): + continue + + if isinstance(cell, gluon.rnn.GRUCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + elif isinstance(cell, gluon.rnn.LSTMCell): + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), + ctx=mx.cpu(0)) + states = [] + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), + ctx=mx.cpu(0))) + if ctx == mx.gpu(0): + dev = "GPU" + else: + dev = "CPU" + print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, + batch_size)) + benchmark_rnn(cell, rnn_data, states) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index b017c601208e..36a2c151e859 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` fft ifft quantize + foreach ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index f2bb3f15deed..664716560506 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` fft ifft quantize + foreach ``` ## API Reference diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4dd858a51c4b..6c7626b917a4 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1051,6 +1051,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, */ MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char **name); + +/*! + * \brief Get the input symbols of the graph. + * \param sym The graph. + * \param inputs The input symbols of the graph. + * \param input_size the number of input symbols returned. + */ +MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs, + int *input_size); + +/*! + * \brief Cut a subgraph whose nodes are marked with a subgraph attribute. + * The input graph will be modified. A variable node will be created for each + * edge that connects to nodes outside the subgraph. The outside nodes that + * connect to the subgraph will be returned. + * \param sym The graph. + * \param inputs The nodes that connect to the subgraph. + * \param input_size The number of such nodes. + */ +MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs, + int *input_size); + /*! * \brief Get the detailed information about atomic symbol. * \param creator the AtomicSymbolCreator. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index f4694efad297..2bb2462d4869 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -64,8 +64,10 @@ enum OpReqType { * \sa Resource */ struct OpContext { + /*! \brief whether there is a backward phase to compute gradients. */ + bool need_grad; /*! \brief whether it is training phase */ - int is_train; + bool is_train; /*! \brief RunContext related resources */ RunContext run_ctx; /*! \brief the callback when operation completes, used by asynchronize ops */ @@ -98,7 +100,12 @@ enum class ExecType { * In current implementation, copy operator is specially handled by executor. * This flag is used for special case treatment and future extension of different copy ops. */ - kCrossDeviceCopy + kCrossDeviceCopy, + /*! + * \brief A subgraph execution should happen in the main thread, instead of + * in the execution engine. + */ + kSubgraphExec, }; /*! \brief the dispatch mode of the operator */ diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index cc66483f00b3..b1f065e9f822 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -21,6 +21,8 @@ import math from ..context import current_context from ..random import uniform +from ..base import _as_list +from . import ndarray try: from .gen_contrib import * except ImportError: @@ -95,3 +97,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled # pylint: enable=line-too-long + +def foreach(body, data, init_states): + """Run a for loop with user-defined computation over NDArrays on dimension 0. + + This operator simulates a for loop and body has the computation for an iteration + of the for loop. It runs the computation in body on each slice from the input + NDArrays. + + body takes two arguments as input and outputs a tuple of two elements, + as illustrated below: + + out, states = body(data1, states) + + data1 can be either an NDArray or a list of NDArrays. If data is an NDArray, + data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same + size as data. states is a list of NDArrays and have the same size as init_states. + Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated + as the first output of foreach; states from the last execution of body + are the second output of foreach. + + The computation done by this operator is equivalent to the pseudo code below + when the input data is NDArray: + + states = init_states + outs = [] + for i in data.shape[0]: + s = data[i] + out, states = body(s, states) + outs.append(out) + outs = stack(*outs) + + + Parameters + ---------- + body : a Python function. + Define computation in an iteration. + data: an NDArray or a list of NDArrays. + The input data. + init_states: an NDArray or a list of NDArrays. + The initial values of the loop states. + name: string. + The name of the operator. + + Returns + ------- + outputs: an NDArray or a list of NDArrays. + The output data concatenated from the output of all iterations. + states: a list of NDArrays. + The loop states in the last iteration. + + Examples + -------- + >>> step = lambda data, states: (data + states[0], [states[0] * 2]) + >>> data = mx.nd.random.uniform(shape=(2, 10)) + >>> states = [mx.nd.random.uniform(shape=(10))] + >>> outs, states = mx.nd.contrib.foreach(step, data, states) + """ + + def check_input(inputs, in_type, msg): + is_NDArray_or_list = True + if isinstance(inputs, list): + for i in inputs: + if not isinstance(i, in_type): + is_NDArray_or_list = False + break + else: + is_NDArray_or_list = isinstance(inputs, in_type) + assert is_NDArray_or_list, msg + + check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays") + check_input(init_states, ndarray.NDArray, + "init_states should be an NDArray or a list of NDArrays") + + not_data_list = isinstance(data, ndarray.NDArray) + num_iters = data.shape[0] if not_data_list else data[0].shape[0] + states = init_states + outputs = [] + for i in range(num_iters): + if not_data_list: + eles = data[i] + else: + eles = [d[i] for d in data] + outs, states = body(eles, states) + outs = _as_list(outs) + outputs.append(outs) + outputs = zip(*outputs) + tmp_outputs = [] + for out in outputs: + tmp_outputs.append(ndarray.op.stack(*out)) + outputs = tmp_outputs + + if not_data_list and len(outputs) == 1: + outputs = outputs[0] + return (outputs, states) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 83e90e687327..28bb507dd13d 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -19,6 +19,9 @@ # pylint: disable=wildcard-import, unused-wildcard-import """Contrib Symbol API of MXNet.""" import math +import ctypes +import copy + from .random import uniform from .symbol import Symbol try: @@ -26,7 +29,12 @@ except ImportError: pass -__all__ = ["rand_zipfian"] +from . import symbol +from ..base import _LIB, check_call +from ..base import SymbolHandle, _as_list +from ..attribute import AttrScope + +__all__ = ["rand_zipfian", "foreach"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -91,3 +99,240 @@ def rand_zipfian(true_classes, num_sampled, range_max): expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled + +def _get_graph_inputs(subg): + num_handles = ctypes.c_int(0) + handles = ctypes.POINTER(SymbolHandle)() + check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, ctypes.byref(handles), + ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(ctypes.cast(handles[i], SymbolHandle)) + syms.append(s) + return syms + +def _cut_subgraph(subg): + num_handles = ctypes.c_int(0) + handles = ctypes.POINTER(SymbolHandle)() + check_call(_LIB.MXSymbolCutSubgraph(subg.handle, ctypes.byref(handles), + ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(ctypes.cast(handles[i], SymbolHandle)) + syms.append(s) + return syms + +# This construct a subgraph for given output nodes. +# If an output node is one of the input nodes, we call identity to make sure +# that outputs nodes are different from input nodes. +def _construct_subgraph(sym_out, sym_states): + sym_out = _as_list(sym_out) + sym_states = _as_list(sym_states) + all_outputs = [] + all_outputs.extend(sym_out) + all_outputs.extend(sym_states) + g = symbol.Group(all_outputs) + + flat_out = [] + all_input_names = g.list_inputs() + output_names = [o.name for o in sym_out] + for o in sym_out: + if o.name in all_input_names: + flat_out.append(symbol.op.identity(o)) + else: + flat_out.append(o) + + for s in sym_states: + if s.name in all_input_names or s.name in output_names: + # There is a problem if the outputs are the same as the inputs + # or the first output. By calling identity, we can make sure that + # all symbols will refer to different NDArrays. + flat_out.append(symbol.op.identity(s)) + else: + flat_out.append(s) + return symbol.Group(flat_out) + +def foreach(body, data, init_states, name="foreach"): + """Run a for loop with user-defined computation over Symbols on dimension 0. + + This operator simulates a for loop and body has the computation for an iteration + of the for loop. It runs the computation in body on each slice from the input + NDArrays. + + body takes two arguments as input and outputs a tuple of two elements, + as illustrated below: + + out, states = body(data1, states) + + data1 can be either a symbol or a list of symbols. If data is a symbol, + data1 is a symbol. Otherwise, data1 is a list of symbols and has the same + size as data. states is a list of symbols and have the same size as init_states. + Similarly, out can be either a symbol or a list of symbols, which are concatenated + as the first output of foreach; states from the last execution of body + are the second output of foreach. + + foreach can output only output data or states. If a user only wants states, + the body function can return ([], states). Similarly, if a user only wants + output data, the body function can return (out, []). + + The computation done by this operator is equivalent to the pseudo code below + when the input data is NDArray: + + states = init_states + outs = [] + for i in data.shape[0]: + s = data[i] + out, states = body(s, states) + outs.append(out) + outs = stack(*outs) + + + Parameters + ---------- + body : a Python function. + Define computation in an iteration. + data: a symbol or a list of symbols. + The input data. + init_states: a symbol or a list of symbols. + The initial values of the loop states. + name: string. + The name of the operator. + + Returns + ------- + outputs: a Symbol or a list of Symbols. + The output data concatenated from the output of all iterations. + states: a list of Symbols. + The loop states in the last iteration. + + Examples + -------- + >>> step = lambda data, states: (data + states[0], [states[0] * 2]) + >>> data = mx.sym.var('data') + >>> states = [mx.sym.var('state')] + >>> outs, states = mx.sym.contrib.foreach(step, data, states) + """ + + def check_data(inputs, in_type, msg): + is_NDArray_or_list = True + if isinstance(inputs, list): + for i in inputs: + if not isinstance(i, in_type): + is_NDArray_or_list = False + break + else: + is_NDArray_or_list = isinstance(inputs, in_type) + assert is_NDArray_or_list, msg + + check_data(data, symbol.Symbol, "data should be a symbol or a list of symbols") + check_data(init_states, symbol.Symbol, "init_states should be a symbol or a list of symbols") + not_state_list = isinstance(init_states, symbol.Symbol) + + # If the input python function references to the symbols outside + # the python function, we need to prune the computation graph constructed from + # the function. One way of doing it is to mark the nodes in the computation graph + # with AttrScope and prune the nodes without the special attribute. + with AttrScope(__subgraph_name__=name): + if isinstance(data, list): + in_eles = [symbol.var(sym.name) for sym in data] + else: + in_eles = symbol.var(data.name) + if isinstance(init_states, list): + states = [symbol.var(s.name) for s in init_states] + else: + states = symbol.var(init_states.name) + sym_out, sym_states = body(in_eles, states) + + check_data(sym_out, symbol.Symbol, + "the output should be an NDArray or a list of NDArrays") + check_data(sym_states, symbol.Symbol, + "the output states should be an NDArray or a list of NDArrays") + if isinstance(sym_states, list): + assert isinstance(init_states, list) and len(sym_states) == len(init_states), \ + "the number of output states (%d) should be the same as input states (%d)" \ + % (len(sym_states), len(init_states)) + num_out_data = len(sym_out) + num_states = len(sym_states) + num_outputs = num_out_data + num_states + g = _construct_subgraph(sym_out, sym_states) + + input_syms = _get_graph_inputs(g) + cut_syms = _cut_subgraph(g) + input_syms = _get_graph_inputs(g) + + # Here we need to find out how the input symbols are ordered as well as + # where the loop states are located in the list of inputs. + + # This dict contains the symbols of the subgraph. + input_syms = {sym.name:sym for sym in input_syms} + gin_names = input_syms.keys() + # This array contains the symbols for the inputs of foreach. + # They are ordered according to the inputs of the subgraph. + init_states = _as_list(init_states) + state_names = [sym.name for sym in init_states] + data_syms = _as_list(data) + data_names = [sym.name for sym in data_syms] + cut_var_map = {sym.list_outputs()[0]:sym for sym in cut_syms} + cut_var_names = cut_var_map.keys() + + subg_input_names = g.list_inputs() + # ordered_ins contains input symbols in the following order: + # data_syms, state_syms, followed by cut_vars and vars in the closure. + ordered_ins = data_syms + # this defines the location of data_syms in the list of subgraph inputs + in_data_locs = [] + for dname in data_names: + # Some data may not be used. + if dname in subg_input_names: + in_data_locs.append(subg_input_names.index(dname)) + else: + raise AssertionError("the data arrays have to be used in the loop body") + + ordered_ins.extend(init_states) + # this defines the location of state_syms in the list of subgraph inputs. + in_state_locs = [] + for sname in state_names: + # Some state may not be used. + if sname in subg_input_names: + in_state_locs.append(subg_input_names.index(sname)) + else: + raise AssertionError("the state arrays have to be used in the loop body") + + remain_locs = [] + for in_name in subg_input_names: + assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \ + % (in_name, str(gin_names)) + if in_name in cut_var_names: + ordered_ins.append(cut_var_map[in_name]) + remain_locs.append(subg_input_names.index(in_name)) + elif in_name not in data_names and in_name not in state_names: + # The remaining inputs are the variable nodes created inside the UDF. + # The subgraph can't have nodes shared with the main graph. As such, + # we need to make a copy of these variable nodes. + assert in_name in gin_names + ordered_ins.append(copy.deepcopy(input_syms[in_name])) + remain_locs.append(subg_input_names.index(in_name)) + + ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs, + num_out_data=num_out_data, in_state_locs=in_state_locs, + in_data_locs=in_data_locs, remain_locs=remain_locs) + if num_outputs - num_states > 1: + outs = [] + for i in range(num_outputs - num_states): + outs.append(ret[i]) + elif num_outputs - num_states == 1: + outs = ret[0] + else: + outs = [] + states = [] + for i in range(num_states): + states.append(ret[num_outputs - num_states + i]) + + if not_state_list: + # If there is only one input state, there should be only one output state. + assert len(states) == 1 + states = states[0] + + return (outs, states) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index 9d51ddcb674a..ca50a741012b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -33,7 +33,7 @@ private[mxnet] object CToScalaUtils { case "double" | "doubleorNone" => "Double" case "string" => "String" case "boolean" | "booleanorNone" => "Boolean" - case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default, $argType") } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index e5e9b522890b..c27a59a67c6e 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -344,6 +344,77 @@ int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, API_END(); } +namespace mxnet { + +extern std::vector GetInputSymbols(const nnvm::Symbol &sym); +extern bool CutGraphInputs(const std::vector &input_entries, + bool skip_var, std::vector *orig_entries); + +} + +int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) { + API_BEGIN(); + nnvm::Symbol *s = static_cast(sym); + std::vector input_syms = mxnet::GetInputSymbols(*s); + *input_size = input_syms.size(); + + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + ret->ret_handles.clear(); + ret->ret_handles.reserve(*input_size); + for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); + *input_arr = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); + API_END_HANDLE_ERROR(); +} + +int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, + int *input_size) { + // Given a graph, we want to fetch the nodes that have been marked as part of + // a subgraph. + API_BEGIN(); + nnvm::Symbol *s = static_cast(sym); + std::string subg_attr = "__subgraph_name__"; + auto out_node = s->outputs[0].node; + auto it = out_node->attrs.dict.find(subg_attr); + if (it != out_node->attrs.dict.end()) { + std::string subg_name = it->second; + std::vector input_entries; + DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries] + (nnvm::NodePtr n) { + // If the node itself isn't in the subgraph, we ignore it. + auto it = n->attrs.dict.find(subg_attr); + if (it == n->attrs.dict.end() || it->second != subg_name) + return; + + // We search for nodes whose node entries aren't in the subgraph. + for (size_t j = 0; j < n->inputs.size(); j++) { + auto in_node = n->inputs[j].node; + auto it = in_node->attrs.dict.find(subg_attr); + if (it == in_node->attrs.dict.end() || it->second != subg_name) + input_entries.push_back(&n->inputs[j]); + } + }); + + std::vector orig_entries; + CutGraphInputs(input_entries, false, &orig_entries); + std::vector input_syms(orig_entries.size()); + for (size_t i = 0; i < input_syms.size(); i++) { + input_syms[i] = new nnvm::Symbol(); + input_syms[i]->outputs.push_back(orig_entries[i]); + } + *input_size = input_syms.size(); + + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + ret->ret_handles.clear(); + ret->ret_handles.reserve(*input_size); + for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); + *input_symbols = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); + } else { + *input_size = 0; + } + + API_END_HANDLE_ERROR(); +} + int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { nnvm::Symbol *s = new nnvm::Symbol(); API_BEGIN(); diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 831b5f900237..7386de4d12e3 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -39,6 +39,7 @@ namespace exec { GraphExecutor::GraphExecutor() { log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false); + need_grad_ = false; } GraphExecutor::~GraphExecutor() { @@ -257,11 +258,11 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, nnvm::Graph g; g.outputs = symbol.outputs; - bool need_grad = false; + need_grad_ = false; for (OpReqType req : grad_req_types) { - if (req != kNullOp) need_grad = true; + if (req != kNullOp) need_grad_ = true; } - if (!need_grad) return g; + if (!need_grad_) return g; for (size_t i = 0; i < g.outputs.size(); ++i) { NodeEntry ngrad{nnvm::Node::Create(), 0, 0}; head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i])); @@ -1591,6 +1592,7 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; opnode.exec->op_ctx.is_train = is_train; + opnode.exec->op_ctx.need_grad = need_grad_; } // Push Ops @@ -1609,11 +1611,15 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { OpNode& opnode = op_nodes_[nid]; if (op_nodes_[nid].skip_exec_node) continue; opnode.exec->op_ctx.is_train = is_train; + opnode.exec->op_ctx.need_grad = need_grad_; if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) { CHECK_EQ(inode.inputs.size(), 1U); CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0])); + } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) { + // If the node contains a subgraph, we can't execute it in the engine. + opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false); } else if (opnode.cached_opr != nullptr) { bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 24f98894912b..bfc415b4526a 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -213,6 +213,8 @@ class GraphExecutor : public Executor { // perform bulking and segmentation on a training graph void BulkTrainingOpSegs(size_t total_num_nodes); + // indicate whether there is a backward graph for gradients. + bool need_grad_; // internal graph nnvm::Graph graph_; // operator node diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 2181c5cab871..5e48c5a26f7b 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -818,7 +818,7 @@ OpStatePtr CachedOp::DynamicForward( return op_state; } -void CachedOp::Forward( +OpStatePtr CachedOp::Forward( const std::shared_ptr& op_ptr, const std::vector& inputs, const std::vector& outputs) { @@ -858,6 +858,7 @@ void CachedOp::Forward( std::move(attrs), inputs, outputs, op_state, &save_inputs(), &save_outputs()); } + return op_state; } diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 370ef02b5f25..4f4dfdcc14dd 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -92,7 +92,7 @@ class CachedOp { std::vector Gradient( const nnvm::NodePtr& node, const std::vector& ograds) const; - void Forward( + OpStatePtr Forward( const std::shared_ptr& op_ptr, const std::vector& inputs, const std::vector& outputs); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index faff5f173fe1..2331d7be155c 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -373,6 +373,7 @@ inline void PushFCompute(const FCompute& fn, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; CHECK(exec_type == ExecType::kSync); std::vector inputs, outputs; @@ -393,7 +394,7 @@ inline void PushFCompute(const FCompute& fn, &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); // setup context - OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; bool is_gpu = ctx.dev_mask() == gpu::kDevMask; // pre-fcompute fallback, cast to default storage type CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); @@ -420,11 +421,12 @@ inline void PushFComputeEx(const FComputeEx& fn, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { - OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(outputs, req); #endif @@ -459,6 +461,7 @@ inline void PushOperator(const OpStatePtr& state, static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); @@ -470,17 +473,23 @@ inline void PushOperator(const OpStatePtr& state, if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) { const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { - OpContext opctx{is_train, rctx, on_complete, requested}; + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(outputs, req); #endif fcompute_ex(state, opctx, inputs, req, outputs); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + // For operators with subgraphs, we need to invoke them in the main thread + // instead of the threaded engine. + if (exec_type == ExecType::kSubgraphExec) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, ctx, read_vars, write_vars, FnProperty::kNormal, 0, @@ -497,7 +506,7 @@ inline void PushOperator(const OpStatePtr& state, << "for stateful operator " << op->name; const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { - OpContext opctx{is_train, rctx, on_complete, requested}; + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; std::vector input_blobs, output_blobs; // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays @@ -519,12 +528,16 @@ inline void PushOperator(const OpStatePtr& state, fcompute(state, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type, if necessary CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && exec_type == ExecType::kSync) { + if (is_gpu && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + if (exec_type == ExecType::kSubgraphExec) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e90fb6319d77..0b2beed3391b 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1109,7 +1109,8 @@ void CopyFromToImpl(const NDArray& from, const NDArray& to, const Context to_ctx = to.ctx(); bool is_train = Imperative::Get()->is_training(); - OpContext opctx{is_train, + OpContext opctx{Imperative::Get()->is_recording(), + is_train, rctx, engine::CallbackOnComplete(), requested}; diff --git a/src/nnvm/graph_editor.cc b/src/nnvm/graph_editor.cc new file mode 100644 index 000000000000..1dee3c14ee44 --- /dev/null +++ b/src/nnvm/graph_editor.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_editor.cc + * The functions in this file edit an NNVM graph. Potentially, + * these functions should be moved to NNVM in the future. + */ + +#include +#include +#include + +namespace nnvm { +NodePtr CreateVariableNode(const std::string& name); +} + +namespace mxnet { + +/* + * Given a computation graph, this function finds the input nodes of the graph + * and create symbols for the input nodes. It returns the input symbols. + */ +std::vector GetInputSymbols(const nnvm::Symbol &sym) { + nnvm::Graph g; + std::vector input_syms; + g.outputs = sym.outputs; + const nnvm::IndexedGraph& idx = g.indexed_graph(); + // Go through all nodes and return the ones representing variables. + for (size_t i = 0; i < idx.num_nodes(); i++) { + const nnvm::Node &n = *idx[i].source; + for (const nnvm::NodeEntry &e : n.inputs) { + auto p = e.node; + if (p->is_variable()) { + nnvm::Symbol *s = new nnvm::Symbol(); + s->outputs.push_back(e); + input_syms.push_back(s); + } + } + } + return input_syms; +} + +/* + * Given a computation graph and a set of input node entries, this function cuts + * the node entries and creates new variable nodes as the input nodes of the + * subgraph. It returns the nodes that connect to the subgraph directly and + * the names of the new variable nodes. + */ +bool CutGraphInputs(const std::vector &input_entries, + bool skip_var, std::vector *orig_entries) { + struct pred_entry { + nnvm::NodeEntry e; + explicit pred_entry(const nnvm::NodeEntry &_e): e(_e) {} + bool operator()(const nnvm::NodeEntry e1) { + return e.node == e1.node && e.index == e1.index; + } + }; + + std::vector var_nodes; + orig_entries->clear(); + orig_entries->reserve(input_entries.size()); + for (size_t i = 0; i < input_entries.size(); i++) { + nnvm::NodeEntry *e = input_entries[i]; + // If the node is a variable itself, we may want to skip the node. + if (e->node->is_variable() && skip_var) + continue; + + auto it = std::find_if(orig_entries->begin(), orig_entries->end(), + pred_entry(*e)); + bool exist = (it != orig_entries->end()); + orig_entries->push_back(*e); + nnvm::NodePtr n; + // If we haven't seen the entry before, we need to create a new var node + // for the node entry. + if (!exist) { + nnvm::Symbol sym; + sym.outputs.push_back(*e); + n = nnvm::CreateVariableNode(sym.ListOutputNames()[0]); + } else { + // Otherwise, we use the var node created before. + size_t idx = it - orig_entries->begin(); + CHECK_LT(idx, var_nodes.size()); + n = var_nodes[idx]; + } + var_nodes.push_back(n); + *e = nnvm::NodeEntry{n, 0, 0}; + } + return true; +} + +} // namespace mxnet diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc new file mode 100644 index 000000000000..c091fdb67e0f --- /dev/null +++ b/src/operator/control_flow.cc @@ -0,0 +1,545 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./elemwise_op_common.h" +#include "../imperative/imperative_utils.h" +#include "./subgraph_op_common.h" + +namespace mxnet { +namespace op { + +struct ForeachParam : public dmlc::Parameter { + int num_args; + int num_outputs; + int num_out_data; + // The location of states in the subgraph inputs. + nnvm::Tuple in_state_locs; + // The location of data arrays in the subgraph inputs. + nnvm::Tuple in_data_locs; + // The location of remaining arrays in the subgraph inputs. + nnvm::Tuple remain_locs; + DMLC_DECLARE_PARAMETER(ForeachParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs."); + DMLC_DECLARE_FIELD(num_outputs) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(num_out_data) + .describe("The number of output data of the subgraph."); + DMLC_DECLARE_FIELD(in_state_locs) + .describe("The locations of loop states among the inputs."); + DMLC_DECLARE_FIELD(in_data_locs) + .describe("The locations of input data among the inputs."); + DMLC_DECLARE_FIELD(remain_locs) + .describe("The locations of remaining data among the inputs."); + } +}; // struct ForeachParam + +DMLC_REGISTER_PARAMETER(ForeachParam); + +class ForeachState: public LoopState { + public: + ForeachParam params; + int num_iterations; + + ForeachState(const Symbol &g, const ForeachParam ¶ms) : LoopState(g) { + this->params = params; + } +}; + +static void ForeachComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ForeachState &state = state_ptr.get_state(); + const ForeachParam& params = state.params; + const size_t iter_dim = 0; + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_GT(params.in_data_locs.ndim(), 0); + size_t len = inputs[0].shape()[iter_dim]; + state.num_iterations = len; + for (size_t i = 1; i < params.in_data_locs.ndim(); i++) + CHECK_EQ(inputs[i].shape()[iter_dim], len); + for (size_t i = 0; i < (size_t) params.num_out_data; i++) + CHECK_EQ(len, outputs[i].shape()[iter_dim]); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) + << "The for operator doesn't support the sparse format"; + + // Initialize the outputs of the subgraph is a little trickier. + // The states from the previous iteration are used as the inputs of the next + // iteration, so I have to maintain two arrays, so the inputs and outputs + // of the subgraph share the same memory. + std::vector subg_outputs1(outputs.size()); + std::vector subg_outputs2(outputs.size()); + std::vector *subg_outputs[2]{&subg_outputs1, &subg_outputs2}; + // If the length is an odd number, the last iteration will use the first set + // of outputs. In this way, we don't need to copy the results from the + // subgraph to the final outputs of the loop. + if (len % 2 == 1) { + for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) { + subg_outputs1[i] = outputs[i]; + subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, + outputs[i].dtype()); + } + } else { + // Otherwise, we'll use the second set of outputs. + for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) { + subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true, + outputs[i].dtype()); + subg_outputs2[i] = outputs[i]; + } + } + + // Initialize the inputs for the subgraph. + // In each iteration, we need to update the subgraph inputs for input data + // and the loop states. + std::vector subg_inputs(inputs.size()); + // The remaining arrays (other than input data and states) only need to be set once. + for (size_t j = 0; j < params.remain_locs.ndim(); j++) { + CHECK_LT(params.remain_locs[j], subg_inputs.size()); + subg_inputs[params.remain_locs[j]] = inputs[j + params.in_data_locs.ndim() + + params.in_state_locs.ndim()]; + } + + // Here we iterate over the first dimension of the first input array. + for (size_t i = 0; i < len; i++) { + // Initialize outputs for the subgraph. + std::vector *subg_out_curr = subg_outputs[i % 2]; + std::vector *subg_out_prev = subg_outputs[(i + 1) % 2]; + for (int j = 0; j < params.num_out_data; j++) + (*subg_out_curr)[j] = outputs[j].At(i); + // When recording for backward computation, we should make sure + // that output arrays are actually different in each iteration. + if (ctx.need_grad && i < len - 1) { + for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++) + (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(), + true, outputs[j].dtype()); + } else if (ctx.need_grad && i == len - 1) { + // For the last iteration, we need to write data to the output array + // directly. + for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++) + (*subg_out_curr)[j] = outputs[j]; + } + + // Initialize inputs for the subgraph. + // Get a slice from the input data arrays. + for (size_t j = 0; j < params.in_data_locs.ndim(); j++) { + size_t loc = params.in_data_locs[j]; + subg_inputs[loc] = inputs[j].At(i); + } + // For the rest of the iterations, the states are the outputs + // from the previous iteration. + if (i > 0) { + for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) { + size_t idx = j - params.num_out_data; + CHECK_LT(params.in_state_locs[idx], subg_inputs.size()); + subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j]; + } + } else { + for (size_t j = 0; j < params.in_state_locs.ndim(); j++) { + CHECK_LT(params.in_state_locs[j], subg_inputs.size()); + subg_inputs[params.in_state_locs[j]] = inputs[j + params.in_data_locs.ndim()]; + } + } + + state.Forward(i, subg_inputs, req, *subg_out_curr, ctx.need_grad); + } +} + +static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + ForeachState &state = state_ptr.get_state(); + const ForeachParam& params = state.params; + CHECK_EQ(outputs.size(), (size_t) params.num_args - 1); + CHECK_GT(params.in_data_locs.ndim(), 0); + for (const auto &arr : outputs) + CHECK_EQ(arr.storage_type(), kDefaultStorage) + << "The for operator doesn't support the sparse format"; + int len = state.num_iterations; + size_t num_output_data = params.num_out_data; + + // In backward computation, we need to run iterations from backwards. + std::vector subg_ograds(params.num_outputs); + std::vector subg_igrads(outputs.size()); + for (size_t i = num_output_data; i < subg_ograds.size(); i++) + subg_ograds[i] = inputs[i]; + std::vector subg_req(req.size()); + for (auto r : req) + CHECK_NE(r, kWriteInplace); + + // There are three types of arrays in igrads. + // * data gradients. + // * loop variable gradients. + // * remaining variable gradients. + // They are in the following order: + // [data vars], [loop vars], [remaining vars] + + // [remaining vars] + for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + size_t loc = params.remain_locs[i]; + size_t orig_loc = i + params.in_data_locs.ndim() + params.in_state_locs.ndim(); + subg_igrads[loc] = outputs[orig_loc]; + subg_req[loc] = req[orig_loc]; + } + + for (int iter_num = len - 1; iter_num >= 0; iter_num--) { + for (int i = 0; i < params.num_out_data; i++) + subg_ograds[i] = inputs[i].At(iter_num); + if (iter_num < len - 1) { + // For the rest of the iterations, we should add graidents to the + // remaining vars. + for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + size_t loc = params.remain_locs[i]; + subg_req[loc] = kAddTo; + } + } + + // [data vars] + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + subg_igrads[loc] = outputs[i].At(iter_num); + subg_req[loc] = req[i]; + } + // [loop vars] + for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + size_t loc = params.in_state_locs[i]; + const NDArray &output = outputs[i + params.in_data_locs.ndim()]; + if (iter_num != 0) { + // For state gradients, we need to allocate new NDArrays + // because intermediate state gradients won't be returned to the users. + subg_igrads[loc] = NDArray(output.shape(), output.ctx(), true, output.dtype()); + } else { + subg_igrads[loc] = output; + } + // For the first iteration, we need to use the request provided by + // the user to write state gradients to the outputs. + subg_req[loc] = iter_num != 0 ? kWriteTo : req[i + params.in_data_locs.ndim()]; + } + + state.Backward(iter_num, subg_ograds, subg_req, subg_igrads); + + size_t num_states = subg_ograds.size() - num_output_data; + for (size_t i = 0; i < num_states; i++) { + size_t loc = params.in_state_locs[i]; + CHECK_LT(loc, subg_igrads.size()); + subg_ograds[i + num_output_data] = subg_igrads[loc]; + } + } + state.Cleanup(); +} + +template +static void remap(const std::vector &op_in, size_t start, + const nnvm::Tuple &locs, std::vector *subg_in) { + auto op_in_it = op_in.begin() + start; + for (size_t i = 0; i < locs.ndim(); i++) { + dim_t loc = locs[i]; + subg_in->at(loc) = *(op_in_it + i); + } +} + +static inline TShape SliceFirstDim(const TShape &s) { + if (s.ndim() > 1) { + return TShape(s.begin() + 1, s.end()); + } else { + return TShape(mshadow::Shape1(1)); + } +} + +static bool ForeachShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 1U); + + std::vector subg_in_shape(in_shape->size()); + // data shape + std::vector data_1d(params.in_data_locs.ndim(), false); + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + if (in_shape->at(i).ndim() == 1) + data_1d[i] = true; + subg_in_shape[loc] = SliceFirstDim(in_shape->at(i)); + } + // state shape + remap(*in_shape, params.in_data_locs.ndim(), params.in_state_locs, + &subg_in_shape); + // remaining shape + remap(*in_shape, params.in_data_locs.ndim() + params.in_state_locs.ndim(), + params.remain_locs, &subg_in_shape); + + std::vector subg_out_shape = *out_shape; + for (int i = 0; i < params.num_out_data; i++) { + TShape shape = subg_out_shape[i]; + // If we don't have shape info, we don't need to do anything. + if (shape.ndim() == 0) + continue; + subg_out_shape[i] = SliceFirstDim(shape); + } + + bool infer_success = InferSubgraphShape(*attrs.subgraphs[0], + &subg_in_shape, &subg_out_shape); + + // After inference, we need to move inferred information back to in_shape and + // out_shape. + + // For the shape of output data. + size_t len = in_shape->at(0)[0]; + CHECK_GT(len, 0); + for (int i = 0; i < params.num_out_data; i++) { + // If the output shape isn't inferred, we don't need to propogate the info. + const auto& g_out_shape = subg_out_shape[i]; + if (g_out_shape.ndim() == 0) + continue; + + auto out = TShape(g_out_shape.ndim() + 1); + out[0] = len; + for (size_t i = 1; i < out.ndim(); i++) + out[i] = g_out_shape[i - 1]; + SHAPE_ASSIGN_CHECK(*out_shape, i, out); + } + // For the shape of output states. + for (size_t i = params.num_out_data; i < subg_out_shape.size(); i++) + SHAPE_ASSIGN_CHECK(*out_shape, i, subg_out_shape[i]); + + // For the shape of input data. + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + const auto &shape = subg_in_shape[loc]; + // If the input data shape isn't inferred, we don't need to propogate the + // info. + if (shape.ndim() == 0) + continue; + + if (data_1d[i]) { + TShape s(1); + s[0] = len; + SHAPE_ASSIGN_CHECK(*in_shape, i, s); + } else { + auto in = TShape(shape.ndim() + 1); + in[0] = len; + for (size_t i = 1; i < in.ndim(); i++) + in[i] = shape[i - 1]; + SHAPE_ASSIGN_CHECK(*in_shape, i, in); + } + } + // For the shape of state. + for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + size_t loc = params.in_state_locs[i]; + SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim(), + subg_in_shape[loc]); + } + // For the shape of remaining data. + for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + size_t loc = params.remain_locs[i]; + SHAPE_ASSIGN_CHECK(*in_shape, + i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), + subg_in_shape[loc]); + } + + if (infer_success) { + size_t num_states = out_shape->size() - params.num_out_data; + for (size_t i = 0; i < num_states; i++) { + CHECK_EQ((*out_shape)[i + params.num_out_data], + (*in_shape)[i + params.in_data_locs.ndim()]); + } + } + // Check if we have inferred the shapes correctly. + return infer_success; +} + +static bool ForeachType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 1U); + std::vector subg_in_type(in_type->size(), 0); + remap(*in_type, 0, params.in_data_locs, &subg_in_type); + remap(*in_type, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_type); + remap(*in_type, params.in_data_locs.ndim() + params.in_state_locs.ndim(), + params.remain_locs, &subg_in_type); + bool success = InferSubgraphDataType(*attrs.subgraphs[0], &subg_in_type, out_type); + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + TYPE_ASSIGN_CHECK(*in_type, i, subg_in_type[loc]); + } + for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + size_t loc = params.in_state_locs[i]; + TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim(), subg_in_type[loc]); + } + for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + size_t loc = params.remain_locs[i]; + TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), + subg_in_type[loc]); + } + return success; +} + +static bool ForeachStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 1U); + std::vector subg_in_attrs(in_attrs->size(), kUndefinedStorage); + remap(*in_attrs, 0, params.in_data_locs, &subg_in_attrs); + remap(*in_attrs, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_attrs); + remap(*in_attrs, params.in_data_locs.ndim() + params.in_state_locs.ndim(), + params.remain_locs, &subg_in_attrs); + bool success = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, + dispatch_mode, &subg_in_attrs, out_attrs); + for (size_t i = 0; i < params.in_data_locs.ndim(); i++) { + size_t loc = params.in_data_locs[i]; + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, subg_in_attrs[loc]); + } + for (size_t i = 0; i < params.in_state_locs.ndim(); i++) { + size_t loc = params.in_state_locs[i]; + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim(), + subg_in_attrs[loc]); + } + for (size_t i = 0; i < params.remain_locs.ndim(); i++) { + size_t loc = params.remain_locs[i]; + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, + i + params.in_data_locs.ndim() + params.in_state_locs.ndim(), + subg_in_attrs[loc]); + } + return success; +} + +static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1); + CHECK_EQ(in_attrs->size(), (size_t) params.num_args - 1 + params.num_outputs * 2); + CHECK_EQ(attrs.subgraphs.size(), 1U); + CachedOp op(*attrs.subgraphs[0], + std::vector >()); + // map the operator inputs to the subgraph inputs. + std::vector subg_forward_ins(params.num_args - 1, kUndefinedStorage); + remap(*in_attrs, params.num_outputs, params.in_data_locs, &subg_forward_ins); + remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim(), + params.in_state_locs, &subg_forward_ins); + remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim() + params.in_state_locs.ndim(), + params.remain_locs, &subg_forward_ins); + + // Copy backward input storage to backward subgraph input storage. + std::vector subg_in_attrs = *in_attrs; + for (size_t i = 0; i < subg_forward_ins.size(); i++) + subg_in_attrs[i + params.num_outputs] = subg_forward_ins[i]; + return op.BackwardStorageType(attrs, dev_mask, dispatch_mode, + &subg_in_attrs, out_attrs); +} + +static OpStatePtr CreateForeachState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create(*attrs.subgraphs[0], params); +} + +static std::vector +ForeachGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_foreach"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + +NNVM_REGISTER_OP(_foreach) +.MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", ForeachStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ForeachParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.push_back("fn"); + for (int i = 0; i < params.num_args - 1; i++) + names.push_back("data" + std::to_string(i)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0}; +}) +.set_attr("FGradient", ForeachGradient) +.set_attr("FCreateOpState", CreateForeachState) +.set_attr("FInferShape", ForeachShape) +.set_attr("FInferType", ForeachType) +.set_attr("FStatefulComputeEx", ForeachComputeExCPU) +// Foreach operator works like an executor. Its code will always run on CPU. +// So the same code can be registered for both CPU and GPU. +.set_attr("FStatefulComputeEx", ForeachComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("key_var_num_args", "num_args") +.add_argument("fn", "Symbol", "Input graph.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(ForeachParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_foreach) +.set_num_inputs([](const NodeAttrs& attrs){ + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 1; + }) +.set_num_outputs([](const NodeAttrs& attrs){ + const ForeachParam& params = nnvm::get(attrs.parsed); + return params.num_args - 1; + }) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardForeachStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", ForeachGradComputeExCPU) +.set_attr("FStatefulComputeEx", ForeachGradComputeExCPU); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc new file mode 100644 index 000000000000..71a9a21c28c4 --- /dev/null +++ b/src/operator/subgraph_op_common.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./subgraph_op_common.h" +#include "./operator_common.h" +#include "../imperative/imperative_utils.h" + +namespace mxnet { +namespace op { + +bool InferSubgraphDataType(const nnvm::Symbol &subgraph, + std::vector *in_types, + std::vector *out_types) { + nnvm::Graph g; + g.outputs = subgraph.outputs; + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_types->size()); + CHECK_EQ(idx_g.outputs().size(), out_types->size()); + + // Put the input and output data types to the dtype vector. + nnvm::DTypeVector types(idx_g.num_node_entries(), -1); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_types->size()); + for (size_t i = 0; i < in_types->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + types[eid] = in_types->at(i); + } + CHECK_EQ(g.outputs.size(), out_types->size()); + for (size_t i = 0; i < out_types->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + types[eid] = out_types->at(i); + } + + // Infer data type of the graph. + g.attrs["dtype"] = std::make_shared(std::move(types)); + g = exec::InferType(std::move(g)); + + const auto& types1 = g.GetAttr("dtype"); + // assign to in_types + for (size_t i = 0; i < in_types->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + TYPE_ASSIGN_CHECK(*in_types, i, types1[eid]); + } + // assign to out_types + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + TYPE_ASSIGN_CHECK(*out_types, i, types1[eid]); + } + // Check if we have inferred the dtypes correctly. + return g.GetAttr("dtype_num_unknown_nodes") == 0; +} + +bool InferSubgraphStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_stypes, + std::vector *out_stypes) { + nnvm::Graph g; + g.outputs = subgraph.outputs; + const auto& idx_g = g.indexed_graph(); + CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size()); + CHECK_EQ(idx_g.outputs().size(), out_stypes->size()); + exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask); + + // Put the input and output storages to the storage vector. + nnvm::StorageVector stypes(idx_g.num_node_entries(), exec::kBadStorageID); + const auto &input_nids = idx_g.input_nodes(); + CHECK_EQ(input_nids.size(), in_stypes->size()); + for (size_t i = 0; i < in_stypes->size(); i++) { + auto eid = idx_g.entry_id(input_nids[i], 0); + stypes[eid] = in_stypes->at(i); + } + CHECK_EQ(g.outputs.size(), out_stypes->size()); + for (size_t i = 0; i < out_stypes->size(); i++) { + auto eid = idx_g.entry_id(g.outputs[i]); + stypes[eid] = out_stypes->at(i); + } + + // Infer storage type of the graph. + bool dev_match = g.attrs.count("dev_mask") && + g.GetAttr("dev_mask") == dev_masks; + if (!dev_match) { + g.attrs["dev_mask"] = std::make_shared(std::move(dev_masks)); + } + g.attrs["storage_type"] = std::make_shared(std::move(stypes)); + g = exec::InferStorageType(std::move(g)); + + const auto& stypes1 = g.GetAttr("storage_type"); + // assign to in_types + for (size_t i = 0; i < in_stypes->size(); ++i) { + const auto eid = idx_g.entry_id(input_nids[i], 0); + STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, stypes1[eid]); + } + + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + // assign to out_types + for (size_t i = 0; i < g.outputs.size(); ++i) { + const auto eid = idx_g.entry_id(g.outputs[i]); + STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes1[eid]); + } + // Check if we have inferred the storages correctly. + return g.GetAttr("storage_type_num_unknown_nodes") == 0; +} + +bool InferSubgraphShape(const nnvm::Symbol &subgraph, + std::vector *in_shape, + std::vector *out_shape) { + nnvm::Graph g; + g.outputs = subgraph.outputs; + const auto& idx = g.indexed_graph(); + CHECK_EQ(idx.input_nodes().size(), in_shape->size()); + CHECK_EQ(idx.outputs().size(), out_shape->size()); + + // Put the input and output shapes to the shape vector. + nnvm::ShapeVector shapes(idx.num_node_entries()); + const auto &input_nids = idx.input_nodes(); + CHECK_EQ(input_nids.size(), in_shape->size()); + for (size_t i = 0; i < in_shape->size(); i++) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = in_shape->at(i); + } + CHECK_EQ(g.outputs.size(), out_shape->size()); + for (size_t i = 0; i < out_shape->size(); i++) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = out_shape->at(i); + } + + // Infer shape of the graph. + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + + const auto& shapes1 = g.GetAttr("shape"); + // Inferring the shape in the subgraph may infer the shape of the inputs. + // We need to copy the inferred input shapes back. + CHECK_EQ(input_nids.size(), in_shape->size()); + for (size_t i = 0; i < in_shape->size(); i++) { + auto eid = idx.entry_id(input_nids[i], 0); + SHAPE_ASSIGN_CHECK(*in_shape, i, shapes1[eid]); + } + + for (size_t i = 0; i < g.outputs.size(); i++) { + uint32_t eid = idx.entry_id(g.outputs[i]); + SHAPE_ASSIGN_CHECK(*out_shape, i, shapes1[eid]); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; +} + +LoopState::LoopState(const Symbol &g) { + this->subgraph_sym = g; + this->subgraph.outputs = g.outputs; + + std::vector > kwargs; + kwargs.push_back(std::pair("inline_limit", "0")); + // We turn on static_alloc for two reasons. + // It avoids the overhead of unnecessary memory allocation. + // only static_alloc supports nested call of CachedOp. + kwargs.push_back(std::pair("static_alloc", "1")); + iter_op = std::make_shared(subgraph_sym, kwargs); +} + +void LoopState::Forward(int iter_no, + const std::vector &cinputs, + const std::vector& req, + const std::vector &coutputs, + bool is_recording) { + using namespace nnvm; + using namespace imperative; + + bool orig_is_record; + if (is_recording) + orig_is_record = Imperative::Get()->set_is_recording(true); + else + orig_is_record = Imperative::Get()->is_recording(); + + std::vector in_bufs = cinputs; + std::vector out_bufs = coutputs; + std::vector inputs(cinputs.size()); + std::vector outputs(coutputs.size()); + for (size_t i = 0; i < inputs.size(); i++) + inputs[i] = &in_bufs[i]; + for (size_t i = 0; i < outputs.size(); i++) + outputs[i] = &out_bufs[i]; + + OpStatePtr state = iter_op->Forward(nullptr, inputs, outputs); + // If an input and an output share the array, the output array will be changed + // by CachedOp. We need to copy data to the real output. + for (size_t i = 0; i < out_bufs.size(); i++) + if (!out_bufs[i].IsSame(coutputs[i])) + CopyFromTo(out_bufs[i], coutputs[i]); + if (is_recording) { + all_inputs.push_back(cinputs); + all_outputs.push_back(coutputs); + all_states.push_back(state); + } + + Imperative::Get()->set_is_recording(orig_is_record); +} + +void LoopState::Backward(int iter_no, + const std::vector &ograds, + const std::vector &req, + const std::vector &igrads) { + using namespace nnvm; + using namespace imperative; + + CHECK_GT(all_states.size(), iter_no) + << "We didn't record the computation for iteration " << iter_no; + auto op = iter_op; + std::vector inputs; + std::vector outputs; + inputs.reserve(op->num_backward_inputs()); + outputs.reserve(op->num_inputs()); + std::vector ograd_bufs = ograds; + std::vector igrad_bufs = igrads; + for (size_t i = 0; i < ograds.size(); i++) + inputs.push_back(&ograd_bufs[i]); + + const std::vector &save_inputs = op->save_inputs(); + const std::vector &save_outputs = op->save_outputs(); + CHECK_EQ(save_inputs.size(), all_inputs[iter_no].size()); + CHECK_EQ(op->num_outputs(), all_outputs[iter_no].size()); + for (size_t i = 0; i < all_inputs[iter_no].size(); i++) { + if (save_inputs[i]) + inputs.push_back(&all_inputs[iter_no][i]); + } + for (size_t i = 0; i < all_outputs[iter_no].size(); i++) { + if (save_outputs[i]) + inputs.push_back(&all_outputs[iter_no][i]); + } + CHECK_EQ(inputs.size(), op->num_backward_inputs()); + for (size_t i = 0; i < igrads.size(); i++) + outputs.push_back(&igrad_bufs[i]); + CHECK_EQ(outputs.size(), op->num_inputs()); + auto state = all_states[iter_no]; + op->Backward(false, state, inputs, req, outputs); + // If an input and an output share the array, the output array will be changed + // by CachedOp. We need to copy data to the real output. + for (size_t i = 0; i < igrads.size(); i++) + if (!igrads[i].IsSame(igrad_bufs[i])) + CopyFromTo(igrad_bufs[i], igrads[i]); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h new file mode 100644 index 000000000000..79078409e214 --- /dev/null +++ b/src/operator/subgraph_op_common.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ +#define MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ + +#include +#include +#include +#include +#include "../imperative/cached_op.h" +#include "../imperative/imperative_utils.h" + +namespace mxnet { +namespace op { + +/* + * Infer the data types of inputs and outputs of an operator that contains a + * subgraph. + */ +bool InferSubgraphDataType(const nnvm::Symbol &subgraph, std::vector *in_type, + std::vector *out_type); + +/* + * Infer the shape of inputs and outputs of an operator that contains a + * subgraph. + */ +bool InferSubgraphShape(const nnvm::Symbol &subgraph, + std::vector *in_shape, + std::vector *out_shape); + +/* + * Infer the storage types of inputs and outputs of an operator that contains a + * subgraph. + */ +bool InferSubgraphStorage(const nnvm::Symbol &subgraph, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs); + +/* + * This contains the states for running a loop and provides methods + * of running the subgraph computation for an iteration. + */ +class LoopState { + // These are output arrays from all iterations. + // They also contain the Op state for each CachedOp. + std::vector > all_outputs; + std::vector > all_inputs; + // For inference, there should be only one cached op because we + // want to share the memory in iterations. + // For training, each iteration has a cached op because each iteration + // needs to maintain a set of memory buffers for all computation states, + // which will be used in the backward. + CachedOpPtr iter_op; + std::vector all_states; + Symbol subgraph_sym; + nnvm::Graph subgraph; + + public: + explicit LoopState(const Symbol &g); + + void Forward(int iter_no, + const std::vector &inputs, + const std::vector& req, + const std::vector &outputs, + bool is_recording); + void Backward(int iter_no, + const std::vector &ograds, + const std::vector &req, + const std::vector &igrads); + void Cleanup() { + all_outputs.clear(); + all_inputs.clear(); + all_states.clear(); + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_SUBGRAPH_OP_COMMON_H_ diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 169a9d47e7cf..302928b05f75 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -18,9 +18,10 @@ import mxnet as mx from mxnet import gluon import numpy as np +import copy from numpy.testing import assert_allclose import unittest -from mxnet.test_utils import almost_equal +from mxnet.test_utils import almost_equal, assert_almost_equal def test_rnn(): @@ -28,13 +29,69 @@ def test_rnn(): inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] outputs, _ = cell.unroll(3, inputs) outputs = mx.sym.Group(outputs) - assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] + assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', + 'rnn_i2h_bias', 'rnn_i2h_weight'] assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) assert outs == [(10, 100), (10, 100), (10, 100)] +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell_type, hidden_size, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell_type(hidden_size, prefix='rnn_') + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def check_contrib_rnn(cell_type, num_states): + batch_size = 10 + hidden_size = 100 + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50)) + state_shape = (batch_size, hidden_size) + states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=mx.cpu(0)) + res1 = layer(rnn_data, states) + params1 = layer.collect_params() + orig_params1 = copy.deepcopy(params1) + + trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res1 = layer(rnn_data, states) + res1.backward() + trainer.step(batch_size) + + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=mx.cpu(0)) + layer.hybridize() + res2 = layer(rnn_data, states) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(val.data()) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2 = layer(rnn_data, states) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001, atol=0.0001) + + +def test_contrib_rnn(): + cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), + (gluon.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_contrib_rnn(cell_type, num_states) + + def test_lstm(): cell = gluon.rnn.LSTMCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 461fb63514c1..ae5cba21711a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -19,12 +19,13 @@ from __future__ import print_function import numpy as np import mxnet as mx +import copy import math import random import itertools from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * -from mxnet.base import py_str, MXNetError +from mxnet.base import py_str, MXNetError, _as_list from common import setup_module, with_seed, teardown import unittest @@ -5937,6 +5938,477 @@ def test_float16_min_max(): assert np.finfo('float16').max == mx.nd.max(a).asscalar() +@with_seed() +def test_foreach(): + v3 = mx.sym.var("v0") + v4 = mx.sym.var("v1") + v5 = mx.sym.var("v2") + v6 = mx.sym.var("v3") + v7 = mx.sym.var("v4") + v8 = mx.sym.var("v5") + + def verify_foreach(step, in_syms, state_syms, free_syms, + in_arrs, init_states, frees, out_grads, is_train=True, + free_vars_func=None, num_iters=1): + step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms) + res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms) + out = _as_list(res) + num_outputs = len(out) + for i in range(num_outputs): + out[i] = out[i] * 2 + out.extend(states) + out = mx.sym.Group(out) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + arr_grads = [] + arg_dict = {} + arg_grad_dict = {} + i = 0 + for arr in _as_list(in_arrs): + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in init_states: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in frees: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + + if is_train: + e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict) + else: + e = out.bind(ctx=default_context(), args=arg_dict) + # the inputs to forward and backward are the same so forward and backward + # should always return the same outputs. + for i in range(num_iters): + e.forward(is_train=is_train) + if (is_train): + # backward + tmp_grads = out_grads[0][:] + tmp_grads.extend(out_grads[1]) + e.backward(tmp_grads) + + # Below we use imperative to reimplement foreach and compute its gradients. + res = [] + for i in range(len(_as_list(out_grads[0]))): + res.append([]) + for arr in _as_list(in_arrs): + arr.attach_grad() + for arr in init_states: + arr.attach_grad() + for arr in frees: + arr.attach_grad() + with mx.autograd.record(): + frees_imp = frees if free_vars_func is None else free_vars_func(frees) + step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp) + states = [mx.nd.expand_dims(s, 0) for s in init_states] + res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states) + + res2 = _as_list(res) + for i in range(len(res2)): + res2[i] = res2[i] * 2 + outs = [] + outs[:] = res2[:] + if isinstance(states, list): + outs.extend(states) + states = [mx.nd.expand_dims(s, 0) for s in states] + res2.extend(states) + else: + outs.append(states) + states = mx.nd.expand_dims(states, 0) + res2.append(states) + if is_train: + res = mx.nd.concat(*res2, dim=0) + + tmp_grads = out_grads[0][:] + tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]] + tmp_grads.extend(tmp_grads1) + if is_train: + res.backward(mx.nd.concat(*tmp_grads, dim=0)) + for i in range(len(outs)): + assert e.outputs[i].shape == outs[i].shape + assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(), + rtol=0.001, atol=0.0001) + if (is_train): + all_ins = _as_list(in_arrs)[:] + all_ins.extend(init_states) + all_ins.extend(frees) + size = min(len(all_ins), len(e.grad_arrays)) + for i in range(size): + assert_almost_equal(all_ins[i].grad.asnumpy(), + e.grad_arrays[i].asnumpy(), + rtol=0.001, atol=0.0001) + + # Test cases: + # * graph inputs are stored in different orders. + # This is to test if foreach finds the data arrays and weight arrays + # in the right location. + # * the number of iterations: odd or even. + # * multiple inputs and multiple outputs. + # * inference. + def step1(in1, states, free): + out = in1 * 2 + states[0] + free[0] + return (out, [out]) + frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1] + arrs = mx.nd.arange(6).reshape(shape=(3, 2)) + states = [mx.nd.arange(2)] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, + lambda frees : [frees[0] + frees[1]]) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, + lambda frees : [frees[0] + frees[1]]) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, + lambda frees : [frees[0] + frees[1]], 5) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, + lambda frees : [frees[0] + frees[1]], 5) + + # Test the even number of iterations. + frees = [mx.nd.random.uniform(shape=(2))] + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + # Test the odd number of iterations + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Reorder the input and state in the subgraph inputs. + def step2(in1, states, free): + out = states[0] + in1 * 2 + free[0] + return (out, [out]) + # Test the even number of iterations. + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + # Test the odd number of iterations. + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + def step3(in1, states, free): + out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0] + return ([out, out], [out * 2, out * 3]) + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The order of subgraph inputs doesn't match the operator inputs + def step4(in1, states, free): + out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0] + return ([out, out * 2], [out * 2, out * 3]) + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) + verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The data inputs and states have different shapes. + def step5(in1, states, free): + if isinstance(in1[0], mx.nd.NDArray): + out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2) + out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2) + else: + out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2) + out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2) + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))] + arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The data inputs and states have different shapes and data types. + def step6(in1, states, free): + if isinstance(in1[0], mx.nd.NDArray): + out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'), + mx.nd.cast(in1[1], 'float32') * 2) + out2 = mx.nd.broadcast_add(in1[0], + free[0] + mx.nd.cast(states[1], 'float32') * 2) + else: + out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'), + mx.sym.cast(in1[1], 'float32') * 2) + out2 = mx.sym.broadcast_add(in1[0], + free[0] + mx.sym.cast(states[1], 'float32') * 2) + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), + mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')] + arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), + mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')] + states = [mx.nd.random.uniform(shape=(2, 2)), + mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # some of the inputs are used twice. + def step7(in1, states, free): + out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0] + out2 = in1[0] + free[0] + states[1] * 2 + in1[1] + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test the case that the output is the input. + arrs = mx.nd.random.uniform(shape=(3, 2)) + states = [mx.nd.arange(2)] + frees = [mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + def step8(in1, states, free): + return (in1, [states[0] * free[0]]) + verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False) + def step9(in1, states, free): + return (in1 * free[0], states) + verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test the case that not all inputs are used. + def step10(in1, states, free): + return (in1, states) + verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False) + def step11(in1, states, free): + return (in1, free) + try: + verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False) + except AssertionError: + print("the states have to be used") + def step12(in1, states, free): + return (in1, [states[0] + 1, states[0] + 2]) + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + frees = [] + try: + verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads) + verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False) + except AssertionError: + print("the states have to be used") + + # test without free variables. + def step13(in1, states, free): + return (in1, states) + states = [mx.nd.random.uniform(shape=(2))] + verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads) + verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False) + + # test when there isn't output data or output states. + def step14(in1, states, free): + return (in1 + free[0], []) + frees = [mx.nd.random.uniform(shape=(2))] + verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads) + verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False) + def step15(in1, states, free): + return ([], [in1 * states[0] * free[0]]) + out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test the case of iterating on a 1D data array. + def step16(in1, states, free): + return ([in1[0] * states[0]], [states[0] * 2]) + arrs = [mx.nd.arange(3)] + states = [mx.nd.random.uniform(shape=(1))] + out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], + [mx.nd.random.uniform(-10, 10, (1))]] + verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads) + verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False) + def step17(in1, states, free): + return ([in1[1] * in1[0] * states[0]], [states[0] * 2]) + arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)] + states = [mx.nd.random.uniform(shape=(1))] + out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], + [mx.nd.random.uniform(-10, 10, (1))]] + verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads) + verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False) + + +@with_seed() +def test_foreach_nested(): + # Test nested foreach. + def step_in(in1, states): + out = in1 * 2 + states[0] + return (out, [out]) + + def step_sym(in1, states): + out1 = mx.sym.contrib.foreach(step_in, in1, states) + out = mx.sym.broadcast_add(out1[0], states[0]) + return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))]) + def step_nd(in1, states): + out1 = mx.nd.contrib.foreach(step_in, in1, states) + out = mx.nd.broadcast_add(out1[0], states[0]) + return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))]) + + data_sym = mx.sym.var("v1") + state_sym = mx.sym.var("v2") + out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym]) + assert isinstance(states, list) + assert len(states) == 1 + out = mx.sym.broadcast_add(out, states[0]) + + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + + data = mx.nd.arange(8).reshape((2, 2, 2)) + state = mx.nd.arange(2) + data_grad = mx.nd.empty(data.shape) + state_grad = mx.nd.empty(state.shape) + e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state}, + args_grad={'v1':data_grad, 'v2':state_grad}) + e.forward(is_train=True) + out_grads = [] + for out in e.outputs: + out_grads.append(mx.nd.random.uniform(shape=out.shape)) + e.backward(out_grads) + + data.attach_grad() + state.attach_grad() + with mx.autograd.record(): + out, states = mx.nd.contrib.foreach(step_nd, data, [state]) + assert isinstance(states, list) + assert len(states) == 1 + res = mx.nd.broadcast_add(out, states[0]) + assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001, atol=0.0001) + + res.backward(out_grads[0]) + assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy()) + assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy()) + + +def check_foreach_rnn(cell_type, num_states): + data = mx.sym.var("data") + params = mx.rnn.RNNParams() + hidden_dim = 4 + input_dim = 5 + seq_len = 2 + batch_size = 2 + + # This tests foreach with accumulation sum. + def step(in1, states): + rnn = cell_type(hidden_dim, prefix='', params=params) + next_h, states = rnn(in1, states) + return (next_h, states) + + def sym_group(out): + if (isinstance(out[0], mx.sym.Symbol)): + ret = [out[0]] + else: + ret = out[0] + ret.extend(out[1]) + return mx.sym.Group(ret) + + rnn = cell_type(hidden_dim, prefix='', params=params) + if num_states == 2: + init_states = [mx.sym.var("h"), mx.sym.var("c")] + else: + init_states = [mx.sym.var("h")] + out = mx.sym.contrib.foreach(step, data, init_states) + out = sym_group(out) + arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=(seq_len, batch_size, input_dim), + h=(batch_size, hidden_dim)) + rnn_inputs = out.list_inputs() + + # Inputs + args1 = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + args2 = copy.deepcopy(args1) + # gradients for the backward of the foreach symbol + args_grad1 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + # gradients for the backward of the unrolled symbol. + args_grad2 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + + # Symbol of running LSTM with foreach. + out = mx.sym.contrib.foreach(step, data, init_states) + out = sym_group(out) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) + + # Symbol of running unrolled LSTM. + lstm = cell_type(hidden_dim, prefix='') + unroll_outs = [] + states = init_states + for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): + h, states = lstm(inputs, states) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) + unroll_outs.extend(states) + out = mx.sym.Group(unroll_outs) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) + + for i in range(5): + out_grads = [] + for arr in e1.outputs: + out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) + + args = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + + e1.forward(is_train=True, **args) + outputs1 = e1.outputs + e1.backward(out_grads) + + e2.forward(is_train=True, **args) + outputs2 = e2.outputs + e2.backward(out_grads) + + for i in range(len(outputs2)): + assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), + rtol=0.001, atol=0.0001) + input_names = out.list_inputs() + for i in range(len(e1.grad_arrays)): + name = input_names[i] + assert_almost_equal(args_grad1[name].asnumpy(), args_grad2[name].asnumpy(), + rtol=0.001, atol=0.0001) + + +@with_seed() +def test_foreach_rnn(): + cell_types = [(mx.rnn.LSTMCell, 2), (mx.rnn.RNNCell, 1), (mx.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_foreach_rnn(cell_type, num_states) + + @with_seed() def test_squeeze_op(): def check_squeeze_op(shape, axis=None):