Source: Deep Learning on Medium
Speech Recognition — Maximum Mutual Information Estimation (MMIE)
Many ASRs are trained with the MLE (Maximum likelihood estimation — details). It is one of the most popular methods in machine learning ML and deep learning DL. In this article, we will present another alternative for ASR, Maximum Mutual Information Estimation (MMIE), and look at some of its benefits and implementations.
So far the ASR training discussed in this series is based on maximum likelihood estimation using the HMM topology below.
This topology has a strong assumption. Observation at time t depends on the corresponding hidden state only, i.e. P(xᵢ | state sequence) = P(xᵢ | hidden stateᵢ). This is not exactly true. Even we can find a global optimal for MLE, the assumption may lead us to a sub-optimal solution. Indeed, many computed probabilities are likely over-estimated. In reality, observations are related to past and future states.
To counteract this problem, we previously introduce the context-dependent phone (triphone) and deltas in MFCC to model the speech context. This moves the HMM model closer to reality. But this serves as a hack only.
Maximum mutual information estimation (MMIE)
The equation below is the MLE.
In ASR, MMIE will be modeled as:
where κ (about 1/12) is a scaling fudge factor in correcting the overestimation.
If X (observations) and W (a word sequence) are completely independent according to the model λ, the equation equals 0. This independency implies X is unrelated to W. We want the opposite for the reference transcript. We want a model λ to maximize the corresponding MMIE. With P(W) independent of the model λ, MMIE can be further simplified as:
That actually maximizes the posterior probability (a.k.a. Maximum a posteriori (MAP)). Intuitively, we maximize the numerator and minimize the denominator. i.e. we increase the chance that the model is making a prediction similar to the reference label (the numerator) and decrease the chance for others (the denominator).
Here is another way to represent MLE
In theory, they may lead to the same optimal solution. But due to the model approximation, their solutions may differ from each other.
MMIE reduces the likelihood of word sequences other than the reference transcript. In our context here, we call it a discriminative training that boosts the right answer and lessens the wrong answers. MMIE is also a sequence training because we train a model in optimizing a sequence likelihood.
For classification problems, MMIE creates a cleaner decision boundary compared with MLE training. The following diagrams classify three different classes of data using Gaussian models with diagonal covariance (non-diagonal covariance elements are zero). As shown, MMIE training has better decision boundaries.
However, summing over all possible word sequences in the denominator is hard. The discriminative sequence training is much more complex than MLE training. So we will first explore a lattice method in approximating it.
To approximate the denominator in MMIE,
we create a lattice for each utterance in representing the alternative predictions. To simplify the calculation, only plausible sequences are included. We will ignore sequences with relatively small probabilities.
To create this lattice, we apply the WFST concept using a weaker language model (a shorter span language model like bigram). This is the same mechanism in producing candidates that encoded in a word lattice in the ASR decoding process. We just use a weaker language model to produce the candidates we needed. In addition, each word in the lattice will be expanded with the HMM phone representation.
For the numerator, we create another lattice. This lattice will be expanded to HMM phones accounted for different word pronunciations. But it will only contain the word sequence in the transcript. Conceptually, we want to sum over all different paths in the two lattices respectively to compute the sequence probability for the numerator and the denominator.
To accomplish that, we can simply apply the forward-backward algorithm to compute the forward and backward probabilities α and β.
And use them to compute the state occupancy probability γ for the numerator and the denominator.
The details for the forward-backward algorithm have been covered in a previous HMM article, we will not repeat here again.
The following is an example of applying the MMIE training. Basically, MMIE uses the forward-backward algorithm to compute the state occupancy probability to re-adjust the GMM acoustic model (say trained with MLE).
For details, here is the original paper. We just want to show you the skeleton of applying MMIE with the forward-backward algorithm.
Let’s apply gradient descent in optimizing MMI.
Its gradient w.r.t. the acoustic model is the difference of the state occupation probability.
This is a general formula regardless of the model being used. It can be a GMM HMM model or a deep network. For both scenarios, it involves the calculation of the state occupation probability γ given the word transcript or all other possible word sequences using the forward-backward algorithm. If the phone transition and observations are modeled with HMM, the gradient is:
Minimum phone error (MPE)
We can modify MMIE to optimize ASR based on the word error rate concept (WER). But MPE will do it at the phone level instead of the word level in measuring edit distance.
This involves a factor A in the numerator in measuring phone transcription accuracy. MPE is a weighted average of this phone accuracy. This weight is computed with the sequence probability over possible word sequences. It encourages high phone accuracy with the word sequence matched with the acoustic and the language model.
Deep networks excel in feature extraction and discovering correlation among them. This allows us to exploit contents in making predictions. In ASR, we can use a deep network to classify phones based on the features extracted in acoustic frames. We treat it like a classifier using softmax to output the probability distribution P(phone | xᵢ). The softmax pulls up the ground truth while pulls down the others — just the same concept as MMIE.
MLE in our discussion uses sequence training but it is not discriminative. The softmax function in the deep model here is discriminative.
How can we turn it into a discriminative sequence training? First, we train the classifier by minimizing the cross-entropy and use the model to generate alignments and lattices. This is the discriminative training phase. The second phase is the sequence training. Since both the deep network and the lattice are network objects, we can train them together. We use this model to compute the MMIE or MPE objective with the forward-backward algorithm and use backpropagation to learn the model parameter.
Lattice-Free MMI (LF-MMI)
Discussed before, in a lattice-based MMI, we first find the word lattice for the denominator. If we use a deep network to classify phones, we will pre-train it with cross-entropy. We also need to use the scaling fudge factor κ to correct the overestimation. All these sounds pretty ad-hoc when training a deep network. Can we avoid them?
In ASR, we use a composed transducer H ◦ C ◦ L ◦ G to decode audio. It is a WFST and can be integrated with the deep network classifier: if we step back for a second, it is just one big complex network. We can train it like DL using backpropagation. We don’t need to introduce a lattice to approximate the denominator.
The lattice-based methods are proposed before the GPU era. Training this deep network without a GPU is not feasible. Possibilities are open up when GPU demonstrates great success in DL in 2012. But there are physical limitations, in particular, the memory consumption. For example, to fit the method into the memory of the GPU, we chop the training utterances into 1–1.5s chunks. But that is not enough, to take advantage of the GPU, we need to avoid branching (GPU allows running one GPU instruction on multiple data one-at-a-time only). Pruned tree search is less appealing with GPU. We need a smaller model. Therefore, LF-MMI uses
- Phone-level language model (LM) instead of the word level (typically using 4-gram phone-level LM).
- No LM backoff (LM smoothing). LM backoff introduces many states.
- 30 ms frame rate instead of 10 ms in the feature extractions.
- Instead of using three states per phone, it uses only one state.
This model can be trained directly. We don’t need word lattice, pretraining or κ.
The gradient descent requires us to compute two sets of posterior probabilities: one from the numerator graph specific to an utterance and one for the denominator graph that encodes all possible word sequences. But, unlike the lattice-based MMI which the denominator graph is different for each utterance, this graph will be the same for all utterance in LF-MMI.
Both numerator and denominator state sequences are encoded as Finite State Transducer. It is built as the HCLG FST decoder discussed before. The computation of the denominator forward-backward computation will be parallelized by the GPU. To speed up the process, careful optimizations, including reversal, weight pushing, and minimization followed by epsilon removal, are performed on the denominator FST to minimize its size.
Because we chop off the utterances into 1.5s a piece, the initial probabilities of the corresponding FSTs need to be re-adjusted. (We kind of chop-off the utterance in the middle.) This initial probability is obtained by running the HMM for 100-time steps starting from the initial state and then averaging the distribution of states over these 100-time steps.
To avoid overfitting, LF-MMI applies two techniques. The first one is the usual L-2 regularization. In addition, the classifier will have two separate output layers, one trained with MMIE and the other with cross-entropy. So the overlapped layers will be trained with both objectives to reduce overfitting caused by a single objective.
The researcher paper also details some other implementation details. In particular, ways to avoid expensive operations and to avoid overflow and underflow problems. We encourage readers to read the paper if more information is needed.