Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

DataParallelExecutorGroup: layout handling for symbols #6736

Closed
leezu opened this issue Jun 19, 2017 · 6 comments
Closed

DataParallelExecutorGroup: layout handling for symbols #6736

leezu opened this issue Jun 19, 2017 · 6 comments

Comments

@leezu
Copy link
Contributor

leezu commented Jun 19, 2017

When merge_multi_context=True in a call to get_outputs of a DataParallelExecutorGroup, _merge_multi_context will try to concatenate the outputs from the different devices along the major axis.

The major axis is computed based on [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names] in the initializer of DataParallelExecutorGroup.

What is the recommended way to set the attr('__layout__') of a symbol? Simply pass attr={'__layout__': layout} when constructing the symbol? Can the attribute be set automatically during module binding?

Setting attr('__layout__') is necessary, as it is otherwise None, leading to _merge_multi_context trying to concatenate along dim=0 which will fail if the batch size is not divisible by the number of devices and the symbol outputs a shape (1, batch_size_per_device, X).

I.e. in case of 3 devices and batch size 128, concatenating (1, 43), (1, 43), (1, 42) along dim=0 will fail.

@formath
Copy link
Contributor

formath commented Jun 19, 2017

In most times, the default layout for DataDesc is 'NCHW' or 'NTC'. And, if the layout given to get_batch_axis is null, it will return 0 which is just the default batch axis. If your layout is not N..., set the layout attribute into op by yourself. Also, concat can concatenate those tensors different on the batch axis.

def get_batch_axis(layout):
        """Get the dimension that corresponds to the batch size.

        When data parallelism is used, the data will be automatically split and
        concatenated along the batch-size dimension. Axis can be -1, which means
        the whole array will be copied for each data-parallelism device.

        Parameters
        ----------
        layout : str
            layout string. For example, "NCHW".

        Returns
        -------
        int
            An axis indicating the batch_size dimension.
        """
        if layout is None:
            return 0
        return layout.find('N')
def concatenate(arrays, axis=0, always_copy=True):
    """DEPRECATED, use ``concat`` instead

    Parameters
    ----------
    arrays : list of `NDArray`
        Arrays to be concatenate. They must have identical shape except
        the first dimension. They also must have the same data type.
    axis : int
        The axis along which to concatenate.
    always_copy : bool
        Default `True`. When not `True`, if the arrays only contain one
        `NDArray`, that element will be returned directly, avoid copying.

    Returns
    -------
    NDArray
        An `NDArray` that lives on the same context as `arrays[0].context`.
    """
    assert isinstance(arrays, list)
    assert len(arrays) > 0
    assert isinstance(arrays[0], NDArray)

    if not always_copy and len(arrays) == 1:
        return arrays[0]

    shape_axis = arrays[0].shape[axis]
    shape_rest1 = arrays[0].shape[0:axis]
    shape_rest2 = arrays[0].shape[axis+1:]
    dtype = arrays[0].dtype
    for arr in arrays[1:]:
        shape_axis += arr.shape[axis]
        assert shape_rest1 == arr.shape[0:axis]
        assert shape_rest2 == arr.shape[axis+1:]
        assert dtype == arr.dtype
    ret_shape = shape_rest1 + (shape_axis,) + shape_rest2
    ret = empty(ret_shape, ctx=arrays[0].context, dtype=dtype)

    idx = 0
    begin = [0 for _ in ret_shape]
    end = list(ret_shape)
    for arr in arrays:
        if axis == 0:
            ret[idx:idx+arr.shape[0]] = arr
        else:
            begin[axis] = idx
            end[axis] = idx+arr.shape[axis]
            # pylint: disable=no-member,protected-access
            _internal._crop_assign(ret, arr, out=ret,
                                   begin=tuple(begin),
                                   end=tuple(end))
            # pylint: enable=no-member,protected-access
        idx += arr.shape[axis]

    return ret

@leezu
Copy link
Contributor Author

leezu commented Jun 19, 2017

If your layout is not N..., set the layout attribute into op by yourself.

What do you have in mind to set the layout attribute?

Concretely my use case is about state-outputs from RNN. I.e. the states returned by unroll.

I have shortly tried to set attr={'__layout__': layout} when creating the symbols, by adapting rnn_cell.py. My code currently uses the workaround to set merge_multi_context=False.

@formath
Copy link
Contributor

formath commented Jun 19, 2017

If you are using 'FucedRNNCell', the shape for states may not be swapped but the outputs do. You can ensure that because I'm not absolutely sure. However, all __layout__ attribute of rnn cell have been set. It should works for [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names].

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
        self.reset()

        inputs, axis = _normalize_sequence(length, inputs, layout, True)
        if axis == 1:
            warnings.warn("NTC layout detected. Consider using "
                          "TNC for FusedRNNCell for faster speed")
            inputs = symbol.swapaxes(inputs, dim1=0, dim2=1)
        else:
            assert axis == 0, "Unsupported layout %s"%layout
        if begin_state is None:
            begin_state = self.begin_state()

        states = begin_state
        if self._mode == 'lstm':
            states = {'state': states[0], 'state_cell': states[1]} # pylint: disable=redefined-variable-type
        else:
            states = {'state': states[0]}

        rnn = symbol.RNN(data=inputs, parameters=self._parameter,
                         state_size=self._num_hidden, num_layers=self._num_layers,
                         bidirectional=self._bidirectional, p=self._dropout,
                         state_outputs=self._get_next_state,
                         mode=self._mode, name=self._prefix+'rnn',
                         **states)

        if not self._get_next_state:
            outputs, states = rnn, []
        elif self._mode == 'lstm':
            outputs, states = rnn[0], [rnn[1], rnn[2]]
        else:
            outputs, states = rnn[0], [rnn[1]]

        if axis == 1:
            outputs = symbol.swapaxes(outputs, dim1=0, dim2=1)

        outputs, _ = _normalize_sequence(length, outputs, layout, merge_outputs)

        return outputs, states
def state_info(self):
        b = self._bidirectional + 1
        n = (self._mode == 'lstm') + 1
        return [{'shape': (b*self._num_layers, 0, self._num_hidden), '__layout__': 'LNC'}
                for _ in range(n)]

@leezu
Copy link
Contributor Author

leezu commented Jun 19, 2017

For FusedRNNCell (with BlockGrad on states) I get shapes (1, 43, X), (1, 43, X), (1, 42, X) for the states when specifying batchsize 128 on 3 devices.
Currently _merge_multi_context will try to concatenate these 3 ndarrays along dim=0, which is not possible. That's why, as far as I understand, at some point the layout attribute of the state symbols should be set. Perhaps this should be done by mx.sym.RNN?

@formath
Copy link
Contributor

formath commented Jun 19, 2017

rnn = symbol.RNN(data=inputs, parameters=self._parameter,
                         state_size=self._num_hidden, num_layers=self._num_layers,
                         bidirectional=self._bidirectional, p=self._dropout,
                         state_outputs=self._get_next_state,
                         mode=self._mode, name=self._prefix+'rnn',
                         **states)

Yes. It just set the __layout__ for the beginning states but rnn. Just set attribute __layout__ for rnn again may be simple. However, this is just a temporary solution.

@yajiedesign
Copy link
Contributor

This issue is closed due to lack of activity in the last 90 days. Feel free to reopen if this is still an active issue. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants