注意
单击此处下载完整的示例代码
使用 Wav2Vec2 进行语音识别¶
作者: Moto Hira
本教程介绍如何使用 来自 wav2vec 2.0 的预训练模型 [论文]。
概述¶
语音识别的过程如下所示。
从音频波形中提取声学特征
逐帧估计声学特征的类别
从类概率序列生成假设
Torchaudio 提供了对预训练权重的轻松访问,并且
相关信息,例如预期采样率和类
标签。它们捆绑在一起,可在 module 下使用。torchaudio.pipelines
制备¶
首先,我们导入必要的包,并获取我们处理的数据。
# %matplotlib inline
import os
import torch
import torchaudio
import requests
import matplotlib
import matplotlib.pyplot as plt
import IPython
matplotlib.rcParams['figure.figsize'] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE):
os.makedirs('_assets', exist_ok=True)
with open(SPEECH_FILE, 'wb') as file:
file.write(requests.get(SPEECH_URL).content)
外:
1.10.0+cpu
0.10.0+cpu
cpu
创建管道¶
首先,我们将创建一个执行该功能的 Wav2Vec2 模型 提取和分类。
有两种类型的 Wav2Vec2 预训练权重可用 torchaudio 中。针对 ASR 任务进行微调的 API,以及未针对 ASR 任务进行微调的 微调。
Wav2Vec2 (和 HuBERT) 模型以自我监督的方式进行训练。他们 首先使用仅用于表示学习的音频进行训练,然后 针对特定任务进行微调,并带有额外的标签。
无需微调的预训练权重可以进行微调 对于其他下游任务,但本教程不会 覆盖那个。
我们将使用torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H()
这里。
有多种型号可供选择torchaudio.pipelines
.请查看文档
他们如何接受训练的细节。
bundle 对象提供了实例化 model 和其他 信息。采样率和类标签如下所示。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
外:
Sample Rate: 16000
Labels: ('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
模型可以按以下方式构建。此过程将自动 获取预先训练的权重并将其加载到模型中。
model = bundle.get_model().to(device)
print(model.__class__)
外:
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
0%| | 0.00/360M [00:00<?, ?B/s]
1%|1 | 4.09M/360M [00:00<00:09, 38.0MB/s]
3%|2 | 9.09M/360M [00:00<00:08, 41.6MB/s]
4%|3 | 14.1M/360M [00:00<00:07, 45.8MB/s]
7%|6 | 24.3M/360M [00:00<00:05, 68.6MB/s]
9%|9 | 34.1M/360M [00:00<00:05, 62.3MB/s]
11%|#1 | 40.3M/360M [00:00<00:07, 43.6MB/s]
13%|#2 | 45.2M/360M [00:01<00:09, 34.5MB/s]
14%|#3 | 49.1M/360M [00:01<00:11, 28.2MB/s]
15%|#5 | 55.6M/360M [00:01<00:09, 34.5MB/s]
17%|#6 | 60.2M/360M [00:01<00:08, 37.3MB/s]
20%|## | 73.0M/360M [00:01<00:05, 58.6MB/s]
23%|##2 | 82.8M/360M [00:01<00:04, 68.8MB/s]
25%|##5 | 90.5M/360M [00:01<00:04, 66.8MB/s]
27%|##7 | 97.6M/360M [00:02<00:04, 58.5MB/s]
29%|##8 | 104M/360M [00:02<00:04, 57.3MB/s]
30%|### | 110M/360M [00:02<00:05, 49.5MB/s]
32%|###1 | 115M/360M [00:02<00:05, 45.9MB/s]
35%|###5 | 128M/360M [00:02<00:03, 62.9MB/s]
37%|###7 | 134M/360M [00:02<00:04, 55.6MB/s]
40%|###9 | 144M/360M [00:02<00:03, 65.3MB/s]
42%|####2 | 152M/360M [00:03<00:03, 71.0MB/s]
44%|####4 | 160M/360M [00:03<00:03, 61.1MB/s]
46%|####6 | 166M/360M [00:03<00:05, 40.1MB/s]
50%|##### | 181M/360M [00:03<00:03, 61.7MB/s]
53%|#####3 | 192M/360M [00:03<00:02, 66.4MB/s]
57%|#####6 | 205M/360M [00:03<00:02, 80.2MB/s]
59%|#####9 | 214M/360M [00:04<00:02, 61.2MB/s]
62%|######1 | 223M/360M [00:04<00:02, 56.8MB/s]
64%|######3 | 229M/360M [00:04<00:03, 42.7MB/s]
67%|######6 | 240M/360M [00:04<00:02, 50.5MB/s]
71%|####### | 255M/360M [00:04<00:01, 57.6MB/s]
72%|#######2 | 261M/360M [00:05<00:02, 47.0MB/s]
75%|#######4 | 269M/360M [00:05<00:01, 52.8MB/s]
76%|#######6 | 275M/360M [00:05<00:01, 49.3MB/s]
80%|#######9 | 288M/360M [00:05<00:01, 62.7MB/s]
82%|########1 | 294M/360M [00:05<00:01, 55.5MB/s]
83%|########3 | 300M/360M [00:05<00:01, 46.0MB/s]
85%|########4 | 305M/360M [00:06<00:01, 45.1MB/s]
89%|########8 | 320M/360M [00:06<00:00, 67.2MB/s]
91%|######### | 327M/360M [00:06<00:00, 68.3MB/s]
93%|#########3| 336M/360M [00:06<00:00, 71.6MB/s]
96%|#########6| 347M/360M [00:06<00:00, 83.0MB/s]
99%|#########8| 356M/360M [00:06<00:00, 67.2MB/s]
100%|##########| 360M/360M [00:06<00:00, 55.7MB/s]
<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>
加载数据¶
我们将使用来自 VOiCES 的语音数据 数据集,该数据集在 创意 Commos BY 4.0。
IPython.display.Audio(SPEECH_FILE)
为了加载数据,我们使用torchaudio.load()
.
如果采样率与管道预期的采样率不同,则
我们可以使用torchaudio.functional.resample()
进行重新采样。
注意
torchaudio.functional.resample()
也适用于 CUDA 张量。当对同一组采样率执行多次重采样时, 用
torchaudio.transforms.Resample()
可能会提高性能。
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
提取声学特征¶
下一步是从音频中提取声学特征。
注意
针对 ASR 任务进行微调的 Wav2Vec2 模型可以执行功能 提取和分类一步完成,但为了 教程中,我们还在此处展示了如何执行特征提取。
with torch.inference_mode():
features, _ = model.extract_features(waveform)
返回的 features 是一个 tensor 列表。每个张量都是 transformer 层。
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()

特征分类¶
提取声学特征后,下一步是进行分类 它们被归类为一组类别。
Wav2Vec2 模型提供了执行特征提取和 分类一步完成。
with torch.inference_mode():
emission, _ = model(waveform)
输出采用 logits 的形式。它不是以 概率。
让我们想象一下。
plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
print("Class labels:", bundle.get_labels())

外:
Class labels: ('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
我们可以看到,某些标签有很强的迹象 时间线。
生成转录¶
根据标签概率序列,现在我们要生成 成绩单。生成假设的过程通常称为 “解码”。
解码比简单分类更复杂,因为 解码在某个时间步会受到周围环境的影响 观察。
例如,以 和 之类的单词为例。即使他们的
先验概率分布是不同的(在典型的对话中,会比 ) 更频繁地发生 ),以准确
生成成绩单,例如 ,
解码过程必须推迟最终决定,直到它看到
足够的背景。night
knight
night
knight
knight
a knight with a sword
提出了许多解码技术,它们需要外部 资源,例如单词词典和语言模型。
在本教程中,为简单起见,我们将执行贪婪 解码,它不依赖于此类外部组件,并且只需 在每个时间步长中选取最佳假设。因此,上下文 信息,并且只能生成一个成绩单。
我们首先定义贪婪解码算法。
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return ''.join([self.labels[i] for i in indices])
现在创建 decoder 对象并解码转录文本。
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
让我们检查结果并再次收听音频。
print(transcript)
IPython.display.Audio(SPEECH_FILE)
外:
I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|
ASR 模型使用称为连接主义时间分类 (CTC) 的损失函数进行微调。 CTC 丢失的详细信息在此处解释。在 CTC 中,空白令牌 (ε) 是 特殊标记,它表示前一个符号的重复。在 decoding,这些都会被简单地忽略。
结论¶
在本教程中,我们了解了如何使用torchaudio.pipelines
自
执行声学特征提取和语音识别。构建
一个模型并获得 emission 短至两行。
model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
emission = model(waveforms, ...)
脚本总运行时间:(0 分 14.667 秒)