Dissecting deepspeech.pytorch Part 2

Image for post
Image for post

This article is a part 2 of dissecting deepspeech.pytorch, which is one of implementations of Baidu’s DeepSpeech2 paper. I think deepspeech.pytorch is clean and relatively simple and very educative.

I explained overall concept and data processing in part 1. I’m going to explain the model architecture at part 2.

Architecture

Architecture is some 2D Convolutions, RNN layers, and Fully Connected layer.

Image for post
Image for post

Let’s see what key layers do.

Input Data

Input data for this model is something called mel spectrogram of audio, which is kind of like feature vector.

Image for post
Image for post

Vertical axis is feature count which is calculated based on argument for data processing like widow size. In my setting, actually default settings, the feature count is 161. Horizontal axis is time. So the number is variable. Longer audio give you bigger number, shorter audio give you smaller number. 457 is just an example.

Let’s see what shape do we input into our model.

first_batch = next(iter(train_loader))
inputs, targets, input_percentages, target_sizes = first_batch
inputs.shape
> torch.Size([16, 1, 161, 457])

torch.Size([16, 1, 161, 457])

We know where 161 and 457 come from.
But where do 16 and 1 come from?

16 is a batch size. and 1? actually I don’t know exactly but since the first part of model use nn.Conv2d like computer vision problem, I think 1 is to make the dimension of our audio data same as images.

Image for post
Image for post

MaskConv

MaskConv is the most tricky part in DeepSpeech model.
It seems two nn.Conv2d layers are wrapped by a class named MaskConv.

self.conv = MaskConv(nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True)
))

MaskConv creates a mask based on parameter called input_percentages and apply the mask to the activations. I’m going to explain input_percentages later.

But why do we need mask thing in the first place?

Since we use batch mechanism, sometimes we have audios with different durations in a batch. And this batch input is already masked by data processing scripts called _collate_fn function in data_loader.

All data in a batch has same length. But relatively shorter audio in a batch has zeros at end.

Image for post
Image for post

But after applying Conv2D layers, we lost masked parts. the shape is also changed.

Image for post
Image for post

So, MaskConv calculate which part to mask based on input_percentages and mask it again. input_percentages contains percent to max length like 0.75 and 0.80.

Image for post
Image for post

Now activations has shape of (16, 32, 41, 229). Number of time steps got smaller from 457 to 229. Again 457 is varied batch by batch based on audio durations.

RNN

Before putting activations into RNN layers, we change the shape from (16, 32, 41, 229) to (229, 16, 1312) to fit in RNN architecture.

Image for post
Image for post

Fully Connected

Final step is Fully Connected Layer. 29 is the token counts, which is alphabets and some tokens like space.

Image for post
Image for post

This (229. 16, 29) means each 229 time steps has probability of 29 tokens. 16 is batch count,.

That output is then input into CTC loss which I explained a bit at part 1.

After that, just a normal training process. If loss gets smaller and the score for validation datasets gets better, you’re good.

I hope this article is helpful.

I created a jupyter notebooks to go through entire code of deepspeech.pytorch. you can check it here.

I tried to make it work even on CPU but it seems CTC loss does not work with CPU. So you need GPU sorry about it. But you can still try it on Google Colab.

I'm a serial entrepreneur. I enjoy AI, UI, and blockchain. I like history and reading too.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store