A fast parallel implementation of RNN Transducer (Graves 2013 joint network), on both CPU and GPU for mxnet.
GPU version is now available for Graves2012 add network.
First get mxnet and the code:
git clone --recursive https://github.com/apache/incubator-mxnet
git clone https://github.com/HawkAaron/mxnet-transducer
Copy all files into mxnet dir:
cp -r mxnet-transducer/rnnt* incubator-mxnet/src/operator/contrib/
Then follow the installation instructions of mxnet:
https://mxnet.incubator.apache.org/install/index.html
Finally, add Python API into /path/to/mxnet_root/mxnet/gluon/loss.py
:
class RNNTLoss(Loss):
def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
batch_axis = 0 if batch_first else 2
super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
self.batch_first = batch_first
self.blank_label = blank_label
def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
if not self.batch_first:
pred = F.transpose(pred, (2, 0, 1, 3))
loss = F.contrib.RNNTLoss(pred, label.astype('int32', False),
pred_lengths.astype('int32', False),
label_lengths.astype('int32', False),
blank_label=self.blank_label)
return loss
From the repo test with:
python test/test.py 10 300 100 50 --mx