Next-gen Kaldi: Reworked Conformer Model

BAAI Conference

Blog was created from Daniel Povey’s speech at BAAI Conference. Slides are taken from his presentation. https://www.youtube.com/watch?v=BXYZI9FEOgU&ab_channel=NadiraPovey

The conformer is just the latest version of the transformer that everyone is using in speech recognition and what it is: it's just the transformer plus a convolution module. Our reworked conformer is the conformer with a lot of changes that I'm going to describe below. It allows us to have faster training and more stable training and slightly better Word Error Rate. I think you'll be interested in this talk is that we have some insights about things that can go wrong with transformer training.

Conformer encoder

So here is a picture of the conformer encoder. This is the main part of the transformer:

image

It's a modified transformer. In each layer, in each module, we have a feed forward module, then self attention, then a convolution module, then feed forward. If you just have the first two - the bottom two - that would be a standard transformer. The red part in this diagram is the LayerNorm. We use so-called pre-norm that at the start of each module there is a LayerNorm and also at the output of the whole layer there's a LayerNorm.

BatchNorm: considered harmful

image

We were originally using BatchNorm at the input to the whole thing to normalize the speech features. I'm saying that BatchNorm should be considered harmful. Here's the reason BatchNorm itself works great like there's no problem with it. But the problem is if you do anything unusual BatchNorm is going to create a lot of problems. If you group the training utterances by length that's going to be a problem for BatchNorm because BatchNorm has this assumption that the elements of the mini batch are sampled independently.

So if you do something like you group the utterances by length or maybe you have two different data sets and you want to do them in different mini-batches for some reason. This is going to cause a lot of problems with BatchNorm. Basically it's going to affect your results because it's a train/test mismatch. That in training time your model is seeing utterances that are grouped in a non-random way and then in test time because we use the averaged BatchNorm stats, it's as if the utterances are grouped randomly and it will affect the results.

Also, if you do fine-tuning on a different type of data it's going to be a problem. It's actually ironic because in the original BatchNorm paper they say it's a solution to covariate shift and covariate shift just means that you have a different type of data that has different distribution. So this is actually the thing that BatchNorm cannot deal with. If you have a different data distribution and different parts, if you're training and you separate them, actually BatchNorm is really bad at dealing with that.

My position is that if possible it's going to save you a lot of headache if you just take batch norm out of the model.

Removing BatchNorm

image

One of the first changes we made to the model is to just replace BatchNorm with LayerNorm and the convolution module and also for the input to the network. Actually instead of LayerNorm we just replaced it with a scale of 0.1. The reason that's necessary is because the log mel features have a very high dynamic range like -20 to 0 and because that's so large the bias parameters cannot really get large enough to be comparable to that. If you multiply this -20 to 0.0 by the weight. It's because it's still going to be large and the bias can not train big enough. We just replaced it with a scale.

A lot of you are probably thinking why are you using log mel features, everyone is using the waveform. A lot of people do use the waveform as an input and of course it's great to use that but it doesn't really give you any advantage. A lot of people have published papers and conferences like NIPS saying: “oh we have this thing and it's better than log mel cepstrum”. But usually there's some problem with how they did that comparison and whenever anyone does a reasonable comparison it turns out it's not really better.

Now this may not always be true maybe in future someone is going to really show that it's better to use waveform but we use filterbanks because they're more compressible and you can load them much faster from disk, otherwise you become limited by IO.

Removing LayerNorm

image

The next thing is we'd like to remove LayerNorm like if you look at this model it seems like there's too much LayerNorm. Let me get into why we need LayerNorm and I'll maybe be able to explain why I think it has too much LayerNorm.

LayerNorm recap

image

Let's recap what LayerNorm is. LayerNorm is just a formula applied to a vector that could be like one frame of speech or one pixel of an image. It has three stages:

  • We remove the mean - which is actually a waste of time in most cases.
  • We normalize the length of the vector and there's an epsilon in that formula
  • We apply a learned per-dimension scale and offset Îł and β, like BatchNorm

Why LayerNorm needed

image

Why do we need a LayerNorm? In a deep model unless you normalize every so often, eventually the activations are gonna become like infinity or zero and then your training is not going to be stable, it's gonna fail. In addition to the normalization, the learnable per-dimension scales the gamma, those help to balance the size of the activations and the relative contributions of the different modules. There's four modules per layer right in the conformer. We want to be able to scale up or scale down each individual layer. It's very hard to do that by learning the size of the weight matrices of the layers because they're so large that to scale them down, you're basically trying to learn a very high dimensional direction right, it's too hard to learn. It's much easier if you have like a separate scale, like the per-dimension scale and BatchNorm. If you can just scale down the output of that module to balance the contribution of those four things.

