Skip to content

HawkAaron/mxnet-transducer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mxnet-transducer

A fast parallel implementation of RNN Transducer (Graves 2013 joint network), on both CPU and GPU for mxnet.

GPU version is now available for Graves2012 add network.

Install and Test

First get mxnet and the code:

git clone --recursive https://github.com/apache/incubator-mxnet
git clone https://github.com/HawkAaron/mxnet-transducer

Copy all files into mxnet dir:

cp -r mxnet-transducer/rnnt* incubator-mxnet/src/operator/contrib/

Then follow the installation instructions of mxnet:

https://mxnet.incubator.apache.org/install/index.html

Finally, add Python API into /path/to/mxnet_root/mxnet/gluon/loss.py:

class RNNTLoss(Loss):
    def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
        batch_axis = 0 if batch_first else 2
        super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
        self.batch_first = batch_first
        self.blank_label = blank_label

    def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
        if not self.batch_first:
            pred = F.transpose(pred, (2, 0, 1, 3))

        loss = F.contrib.RNNTLoss(pred, label.astype('int32', False), 
                                    pred_lengths.astype('int32', False), 
                                    label_lengths.astype('int32', False), 
                                    blank_label=self.blank_label)
        return loss

From the repo test with:

python test/test.py 10 300 100 50 --mx

Reference