Skip to content

A better, faster, stronger version of the unbounded interleaved-state recurrent neural network (UIS-RNN)

License

Notifications You must be signed in to change notification settings

DonkeyShot21/uis-rnn-sml

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Better, Faster, Stronger UIS-RNN

This repository implements some useful features on top of the original UIS-RNN repository. Some of them are described in the following paper: Supervised Online Diarization with Sample Mean Loss for Multi-Domain Data. Here is a list:

  • Sample Mean Loss (SML), a loss function that improves performance and training efficiency. To learn more about it you can read our paper.
  • Estimation of crp_alpha, a parameter of the distance dependent Chinese Restaurant Process (ddCRP) that determines the probability of switching to a new speaker. Again, more info in our paper.
  • Parallel prediction using torch.multiprocessing, that mitigates the issue with slow decoding and enables higher GPU usage.
  • Tensorboard logging, for visualizing training.

Here is a diagram of the Sample Mean Loss:

The UIS-RNN was originally proposed in Fully Supervised Speaker Diarization.

Run the demo

To get started, simply run this command:

python3 demo.py --train_iteration=1000 -l=0.001

This will train a UIS-RNN model using data/toy_training_data.npz, then store the model on disk, perform inference on data/toy_testing_data.npz, print the inference results, and save the averaged accuracy in a text file.

P.S.: The files under data/ are manually generated toy data, for demonstration purpose only. These data are very simple, so we are supposed to get 100% accuracy on the testing data.

Arguments

  • --loss_samples the number of samples for the Sample Mean Loss. If loss_samples <= 0 it will be ignored and the loss will be computed as per the original UIS-RNN
  • --fc_depth: the numebr of fully connected layers after the GRU.
  • --non_lin: whether to use non linearity (relu) in the fully connected layers.
  • NUM_WORKERS: the number of workers (processes) for multiprocessing. The argument can be found in demo.py.

All the other arguments are the same as per the original repository

Citations

Our paper is cited as:

@article{fini2019supervised,
  title={Supervised online diarization with sample mean loss for multi-domain data},
  author={Fini, Enrico and Brutti, Alessio},
  journal={arXiv preprint arXiv:1911.01266},
  year={2019}
}

About

A better, faster, stronger version of the unbounded interleaved-state recurrent neural network (UIS-RNN)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%