That formula shows you how just due to the gradient noise from each mini batch, this is how the size of the weight matrix will change over during training naturally. If the learning rate is constant it'll follow like a square root of t line and actually if you look at the weights in transformer models you'll see that they're quite close to this formula. This formula is only valid if the optimizer is Adam with no weight decay, with weight decay you have to compensate. The point is, it's better to just let the weight matrix naturally do its thing and get larger and larger and have something else that's going to scale it down and that's where the BatchNorm and LayerNorm comes in.

LayerNorm failures 1: Module Death

image

Now there are certain bad things that can happen during transformer training. In our Icefall we implemented a lot of diagnostics. PyTorch has these forward hooks and backward hooks that you can add. It's possible to use those to accumulate a bunch of statistics to tell you like how big the outputs of the different modules are.

So that's how we discovered certain problems that can happen in transformer training. You can get whole modules dying. This means that let's say one of the feed forward modules. After training you'll see that the output of the module is almost zero: it's like 10 ^(-6) something like that. That seems to be something that happens with the BatchNorm weight. If early in training that module was not useful the batch norm weights go almost to zero to turn it off and the problem is if the BatchNorm weights are oscillating around zero they're becoming positive and negative the modules. The module never learns anything useful because the sign of the gradient is constantly flipping. It doesn't know which direction to learn. Basically in that case the module never learns anything. This happens particularly for the lower down feed forward layers. It can happen for quite a few of them.

LayerNorm failures (?) 2: Large constant channel

image

Now there's another odd or undesirable thing perhaps that happens with LayerNorm that especially in the middle of the network after training the model we notice that often one channel out of the let's say 512 or 256. One channel gets very large and constant. You look at activation and it's let's say 51 ± 1. It's almost constant now. The individual modules don't really “see this” large activation because it gets scaled down by their input LayerNorms and my belief is that the reason the system does learns this large value is: it doesn't want to get LayerNorm. LayerNorm removes the length information right from the vector because it normalizes the length. The model doesn't want that, so it adds this large constant value which functions like a very large epsilon in that formula and that means that the the length of the vector the rest the length of the rest of the vector elements excluding that one large value is almost constant uh at the output of the LayerNorm. We're going to use this insight to help to replace LayerNorm and to fix this problem. This negative large value is going to be a problem if you want to use integerization (8-bit activations). It's going to be a problem to have this big large value.

Removing LayerNorm

image

Here's how we remove layer norm. There's a few things that we had to change firstly, remove LayerNorm at the inputs of the modules and at the output of the whole layer, we replace it with something called “BasicNorm”. BasicNorm is just we normalize the length of the vector but we normalize it with a large epsilon. This epsilon like we initialize it to one and we learn it in log space. In the end it's gonna be like two or four or five or something. What that does is that because the square root inside it is dominated by the epsilon. The output vector length does depend on the input vector length meaning. If the input vector is longer the output is going to be longer. Now because we need to balance the contribution of the different modules and we don't have the learnable weight, we modify all our weights and biases by instead of self.weight we use self.weights * self.weight_scale.exp(). We're multiplying it by a learned scale. That scale can only be positive and it's important that it should only be positive because otherwise we can get the oscillating gradient problem that I mentioned before. If the scale is going negative, positive, negative, positive. We never learn anything because the gradient keeps flipping. That's how we removed LayerNorm.

Another failure mode: “dead” neurons

image

We made some other changes to the model too. There's another failure mode or another bad thing that can happen during model training is inside the feed forward modules and there's two feed forward modules per layer. A lot of neurons can die and what that means is there's a hidden dimension inside the feed forward. Let's say it can be 2048. We might find that 10 or even 50 even 80 of those neurons are always negative. The activation is very similar to ReLU it's actually something called swish but it's really like and if the activations are always negative the output is always zero or at least always close to zero so it's not doing anything useful uh so we have a way to fix this by modifying the gradients. I'll explain it later.

Another failure mode: too-large or too-small activations

image

There is another failure mode that's less common, again we discovered this by looking at the diagnostics that I mentioned. The actual activation inside the feed forward modules and also the convolution module is something called swish, in PyTorch it's called SiLU. It's just x times sigmoid of x. If the model is well trained you'll see that the normal magnitude of the input to this SiLU will be between about one and three that's the root mean square value of a dimension. There are two ways that the model cannot train properly:

  1. If the inputs become too small like 10 to the minus 4, 10 to the minus 5, the function is almost linear like if you scale up the function it's just it's just y equals x and in that case. It's not learning anything useful it's not there's no nonlinearity
  2. The other problem is the inputs to the function can become very large like 100. In that case it's just the same as ReLU because if you zoom out from that, it's the same ReLU and that's not optimal because ReLU is not quite as good as this Swish. Again this is something that we're going to fix by detecting it and messing with the gradients.

