-
Notifications
You must be signed in to change notification settings - Fork 6.8k
DataParallelExecutorGroup: layout handling for symbols #6736
Comments
In most times, the default layout for DataDesc is 'NCHW' or 'NTC'. And, if the
|
What do you have in mind to set the layout attribute? Concretely my use case is about state-outputs from RNN. I.e. the states returned by I have shortly tried to set |
If you are using 'FucedRNNCell', the shape for
|
For FusedRNNCell (with BlockGrad on states) I get shapes |
Yes. It just set the |
This issue is closed due to lack of activity in the last 90 days. Feel free to reopen if this is still an active issue. Thanks! |
When
merge_multi_context=True
in a call toget_outputs
of aDataParallelExecutorGroup
,_merge_multi_context
will try to concatenate the outputs from the different devices along the major axis.The major axis is computed based on
[DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names]
in the initializer ofDataParallelExecutorGroup
.What is the recommended way to set the
attr('__layout__')
of a symbol? Simply passattr={'__layout__': layout}
when constructing the symbol? Can the attribute be set automatically during module binding?Setting
attr('__layout__')
is necessary, as it is otherwiseNone
, leading to_merge_multi_context
trying to concatenate alongdim=0
which will fail if the batch size is not divisible by the number of devices and the symbol outputs a shape (1, batch_size_per_device, X).I.e. in case of 3 devices and batch size 128, concatenating
(1, 43), (1, 43), (1, 42)
alongdim=0
will fail.The text was updated successfully, but these errors were encountered: