Next-gen Kaldi: recent work with RNN-T

image

I'm going to briefly go through some work on accelerating RNN-T.

Pruned training for RNN-T (1)

image

There's a challenge with RNN-T if you want to build a large system. Especially on a language like Chinese. Let's suppose you want to build a Chinese system with the characters as the units. It's actually too slow to do standard RNN-T. The problem is that you have this 2d recursion over frames and the different positions in the transcript. And each position in this recursion involves the whole vocabulary size. It involves a big matrix multiplication. So if the vocabulary is like 6000, it's too slow. You can't really do it.

Pruned training for RNN-T (2)

image

Our basic idea is we don't really need to do the computation over the entire matrix of sequence length versus acoustic sequence length. Because most positions in that matrix don't really contribute to the loss function. So, if you look at that figure that says b, the contribution to the derivative or to the loss, it's only non-zero along a diagonal like of the time by sequence position. If you can just focus on that diagonal you can make it much faster. Basically, the way we're doing this is to use us a simpler version of the RNN-T encoder, where the so-called joiner is just adding two vectors and renormalizing. And with that simple version we discover what's the important positions and then we just do the big version with the positions that we need to do.

Pruned training for RNN-T (3)

image

This is some tables from our paper showing that our thing is much faster and uses a lot less memory than the standard thing. Also we have better whatever rate, the reason we can be better is because I mentioned that we have like a simpler version of the RNN-T joiner. We can use that for regularization it's it's another part in the loss function so the word error rates get a little bit better.

Fast and parallel decoding for RNN-T (1)

image

This other thing that we did is about speeding up decoding for RNN-T. In traditional RNN-T decoding on each frame of the acoustic input you can output an unlimited number of symbols. Sometimes RNN-T systems can get into a loop where they emit the same symbol over and over. You have to find a way to stop that from happening but in practice in a well-trained RNN-T system, what we find is that it almost never emits more than one symbol per frame. Now this does depend on the frame rate. We have 25 frames per second, so in 25 frames per second on LibriSpeech. It basically never emits more than one symbol per frame in test time. We actually tried a modified version of the RNN-T where even in training time you can only have one symbol per frame and for some reason that actually did not work. For some reason even though in inference you never have more than one symbol per frame you actually need it in training but anyway so the way that we speed up our RNN-T coding is we limit it to one symbol per frame and then we use our k2 ragged tensor things to do the whole thing fast in parallel on the gpu.

Fast and parallel decoding for RNN-T (2)

image

This again is some graph from our paper showing that our thing is super fast. We have a few comparisons. This 25% modified thing is remember I said that in training time, if you have only one symbol per frame it doesn't work so well. If we have an interpolated loss function where we just use 25% of that one symbol per frame that works OK. And that acts as a kind of penalty to prevent it from emitting more than one symbol per frame in inference. So that we could be guaranteed that it's not going to affect the Word Error Rate results too badly.