Our solution to both of these problems that I mentioned just now is something called the activation balancer.

Our solution: ActivationBalancer

image

This is a module that does nothing in the forward pass. In the forward pass it's just y=x and in the backward pass it's going to modify the gradients a little bit. Now we place this module just before the non-linearity - just before the swish. What this module does is in the forward pass it does nothing but it accumulates some statistics and it works out which channel dimensions are problematic. Problematic means for instance less than 5% of the values are more than zero. This five percent number actually comes from Kaldi. We used to do the same thing in Kaldi. For those dimensions we want to penalize them if they're negative.

What we discovered in early experiments is that if you just add a penalty to the loss function, what the model does is, it just concentrates all the positive values and the padding frames. Because the utterances have different lengths right, we have to pad with zero. The model learns to just put make those padding frames very positive so that's not what we want. Instead we fix this by modifying the back propagated derivatives by multiplying them by (1+ ε) or (1 - ε) and which one we choose depends on the sign of the input and also the sign of the derivative. You can work out what the rule is. It's quite simple. We also use a similar approach for the wrong magnitude problem. If the mean absolute value is less than 0.2 or more than 10 or something, we can use the same kind of formula with epsilon to encourage the inputs to have the magnitude that we want.

New nonlinearity: DoubleSwish

image

There's also something quite cute that we discovered accidentally. At some point, I had a bug in some script that was applying the swish not only non-linearity twice and I noticed that the experiment was a little bit better. After I fixed that it got worse, so I'm like why is that so, it turns out that. It's actually a better non-linearity. I'm calling it double swish because it's just swish of swish of x. You can approximate that function by x times sigmoid of x minus one it's an

almost identical function. If you look at that figure those two lines very close together one is one and one is the other. So we're actually using this uh sigmoid of x minus one times x. It gives it gives a very small improvement it's like 0.1 percent whatever rate but we've validated it in a few different setups and it seems to be consistent and real.

New nonlinearity: DoubleSwish [larger version]

image

This is just a larger version of the uh the figure showing the non-linear. It's not just a scaled version of swish it's actually, it's qualitatively different because the the derivative at zero is one core two instead of one half and eventually it becomes just linear, like y=x but the behavior around zero is different. There's another problem that we fixed and you can argue that maybe this is not a problem but it's it's. I didn't like it.

Fixed parameter norms

image

As I mentioned before, the parameter tensors will naturally grow throughout training. Each iteration gives you gradient noise and that gradient noise adds up during the entire training. One way to fix this is just to use the learnable scales and BatchNormal/LayerNorm to scale things down and that kind of works, but eventually it can become unstable because once those BatchNorm/ LayerNorm, scales get very close to zero. The output can become too sensitive to them because the actual the big tensors are too large. So eventually, after training many iterations your transformer or conformer can suddenly become unstable and go to

infinity. Another reason I don't like these growing parameter tensors is it makes the learning rate schedule very hard to interpret because what can happen is that even though the learning rate is growing you're actually learning slower and slower in relative terms because the parameter matrix size is growing even faster than the learning rate

Fixed parameter norms

image

Our solution to these varying parameter norms is, we just limit the root mean square value of all non-scalar tensors to 0.1 and we do this if it gets larger than that threshold. We start to apply weight decay so this is just a change to Adam. If you take AdamW, which is just Adam with shrinkage instead of for the weight decay. Instead of modifying the gradient you just shrink the whole parameter matrix that's AdamW. So we change that to Eve which is we only apply the shrinkage after you get root mean square more than 0.1. This makes the learning schedule much easier to interpret but we also have to use a different learning rate schedule because the conventional one doesn't work anymore. We have to decay it more aggressively for large t.

Learning rate schedules 1

image

So here is the standard learning rate schedule for transformers. It's something called the gnome or sometimes people call it just a transformer learning rate schedule. It starts to increase linearly and then it decreases as as 1/sqrt(t). What you'll find is if you try to train a transformer using a conventional learning rate schedule it just won't train it it will never learn anything. I mean certainly that's true for speech in our in our experience.

Learning rate schedules 2

image

We actually as part of this rework of the conformer, we're removing the warm-up period and instead of warm up we implement warm up at the model level. That arrow around the side, the curved arrow there in the picture, that's indicating a layer level bypass so what we do is we we have this constant called a warm-up that we feed it into the model and it begins at zero and increases to one and then it stays at one. So for the first three 3000 iterations this warm-up is increasing and basically for each layer we bypass the layer proportional to warm-up. So look at that formula. We return that return warmup*output + (1-warmup). Effectively if warm-up is close to zero the model acts like a very shallow model and it's much easier to learn and after we make this change to do the model level warm-up and use a different learning rate schedule we find it optimizes way way faster so our learning rate schedule that that we after tuning.

