注意
点击 这里 下载完整示例代码
使用CTC解码器进行ASR推理¶
作者: Caroline Chen
本教程演示了如何使用带有词典约束和 KenLM 语言模型支持的 CTC 贝叶斯搜索解码器进行语音识别推理。我们将在一个使用 CTC 损失训练的预训练 wav2vec 2.0 模型上展示这一过程。
概述¶
束搜索解码通过迭代扩展文本假设(束),使用下一个可能的字符,并在每个时间步仅保留得分最高的假设。可以将语言模型纳入评分计算中,添加词典约束可以限制假设的下一个可能标记,使得只能生成词典中的词语。
底层实现是从 Flashlight 的 束搜索解码器移植而来的。解码器优化的数学公式可以在 Wav2Letter 论文 中找到, 更详细的算法可以在这篇 博客 中找到。
使用带有 KenLM 语言模型和词汇约束的 CTC 束搜索解码器运行 ASR 推理需要以下组件
声学模型:从音频波形预测语音特征的模型
Tokens: 语音模型可能预测的标记
词典:可能的单词与对应标记序列之间的映射
KenLM: 使用 KenLM库 训练的n-gram语言模型
准备¶
首先我们导入必要的工具并获取我们正在使用的数据
import time
from typing import List
import IPython
import matplotlib.pyplot as plt
import torch
import torchaudio
try:
from torchaudio.models.decoder import ctc_decoder
except ModuleNotFoundError:
try:
import google.colab
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
"""
)
except ModuleNotFoundError:
pass
raise
声学模型与数据¶
我们使用预训练的 Wav2Vec 2.0
Base 模型,该模型在 LibriSpeech
数据集 的 10 分钟数据上进行了微调,可以通过
torchaudio.pipelines() 加载。有关在 torchaudio 中运行 Wav2Vec 2.0 语音识别流水线的更多详细信息,请参阅 此
教程。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()
Out:
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ll10m.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ll10m.pth
0%| | 0.00/360M [00:00<?, ?B/s]
1%| | 2.18M/360M [00:00<00:20, 18.2MB/s]
4%|4 | 15.5M/360M [00:00<00:04, 82.7MB/s]
7%|6 | 24.0M/360M [00:00<00:05, 67.0MB/s]
9%|8 | 32.0M/360M [00:00<00:05, 62.8MB/s]
13%|#3 | 48.0M/360M [00:00<00:03, 92.8MB/s]
17%|#7 | 62.8M/360M [00:00<00:02, 108MB/s]
21%|## | 73.9M/360M [00:00<00:03, 87.6MB/s]
23%|##3 | 83.3M/360M [00:01<00:04, 58.4MB/s]
29%|##8 | 104M/360M [00:01<00:03, 87.1MB/s]
32%|###1 | 115M/360M [00:01<00:02, 90.4MB/s]
36%|###5 | 128M/360M [00:01<00:02, 84.9MB/s]
40%|###9 | 144M/360M [00:01<00:02, 99.2MB/s]
43%|####3 | 155M/360M [00:01<00:02, 100MB/s]
46%|####5 | 166M/360M [00:02<00:02, 84.6MB/s]
49%|####8 | 176M/360M [00:02<00:02, 79.1MB/s]
53%|#####3 | 192M/360M [00:02<00:01, 94.1MB/s]
56%|#####6 | 202M/360M [00:02<00:02, 71.6MB/s]
58%|#####8 | 210M/360M [00:02<00:02, 75.3MB/s]
62%|######2 | 224M/360M [00:02<00:01, 79.9MB/s]
67%|######6 | 240M/360M [00:03<00:01, 91.8MB/s]
69%|######9 | 249M/360M [00:03<00:01, 73.5MB/s]
71%|#######1 | 257M/360M [00:03<00:01, 75.6MB/s]
75%|#######5 | 271M/360M [00:03<00:01, 86.5MB/s]
78%|#######7 | 280M/360M [00:03<00:01, 84.0MB/s]
81%|######## | 290M/360M [00:03<00:00, 90.2MB/s]
84%|########4 | 304M/360M [00:03<00:00, 94.1MB/s]
87%|########6 | 313M/360M [00:04<00:00, 69.3MB/s]
91%|#########1| 328M/360M [00:04<00:00, 87.5MB/s]
98%|#########7| 352M/360M [00:04<00:00, 124MB/s]
100%|##########| 360M/360M [00:04<00:00, 87.6MB/s]
我们将从 LibriSpeech test-other 数据集中加载一个样本。
hub_dir = torch.hub.get_dir()
speech_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/1688-142285-0007.wav"
speech_file = f"{hub_dir}/speech.wav"
torch.hub.download_url_to_file(speech_url, speech_file)
IPython.display.Audio(speech_file)
Out:
0%| | 0.00/441k [00:00<?, ?B/s]
100%|##########| 441k/441k [00:00<00:00, 7.35MB/s]
该音频文件对应的文本是
waveform, sample_rate = torchaudio.load(speech_file)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
文件和数据用于解码器¶
接下来,我们加载令牌、词典和 KenLM 数据,解码器将利用这些数据根据声学模型的输出预测单词。LibriSpeech 数据集的预训练文件可通过 torchaudio 下载,用户也可以提供自己的文件。
Tokens¶
这些符号是声学模型可以预测的可能符号,包括空白符号和静音符号。它们既可以作为文件传入,每行对应相同索引的符号,也可以作为符号列表传入,每个符号映射到唯一的索引。
# tokens.txt
_
|
e
t
...
Out:
['-', '|', 'e', 't', 'a', 'o', 'n', 'i', 'h', 's', 'r', 'd', 'l', 'u', 'm', 'w', 'c', 'f', 'g', 'y', 'p', 'b', 'v', 'k', "'", 'x', 'j', 'q', 'z']
术语表¶
词典是将单词映射到其对应标记序列的映射,用于将解码器的搜索空间限制为仅来自词典中的单词。词典文件的预期格式是每行一个单词,单词后跟由空格分隔的标记。
# lexcion.txt
a a |
able a b l e |
about a b o u t |
...
...
KenLM¶
这是一个使用 KenLM库 训练的n-gram语言模型。可以使用 .arpa 或者二进制化的 .bin 语言模型,但推荐使用二进制格式以加快加载速度。
本教程中使用的语言模型是一个使用 LibriSpeech训练的4-gram KenLM。
下载预训练文件¶
LibriSpeech 数据集的预训练文件可以使用
download_pretrained_files下载。
注意:由于语言模型可能较大,此单元格运行可能需要几分钟。
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
Out:
0%| | 0.00/4.97M [00:00<?, ?B/s]
100%|##########| 4.97M/4.97M [00:00<00:00, 69.4MB/s]
0%| | 0.00/57.0 [00:00<?, ?B/s]
100%|##########| 57.0/57.0 [00:00<00:00, 41.5kB/s]
0%| | 0.00/2.91G [00:00<?, ?B/s]
0%| | 4.37M/2.91G [00:00<01:08, 45.7MB/s]
0%| | 8.73M/2.91G [00:00<01:09, 44.6MB/s]
1%| | 23.7M/2.91G [00:00<00:32, 95.2MB/s]
1%|1 | 32.9M/2.91G [00:00<01:20, 38.4MB/s]
2%|1 | 46.8M/2.91G [00:01<01:05, 46.9MB/s]
2%|1 | 53.0M/2.91G [00:01<01:19, 38.4MB/s]
2%|1 | 57.9M/2.91G [00:01<01:26, 35.3MB/s]
2%|2 | 62.0M/2.91G [00:01<01:25, 35.9MB/s]
2%|2 | 66.0M/2.91G [00:01<01:45, 29.0MB/s]
3%|2 | 79.7M/2.91G [00:02<01:12, 42.0MB/s]
3%|2 | 84.1M/2.91G [00:02<01:20, 37.5MB/s]
3%|3 | 93.2M/2.91G [00:02<01:03, 48.0MB/s]
3%|3 | 98.6M/2.91G [00:02<01:01, 49.2MB/s]
3%|3 | 104M/2.91G [00:02<01:03, 47.7MB/s]
4%|4 | 127M/2.91G [00:02<00:35, 84.0MB/s]
5%|4 | 135M/2.91G [00:02<00:45, 65.4MB/s]
5%|4 | 144M/2.91G [00:03<00:45, 64.8MB/s]
5%|5 | 150M/2.91G [00:03<00:49, 60.4MB/s]
5%|5 | 160M/2.91G [00:03<00:49, 59.6MB/s]
6%|5 | 168M/2.91G [00:03<00:45, 64.9MB/s]
6%|5 | 176M/2.91G [00:03<00:49, 59.6MB/s]
6%|6 | 182M/2.91G [00:03<00:57, 51.2MB/s]
7%|6 | 202M/2.91G [00:03<00:33, 85.9MB/s]
7%|7 | 212M/2.91G [00:04<00:36, 80.4MB/s]
8%|7 | 224M/2.91G [00:04<00:33, 85.5MB/s]
8%|7 | 233M/2.91G [00:04<00:37, 76.6MB/s]
8%|8 | 241M/2.91G [00:04<00:42, 68.4MB/s]
9%|8 | 255M/2.91G [00:04<00:38, 73.3MB/s]
9%|8 | 262M/2.91G [00:04<00:45, 62.3MB/s]
9%|9 | 268M/2.91G [00:05<00:49, 57.1MB/s]
9%|9 | 281M/2.91G [00:05<00:42, 67.1MB/s]
10%|9 | 288M/2.91G [00:05<00:43, 65.3MB/s]
10%|9 | 295M/2.91G [00:05<00:49, 56.5MB/s]
10%|# | 300M/2.91G [00:05<00:52, 53.4MB/s]
10%|# | 306M/2.91G [00:06<01:28, 31.6MB/s]
10%|# | 310M/2.91G [00:06<01:38, 28.4MB/s]
11%|# | 320M/2.91G [00:06<01:09, 40.3MB/s]
11%|# | 327M/2.91G [00:06<00:59, 46.7MB/s]
11%|#1 | 333M/2.91G [00:06<01:02, 44.4MB/s]
11%|#1 | 338M/2.91G [00:07<01:41, 27.4MB/s]
12%|#1 | 352M/2.91G [00:07<01:07, 40.6MB/s]
12%|#1 | 357M/2.91G [00:07<01:08, 40.0MB/s]
12%|#2 | 362M/2.91G [00:07<01:09, 39.3MB/s]
13%|#2 | 382M/2.91G [00:07<00:37, 72.0MB/s]
13%|#3 | 398M/2.91G [00:07<00:28, 93.7MB/s]
14%|#3 | 410M/2.91G [00:07<00:27, 96.4MB/s]
14%|#4 | 421M/2.91G [00:08<00:41, 64.8MB/s]
14%|#4 | 432M/2.91G [00:08<00:38, 69.4MB/s]
15%|#5 | 448M/2.91G [00:08<00:34, 77.5MB/s]
15%|#5 | 456M/2.91G [00:08<00:40, 65.4MB/s]
16%|#5 | 464M/2.91G [00:08<00:56, 46.7MB/s]
16%|#5 | 469M/2.91G [00:09<00:58, 45.1MB/s]
16%|#6 | 480M/2.91G [00:09<00:53, 48.5MB/s]
17%|#6 | 496M/2.91G [00:09<00:40, 64.0MB/s]
17%|#7 | 512M/2.91G [00:09<00:31, 82.1MB/s]
18%|#7 | 527M/2.91G [00:09<00:29, 86.2MB/s]
18%|#7 | 536M/2.91G [00:09<00:31, 81.9MB/s]
18%|#8 | 545M/2.91G [00:10<00:38, 65.9MB/s]
19%|#8 | 560M/2.91G [00:10<00:36, 69.8MB/s]
19%|#9 | 572M/2.91G [00:10<00:31, 81.0MB/s]
20%|#9 | 581M/2.91G [00:10<00:34, 72.2MB/s]
20%|#9 | 591M/2.91G [00:10<00:32, 77.2MB/s]
20%|## | 599M/2.91G [00:10<00:37, 66.6MB/s]
20%|## | 607M/2.91G [00:11<00:35, 69.2MB/s]
21%|## | 614M/2.91G [00:11<00:41, 60.3MB/s]
21%|## | 623M/2.91G [00:11<00:37, 65.5MB/s]
21%|##1 | 630M/2.91G [00:11<00:38, 64.7MB/s]
21%|##1 | 640M/2.91G [00:11<00:37, 65.7MB/s]
22%|##1 | 646M/2.91G [00:11<00:46, 52.1MB/s]
22%|##1 | 652M/2.91G [00:12<01:04, 38.1MB/s]
22%|##2 | 656M/2.91G [00:12<01:24, 28.8MB/s]
23%|##2 | 672M/2.91G [00:12<00:51, 46.8MB/s]
23%|##2 | 684M/2.91G [00:12<00:39, 61.1MB/s]
23%|##3 | 692M/2.91G [00:12<00:50, 47.5MB/s]
24%|##3 | 702M/2.91G [00:13<00:42, 56.3MB/s]
24%|##3 | 709M/2.91G [00:13<00:44, 53.2MB/s]
24%|##4 | 720M/2.91G [00:13<00:40, 59.2MB/s]
24%|##4 | 727M/2.91G [00:13<00:46, 51.3MB/s]
25%|##4 | 732M/2.91G [00:13<00:51, 45.8MB/s]
25%|##4 | 737M/2.91G [00:14<01:20, 29.1MB/s]
25%|##4 | 745M/2.91G [00:14<01:03, 36.9MB/s]
25%|##5 | 752M/2.91G [00:14<00:57, 40.4MB/s]
26%|##5 | 764M/2.91G [00:14<00:41, 56.3MB/s]
26%|##5 | 771M/2.91G [00:14<00:46, 49.6MB/s]
26%|##6 | 784M/2.91G [00:14<00:38, 59.1MB/s]
27%|##6 | 791M/2.91G [00:14<00:38, 60.1MB/s]
27%|##6 | 800M/2.91G [00:15<00:36, 62.3MB/s]
27%|##7 | 815M/2.91G [00:15<00:32, 70.1MB/s]
28%|##7 | 822M/2.91G [00:15<00:36, 62.0MB/s]
28%|##7 | 832M/2.91G [00:15<00:34, 64.4MB/s]
28%|##8 | 838M/2.91G [00:15<00:37, 59.6MB/s]
28%|##8 | 844M/2.91G [00:15<00:42, 53.1MB/s]
29%|##8 | 849M/2.91G [00:15<00:41, 53.9MB/s]
29%|##9 | 864M/2.91G [00:16<00:28, 78.2MB/s]
30%|##9 | 880M/2.91G [00:16<00:22, 98.9MB/s]
30%|##9 | 890M/2.91G [00:16<00:37, 58.8MB/s]
30%|### | 898M/2.91G [00:16<00:43, 50.0MB/s]
30%|### | 906M/2.91G [00:16<00:38, 56.4MB/s]
31%|### | 914M/2.91G [00:17<00:40, 53.9MB/s]
31%|### | 920M/2.91G [00:17<00:46, 46.4MB/s]
32%|###1 | 944M/2.91G [00:17<00:28, 73.7MB/s]
32%|###2 | 960M/2.91G [00:17<00:28, 74.8MB/s]
32%|###2 | 967M/2.91G [00:17<00:31, 67.5MB/s]
33%|###2 | 976M/2.91G [00:17<00:32, 65.5MB/s]
33%|###3 | 987M/2.91G [00:18<00:27, 74.9MB/s]
33%|###3 | 995M/2.91G [00:18<00:28, 74.3MB/s]
34%|###3 | 0.98G/2.91G [00:18<00:25, 81.4MB/s]
34%|###4 | 0.99G/2.91G [00:18<00:35, 57.5MB/s]
35%|###4 | 1.01G/2.91G [00:18<00:27, 75.6MB/s]
35%|###4 | 1.02G/2.91G [00:18<00:26, 76.5MB/s]
35%|###5 | 1.02G/2.91G [00:18<00:27, 72.6MB/s]
35%|###5 | 1.03G/2.91G [00:19<00:28, 70.6MB/s]
36%|###5 | 1.04G/2.91G [00:19<00:25, 80.2MB/s]
36%|###6 | 1.05G/2.91G [00:19<00:27, 73.1MB/s]
36%|###6 | 1.06G/2.91G [00:19<00:32, 60.7MB/s]
37%|###6 | 1.06G/2.91G [00:19<00:48, 40.7MB/s]
37%|###6 | 1.07G/2.91G [00:20<00:51, 38.6MB/s]
37%|###7 | 1.08G/2.91G [00:20<00:41, 47.6MB/s]
38%|###7 | 1.09G/2.91G [00:20<00:28, 68.5MB/s]
38%|###7 | 1.10G/2.91G [00:20<00:29, 67.0MB/s]
38%|###8 | 1.11G/2.91G [00:20<00:31, 60.9MB/s]
38%|###8 | 1.12G/2.91G [00:20<00:29, 65.5MB/s]
39%|###8 | 1.12G/2.91G [00:20<00:31, 61.8MB/s]
39%|###8 | 1.13G/2.91G [00:21<00:36, 52.6MB/s]
39%|###9 | 1.14G/2.91G [00:21<00:34, 55.0MB/s]
39%|###9 | 1.15G/2.91G [00:21<00:46, 40.5MB/s]
40%|#### | 1.17G/2.91G [00:21<00:22, 82.9MB/s]
41%|#### | 1.19G/2.91G [00:21<00:19, 95.3MB/s]
41%|####1 | 1.20G/2.91G [00:21<00:20, 91.6MB/s]
42%|####1 | 1.21G/2.91G [00:22<00:20, 89.8MB/s]
42%|####1 | 1.22G/2.91G [00:22<00:20, 86.6MB/s]
42%|####2 | 1.23G/2.91G [00:22<00:20, 87.8MB/s]
43%|####2 | 1.25G/2.91G [00:22<00:20, 87.4MB/s]
43%|####3 | 1.26G/2.91G [00:22<00:30, 58.9MB/s]
44%|####3 | 1.27G/2.91G [00:22<00:26, 66.8MB/s]
44%|####3 | 1.28G/2.91G [00:23<00:26, 65.0MB/s]
44%|####4 | 1.28G/2.91G [00:23<00:31, 54.8MB/s]
44%|####4 | 1.29G/2.91G [00:23<00:28, 60.2MB/s]
45%|####5 | 1.31G/2.91G [00:23<00:19, 89.4MB/s]
45%|####5 | 1.32G/2.91G [00:23<00:20, 81.3MB/s]
46%|####5 | 1.33G/2.91G [00:23<00:20, 81.9MB/s]
46%|####6 | 1.34G/2.91G [00:23<00:17, 95.5MB/s]
46%|####6 | 1.35G/2.91G [00:24<00:20, 80.6MB/s]
47%|####6 | 1.36G/2.91G [00:24<00:21, 76.7MB/s]
47%|####7 | 1.37G/2.91G [00:24<00:35, 46.2MB/s]
47%|####7 | 1.37G/2.91G [00:24<00:34, 47.3MB/s]
48%|####8 | 1.40G/2.91G [00:24<00:19, 84.6MB/s]
48%|####8 | 1.41G/2.91G [00:24<00:17, 92.5MB/s]
49%|####8 | 1.42G/2.91G [00:25<00:16, 96.6MB/s]
49%|####9 | 1.44G/2.91G [00:25<00:16, 96.9MB/s]
50%|####9 | 1.45G/2.91G [00:25<00:14, 111MB/s]
50%|##### | 1.46G/2.91G [00:25<00:13, 117MB/s]
51%|##### | 1.48G/2.91G [00:25<00:18, 83.4MB/s]
51%|#####1 | 1.49G/2.91G [00:26<00:30, 50.8MB/s]
52%|#####1 | 1.50G/2.91G [00:26<00:24, 63.0MB/s]
52%|#####2 | 1.52G/2.91G [00:26<00:18, 81.1MB/s]
53%|#####2 | 1.53G/2.91G [00:26<00:19, 75.8MB/s]
53%|#####2 | 1.54G/2.91G [00:26<00:23, 64.0MB/s]
53%|#####3 | 1.55G/2.91G [00:26<00:21, 68.4MB/s]
53%|#####3 | 1.56G/2.91G [00:27<00:19, 75.9MB/s]
54%|#####3 | 1.57G/2.91G [00:27<00:17, 81.5MB/s]
54%|#####4 | 1.58G/2.91G [00:27<00:21, 65.4MB/s]
54%|#####4 | 1.58G/2.91G [00:27<00:26, 53.4MB/s]
55%|#####4 | 1.59G/2.91G [00:27<00:25, 54.9MB/s]
55%|#####4 | 1.60G/2.91G [00:27<00:26, 52.2MB/s]
55%|#####5 | 1.61G/2.91G [00:28<00:21, 64.0MB/s]
56%|#####5 | 1.62G/2.91G [00:28<00:26, 52.8MB/s]
56%|#####5 | 1.62G/2.91G [00:28<00:28, 48.1MB/s]
56%|#####5 | 1.63G/2.91G [00:28<00:38, 35.4MB/s]
56%|#####6 | 1.64G/2.91G [00:28<00:26, 50.8MB/s]
57%|#####6 | 1.65G/2.91G [00:28<00:27, 50.1MB/s]
57%|#####6 | 1.66G/2.91G [00:29<00:25, 53.3MB/s]
57%|#####7 | 1.66G/2.91G [00:29<00:31, 43.1MB/s]
57%|#####7 | 1.67G/2.91G [00:29<00:25, 53.1MB/s]
58%|#####7 | 1.68G/2.91G [00:29<00:31, 42.1MB/s]
58%|#####7 | 1.68G/2.91G [00:29<00:35, 36.9MB/s]
59%|#####8 | 1.70G/2.91G [00:30<00:18, 69.3MB/s]
59%|#####9 | 1.72G/2.91G [00:30<00:15, 84.9MB/s]
59%|#####9 | 1.73G/2.91G [00:30<00:16, 76.2MB/s]
60%|#####9 | 1.74G/2.91G [00:30<00:19, 63.8MB/s]
60%|#####9 | 1.74G/2.91G [00:30<00:17, 69.5MB/s]
60%|###### | 1.75G/2.91G [00:31<00:27, 45.7MB/s]
60%|###### | 1.76G/2.91G [00:31<00:30, 40.3MB/s]
61%|######1 | 1.78G/2.91G [00:31<00:18, 66.7MB/s]
62%|######1 | 1.79G/2.91G [00:31<00:15, 77.7MB/s]
62%|######1 | 1.80G/2.91G [00:31<00:19, 61.4MB/s]
62%|######2 | 1.81G/2.91G [00:31<00:19, 61.3MB/s]
63%|######2 | 1.82G/2.91G [00:32<00:24, 47.4MB/s]
63%|######2 | 1.83G/2.91G [00:32<00:25, 45.4MB/s]
63%|######2 | 1.83G/2.91G [00:32<00:35, 32.3MB/s]
63%|######3 | 1.84G/2.91G [00:32<00:30, 38.2MB/s]
64%|######3 | 1.85G/2.91G [00:33<00:22, 50.3MB/s]
64%|######3 | 1.86G/2.91G [00:33<00:19, 57.8MB/s]
64%|######4 | 1.87G/2.91G [00:33<00:20, 55.8MB/s]
64%|######4 | 1.87G/2.91G [00:33<00:20, 54.3MB/s]
65%|######4 | 1.88G/2.91G [00:33<00:20, 54.2MB/s]
65%|######5 | 1.89G/2.91G [00:33<00:14, 77.6MB/s]
66%|######5 | 1.91G/2.91G [00:33<00:09, 112MB/s]
66%|######6 | 1.92G/2.91G [00:33<00:11, 92.4MB/s]
67%|######6 | 1.94G/2.91G [00:34<00:10, 95.6MB/s]
67%|######6 | 1.95G/2.91G [00:34<00:10, 102MB/s]
67%|######7 | 1.96G/2.91G [00:34<00:10, 99.3MB/s]
68%|######7 | 1.97G/2.91G [00:34<00:10, 95.6MB/s]
68%|######7 | 1.98G/2.91G [00:34<00:11, 86.7MB/s]
68%|######8 | 1.99G/2.91G [00:34<00:10, 92.1MB/s]
69%|######8 | 2.00G/2.91G [00:34<00:14, 69.9MB/s]
69%|######8 | 2.00G/2.91G [00:35<00:15, 61.5MB/s]
69%|######9 | 2.02G/2.91G [00:35<00:14, 67.1MB/s]
70%|######9 | 2.03G/2.91G [00:35<00:14, 65.5MB/s]
70%|######9 | 2.03G/2.91G [00:35<00:19, 49.0MB/s]
70%|####### | 2.04G/2.91G [00:35<00:20, 45.9MB/s]
71%|####### | 2.05G/2.91G [00:35<00:14, 62.3MB/s]
71%|####### | 2.06G/2.91G [00:36<00:16, 56.3MB/s]
71%|####### | 2.07G/2.91G [00:36<00:19, 46.0MB/s]
71%|#######1 | 2.07G/2.91G [00:36<00:16, 55.6MB/s]
71%|#######1 | 2.08G/2.91G [00:36<00:27, 32.9MB/s]
72%|#######1 | 2.08G/2.91G [00:37<00:25, 34.3MB/s]
72%|#######2 | 2.10G/2.91G [00:37<00:18, 48.2MB/s]
72%|#######2 | 2.11G/2.91G [00:37<00:14, 60.5MB/s]
73%|#######2 | 2.12G/2.91G [00:37<00:11, 76.0MB/s]
73%|#######3 | 2.13G/2.91G [00:37<00:14, 56.7MB/s]
73%|#######3 | 2.14G/2.91G [00:37<00:15, 54.7MB/s]
74%|#######3 | 2.14G/2.91G [00:37<00:15, 52.8MB/s]
74%|#######4 | 2.16G/2.91G [00:38<00:13, 61.5MB/s]
74%|#######4 | 2.16G/2.91G [00:38<00:12, 64.0MB/s]
75%|#######4 | 2.17G/2.91G [00:38<00:11, 68.8MB/s]
75%|#######4 | 2.18G/2.91G [00:38<00:11, 70.0MB/s]
75%|#######5 | 2.19G/2.91G [00:38<00:13, 58.3MB/s]
76%|#######5 | 2.20G/2.91G [00:38<00:10, 72.6MB/s]
76%|#######6 | 2.22G/2.91G [00:38<00:07, 98.6MB/s]
77%|#######6 | 2.23G/2.91G [00:39<00:09, 75.4MB/s]
77%|#######6 | 2.24G/2.91G [00:39<00:11, 64.2MB/s]
77%|#######7 | 2.25G/2.91G [00:39<00:08, 80.3MB/s]
78%|#######7 | 2.26G/2.91G [00:39<00:08, 85.0MB/s]
78%|#######7 | 2.27G/2.91G [00:39<00:09, 76.4MB/s]
78%|#######8 | 2.28G/2.91G [00:39<00:09, 74.9MB/s]
79%|#######8 | 2.30G/2.91G [00:40<00:07, 83.4MB/s]
79%|#######9 | 2.31G/2.91G [00:40<00:06, 95.1MB/s]
80%|#######9 | 2.32G/2.91G [00:40<00:12, 51.4MB/s]
80%|######## | 2.33G/2.91G [00:40<00:10, 58.7MB/s]
81%|######## | 2.34G/2.91G [00:40<00:08, 68.9MB/s]
81%|######## | 2.35G/2.91G [00:41<00:08, 68.8MB/s]
81%|########1 | 2.36G/2.91G [00:41<00:10, 58.4MB/s]
82%|########1 | 2.37G/2.91G [00:41<00:07, 75.3MB/s]
82%|########1 | 2.38G/2.91G [00:41<00:07, 73.2MB/s]
82%|########2 | 2.39G/2.91G [00:41<00:10, 53.8MB/s]
82%|########2 | 2.40G/2.91G [00:41<00:11, 47.8MB/s]
83%|########2 | 2.40G/2.91G [00:42<00:12, 42.3MB/s]
83%|########2 | 2.41G/2.91G [00:42<00:14, 38.5MB/s]
83%|########3 | 2.42G/2.91G [00:42<00:08, 58.5MB/s]
84%|########3 | 2.43G/2.91G [00:42<00:07, 72.6MB/s]
84%|########3 | 2.44G/2.91G [00:42<00:06, 78.8MB/s]
84%|########4 | 2.45G/2.91G [00:42<00:05, 87.3MB/s]
85%|########4 | 2.46G/2.91G [00:42<00:05, 80.8MB/s]
85%|########4 | 2.47G/2.91G [00:43<00:08, 54.1MB/s]
85%|########5 | 2.49G/2.91G [00:43<00:06, 71.8MB/s]
86%|########5 | 2.50G/2.91G [00:43<00:05, 82.1MB/s]
86%|########6 | 2.51G/2.91G [00:43<00:06, 71.9MB/s]
86%|########6 | 2.51G/2.91G [00:43<00:06, 60.9MB/s]
87%|########6 | 2.52G/2.91G [00:43<00:05, 71.1MB/s]
87%|########7 | 2.53G/2.91G [00:44<00:06, 66.9MB/s]
88%|########7 | 2.55G/2.91G [00:44<00:05, 70.8MB/s]
88%|########8 | 2.56G/2.91G [00:44<00:04, 82.1MB/s]
88%|########8 | 2.57G/2.91G [00:44<00:05, 71.5MB/s]
89%|########8 | 2.58G/2.91G [00:44<00:06, 58.9MB/s]
89%|########9 | 2.59G/2.91G [00:44<00:04, 77.1MB/s]
90%|########9 | 2.61G/2.91G [00:45<00:03, 97.3MB/s]
90%|######### | 2.62G/2.91G [00:45<00:02, 111MB/s]
91%|######### | 2.63G/2.91G [00:45<00:03, 80.0MB/s]
91%|######### | 2.64G/2.91G [00:45<00:04, 71.0MB/s]
91%|#########1| 2.66G/2.91G [00:45<00:03, 71.9MB/s]
92%|#########1| 2.67G/2.91G [00:45<00:02, 87.3MB/s]
92%|#########2| 2.68G/2.91G [00:46<00:04, 57.4MB/s]
92%|#########2| 2.69G/2.91G [00:46<00:03, 61.0MB/s]
93%|#########2| 2.70G/2.91G [00:46<00:02, 79.6MB/s]
93%|#########3| 2.71G/2.91G [00:46<00:02, 79.9MB/s]
94%|#########3| 2.72G/2.91G [00:46<00:03, 65.3MB/s]
94%|#########3| 2.73G/2.91G [00:46<00:02, 73.6MB/s]
94%|#########4| 2.74G/2.91G [00:47<00:02, 61.1MB/s]
95%|#########4| 2.75G/2.91G [00:47<00:03, 54.7MB/s]
95%|#########4| 2.76G/2.91G [00:47<00:02, 55.5MB/s]
95%|#########4| 2.76G/2.91G [00:47<00:03, 52.8MB/s]
96%|#########5| 2.78G/2.91G [00:47<00:01, 80.1MB/s]
96%|#########5| 2.79G/2.91G [00:47<00:01, 80.6MB/s]
96%|#########6| 2.80G/2.91G [00:47<00:01, 74.1MB/s]
96%|#########6| 2.81G/2.91G [00:48<00:01, 80.6MB/s]
97%|#########6| 2.81G/2.91G [00:48<00:01, 55.8MB/s]
97%|#########6| 2.82G/2.91G [00:48<00:01, 57.5MB/s]
97%|#########7| 2.83G/2.91G [00:48<00:01, 53.5MB/s]
97%|#########7| 2.83G/2.91G [00:48<00:02, 39.8MB/s]
98%|#########7| 2.84G/2.91G [00:49<00:01, 45.3MB/s]
98%|#########7| 2.85G/2.91G [00:49<00:01, 44.5MB/s]
98%|#########8| 2.86G/2.91G [00:49<00:00, 54.8MB/s]
99%|#########8| 2.87G/2.91G [00:49<00:00, 66.1MB/s]
99%|#########9| 2.89G/2.91G [00:49<00:00, 69.4MB/s]
100%|#########9| 2.91G/2.91G [00:49<00:00, 82.5MB/s]
100%|##########| 2.91G/2.91G [00:50<00:00, 62.4MB/s]
PretrainedFiles(lexicon='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lexicon.txt', tokens='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/tokens.txt', lm='/root/.cache/torch/hub/torchaudio/decoder-assets/librispeech-4-gram/lm.bin')
构建解码器¶
在本教程中,我们构建了一个束搜索解码器和一个贪婪解码器以进行比较。
束搜索解码器¶
解码器可以使用工厂函数
ctc_decoder构建。
除了前面提到的组件外,它还接受各种束搜索解码参数和令牌/单词参数。
通过向 None 参数传入 lm,此解码器也可以在不使用语言模型的情况下运行。
from torchaudio.models.decoder import ctc_decoder
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
贪婪解码器¶
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(tokens)
运行推理¶
现在我们已经拥有了数据、声学模型和解码器,可以执行推理。束搜索解码器的输出类型为
torchaudio.models.decoder.CTCHypothesis(),包含预测的 token ID、对应的单词、假设分数以及与 token ID 对应的时间步。回顾一下,与波形对应的转录文本是
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform)
贪婪解码器给出以下结果。
greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
Out:
Transcript: i reily was very much affrayd of showing him howmuch shoktd i wause at some parte of what he seid
WER: 0.38095238095238093
使用束搜索解码器:
beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
Out:
Transcript: i really was very much afraid of showing him how much shocked i was at some part of what he said
WER: 0.047619047619047616
我们发现,使用词典约束束搜索解码器生成的转录结果更为准确,由真实单词组成;而贪婪解码器则可能预测出拼写错误的单词,如"affrayd"和"shoktd"。
时间步对齐¶
请注意,生成的假设的组成部分之一是与时令 ID 对应的时间步。
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
Out:
['|', 'i', '|', 'r', 'e', 'a', 'l', 'l', 'y', '|', 'w', 'a', 's', '|', 'v', 'e', 'r', 'y', '|', 'm', 'u', 'c', 'h', '|', 'a', 'f', 'r', 'a', 'i', 'd', '|', 'o', 'f', '|', 's', 'h', 'o', 'w', 'i', 'n', 'g', '|', 'h', 'i', 'm', '|', 'h', 'o', 'w', '|', 'm', 'u', 'c', 'h', '|', 's', 'h', 'o', 'c', 'k', 'e', 'd', '|', 'i', '|', 'w', 'a', 's', '|', 'a', 't', '|', 's', 'o', 'm', 'e', '|', 'p', 'a', 'r', 't', '|', 'o', 'f', '|', 'w', 'h', 'a', 't', '|', 'h', 'e', '|', 's', 'a', 'i', 'd', '|', '|'] 99
tensor([ 0, 31, 33, 36, 39, 41, 42, 44, 46, 48, 49, 52, 54, 58,
64, 66, 69, 73, 74, 76, 80, 82, 84, 86, 88, 94, 97, 107,
111, 112, 116, 134, 136, 138, 140, 142, 146, 148, 151, 153, 155, 157,
159, 161, 162, 166, 170, 176, 177, 178, 179, 182, 184, 186, 187, 191,
193, 198, 201, 202, 203, 205, 207, 212, 213, 216, 222, 224, 230, 250,
251, 254, 256, 261, 262, 264, 267, 270, 276, 277, 281, 284, 288, 289,
292, 295, 297, 299, 300, 303, 305, 307, 310, 311, 324, 325, 329, 331,
353], dtype=torch.int32) 99
在下图中,我们可视化了相对于原始波形的 token 时间步对齐情况。
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)

Beam Search Decoder Parameters¶
在本节中,我们将更深入地探讨一些不同的参数和权衡。有关可自定义参数的完整列表,请参阅
documentation。
辅助函数¶
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
nbest¶
此参数指示要返回的最佳假设数量,这是贪婪解码器无法实现的属性。例如,通过在之前构建束搜索解码器时将 nbest=3 设置为该值,我们现在可以访问得分最高的前 3 个假设。
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
Out:
i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.8238231825794)
i really was very much afraid of showing him how much shocked i was at some parts of what he said (score: 3697.8580900895563)
i reply was very much afraid of showing him how much shocked i was at some part of what he said (score: 3695.015467226502)
束宽¶
beam_size 参数决定了在每个解码步骤后保留的最佳假设的最大数量。使用更大的束宽(beam size)可以探索更大范围的潜在假设,从而可能生成得分更高的假设,但这在计算上更为昂贵,并且在达到某个临界点后不会带来额外的收益。
在下方的示例中,我们看到随着束宽从 1 增加到 5 再到 50,解码质量有所提升,但请注意,使用束宽 500 产生的输出与束宽 50 相同,却增加了计算时间。
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
Out:
beam size 1 : i you ery much afra of shongut shot i was at some arte what he sad (score: 3144.93; 0.2201 secs)
beam size 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3688.02; 0.0646 secs)
beam size 50 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2912 secs)
beam size 500: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.7506 secs)
束大小 token¶
beam_size_token 参数对应于在解码步骤中为扩展每个假设所考虑的标记(token)数量。探索更多的下一个可能标记会增加潜在假设的范围,但代价是计算量增加。
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
Out:
beam size token 1 : i rely was very much affray of showing him hoch shot i was at some part of what he sed (score: 3584.80; 0.3286 secs)
beam size token 5 : i rely was very much afraid of showing him how much shocked i was at some part of what he said (score: 3694.83; 0.2777 secs)
beam size token 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3696.25; 0.2314 secs)
beam size token 29 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.4088 secs)
束搜索阈值¶
beam_threshold 参数用于在每个解码步骤中修剪存储的假设集,移除得分与最高分假设相差超过 beam_threshold 的假设。需要在选择较小的阈值以修剪更多假设并减少搜索空间,以及选择足够大的阈值以避免修剪合理假设之间取得平衡。
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
Out:
beam threshold 1 : i ila ery much afraid of shongut shot i was at some parts of what he said (score: 3316.20; 0.0337 secs)
beam threshold 5 : i rely was very much afraid of showing him how much shot i was at some parts of what he said (score: 3682.23; 0.0850 secs)
beam threshold 10 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3094 secs)
beam threshold 25 : i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.2865 secs)
语言模型权重¶
lm_weight 参数是分配给语言模型分数的权重,该分数将与声学模型分数累加以确定总体分数。较大的权重会鼓励模型基于语言模型预测下一个词,而较小的权重则会给声学模型分数赋予更大的权重。
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
Out:
lm weight 0 : i rely was very much affraid of showing him ho much shoke i was at some parte of what he seid (score: 3834.05; 0.3061 secs)
lm weight 3.23: i really was very much afraid of showing him how much shocked i was at some part of what he said (score: 3699.82; 0.3269 secs)
lm weight 15 : was there in his was at some of what he said (score: 2918.98; 0.3175 secs)
其他参数¶
可优化的其他参数包括以下内容
word_score: 单词结束时添加的分数unk_score: 要添加的未知词出现分数sil_score: 要添加的静音外观分数log_add: 是否对词典 Trie 涂抹使用 log add
脚本的总运行时间: ( 3 分钟 5.586 秒)