Learning rate schedules 3

image

Here is what we came up with I'm putting in like actual constants these are obviously tunable but if n is the epoch and t is the mini-batch from the very start of training including all epochs this is the formula. It starts off approximately constant then it becomes like 1/sqrt(t) and then eventually it decreases like 1/t. The reason we make it dependent on both the mini-batch index and also the epoch index is, it gives us a better kind of invariance when we change the many batch size or the number of workers. Basically we want to have about the same rate of progress per epoch regardless of how many workers we have. If we have more workers we want to learn a little bit faster per mini batch. This formula gives us the right kind of invariant. The idea is we just tune it once and then when we change the number of workers or the mini-batch size we don't have to retune it.

Learning rate schedules 4

image

This figure just compares the red one is our formula and the orange one is the standard uh transformer learning rate schedule. You cannot really compare these two because they're intended for different types of model. Our thing is with the fixed parameter norms so in reality the transformer learning rate schedule is decaying faster than that because the matrices are getting larger so the relative change is smaller.

Learning rate schedules 5

image

We did notice that after you fix the parameter norms the learning speed becomes much more

sensitive to the exact learning rate schedule. The reason for this is that if you don't fix the parameter norms, if you use a larger learning rate, your parameter matrices get larger. The relative change in learning rate is smaller and in fact if you multiply the whole learning rate schedule by two you'll train faster initially but for large t you're actually training about the same speed because eventually the parameter matrices will just be twice larger approximately and it turns out you can do some analysis on this using the formula that I mentioned previously even if you have a t to the p type schedule, even if you change that power if you think about the if you work out the relative learning rate. In theory you're really just changing it by a constant factor for large t. This is true only for p > (- 0.5).

You can work out this from the formula that I mentioned a few slides earlier. I can't find the slide it's that formula with a square root in it. The point is that even though the transformer learning rate schedule looks very weird, it's actually not as weird as it looks because even while it's in even while the learning rate is increasing it's actually learning slower and slower for at least once it grows beyond the initial parameter size.

Convergence

image

This graph shows the speed of convergence of our new model in blue compared to the old model in orange and you can see that for the first part of training we're training way faster like 10 times faster or more. Because we don't have this warm-up period and we have this model level warm-up it's way faster and even later on we're still faster it's like maybe twice faster or at least 60% faster sorry I mean 40% faster or something like that.

WER/accuracy changes

image

We also have small Word Error Rate improvements. Unfortunately, all the things I mentioned above, I tuned those on a smaller subset like LibriSpeech 100. It turns out that not all of those changes are fully additive but when we test the whole thing on full LibriSpeech we still get an improvement. The main improvement of all of these things is the recipe is more stable. It is way faster. First, we don't need so many epochs we're using like 30 epochs instead of 16. Also per epoch it's a little bit faster because we took out all of those LayerNorms, plus it is open, plus it's more stable. We're also working on a further improved version of this optimizer where we use the scatter of the gradients to do a change in coordinates before we do Adam, because Adam is per element if you rotate the parameter space, it actually is going to change how it behaves. I'm not going to describe that right now it's not really finished.

Summary of changes

image

This is just a summary of what I've said before. To help you guys remember it. First, I had this argument that BatchNorm should be considered harmful because it causes hassles for things like fine-tuning and the data loaders.

Second, we replaced the LayerNorm with this BasicNorm. That's just length normalization with a learnable epsilon. We took out the learnable scales from LayeredNorm from most of the modules, we make everything learnable by having this weight scale thing and it has to be positive to prevent that oscillating sign problem that I mentioned. We fixed a bunch of different problems with this activation balancer.

We're detecting dead neurons. We're just detecting values that are too small or too large. To make the learning rate schedule more interpretable and to make everything more predictable, we're limiting the root mean square value of all the non-scalar parameter matrix is to 0.1.

We fix the size of all the matrices. We're doing model warm-up and if we're doing warm-up in a different way. We have this learnable not learnable we have this variable bypass of all the layers to help it to train, early in training to help it to converge. We have a new learning rate schedule there's no warm up with a different formula and we found that replacing Swish with the DoubleSwish is helpful.

That link there is just to the recipe in our repository that has all of these changes. We've actually made some further changes later on so we're up to like prune transducer stateless five or six now because we we've improved the recipe in other ways but this is the first place that we used all this.

image

That link is to a Wechat group for discussing Next-gen Kaldi if you're interested. You may be able to get one of these t-shirts and those are just links to our main open source repositories we have other ones too but probably Icefall is the top level, you know the recipes and then that will download the other ones.

References:

Icefall:

Dan #30 Daniel Povey BAAI 2022 Full Version https://youtu.be/Q3gNj7XlArs powerpoint slides: https://shorturl.at/KMVY4 try latest k2 model here: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition