注意
点击 这里 下载完整示例代码
语音识别与Wav2Vec2¶
作者: Moto Hira
本教程展示了如何使用预训练的wav2vec 2.0模型进行语音识别 [论文]。
概述¶
语音识别的过程如下所示。
从音频波形中提取声学特征
逐帧估计声学特征的类别
从类别概率序列中生成假设
Torchaudio 提供了对预训练权重及相关信息(如预期采样率和类别标签)的便捷访问。它们被打包在一起,并在
torchaudio.pipelines() 模块下提供。
准备¶
首先,我们导入必要的包,并获取我们要处理的数据。
# %matplotlib inline
import os
import IPython
import matplotlib
import matplotlib.pyplot as plt
import requests
import torch
import torchaudio
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__)
print(torchaudio.__version__)
print(device)
SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SPEECH_FILE = "_assets/speech.wav"
if not os.path.exists(SPEECH_FILE):
os.makedirs("_assets", exist_ok=True)
with open(SPEECH_FILE, "wb") as file:
file.write(requests.get(SPEECH_URL).content)
Out:
1.12.0
0.12.0
cpu
创建一个流程¶
首先,我们将创建一个 Wav2Vec2 模型,该模型负责特征提取和分类。
在 torchaudio 中有两种类型的 Wav2Vec2 预训练权重可供使用。一种是针对 ASR 任务进行微调的,另一种是没有进行微调的。
Wav2Vec2(以及HuBERT)模型是以自监督的方式进行训练的。它们首先仅使用音频进行训练以学习表示,然后通过附加标签对特定任务进行微调。
未进行微调的预训练权重也可以用于其他下游任务的微调,但本教程不涵盖这一点。
我们在这里将使用 torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H()。
有多种模型可供选择,如
torchaudio.pipelines。请查阅文档以了解它们的训练细节。
该 bundle 对象提供了实例化模型和其他信息的接口。采样率和类别标签的获取方式如下。
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
print("Sample Rate:", bundle.sample_rate)
print("Labels:", bundle.get_labels())
Out:
Sample Rate: 16000
Labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
模型可以按以下方式构建。此过程将自动获取预训练的权重并将其加载到模型中。
model = bundle.get_model().to(device)
print(model.__class__)
Out:
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
0%| | 0.00/360M [00:00<?, ?B/s]
2%|1 | 5.41M/360M [00:00<00:06, 56.5MB/s]
3%|3 | 11.1M/360M [00:00<00:06, 58.2MB/s]
5%|4 | 17.7M/360M [00:00<00:05, 63.2MB/s]
7%|6 | 25.0M/360M [00:00<00:05, 68.4MB/s]
9%|9 | 32.9M/360M [00:00<00:04, 73.6MB/s]
11%|#1 | 40.7M/360M [00:00<00:04, 76.1MB/s]
13%|#3 | 48.0M/360M [00:00<00:04, 71.0MB/s]
15%|#5 | 54.8M/360M [00:00<00:04, 69.0MB/s]
18%|#7 | 63.6M/360M [00:00<00:04, 75.9MB/s]
20%|## | 72.0M/360M [00:01<00:03, 79.3MB/s]
22%|##2 | 80.4M/360M [00:01<00:03, 81.9MB/s]
25%|##4 | 88.5M/360M [00:01<00:03, 82.4MB/s]
27%|##6 | 96.4M/360M [00:01<00:03, 81.9MB/s]
29%|##8 | 104M/360M [00:01<00:03, 80.0MB/s]
31%|###1 | 112M/360M [00:01<00:03, 79.8MB/s]
33%|###3 | 119M/360M [00:01<00:03, 76.8MB/s]
35%|###5 | 127M/360M [00:01<00:03, 76.0MB/s]
37%|###7 | 134M/360M [00:01<00:03, 71.7MB/s]
40%|###9 | 142M/360M [00:01<00:02, 76.3MB/s]
42%|####1 | 150M/360M [00:02<00:02, 74.7MB/s]
44%|####3 | 158M/360M [00:02<00:02, 77.9MB/s]
46%|####5 | 166M/360M [00:02<00:02, 71.7MB/s]
48%|####8 | 173M/360M [00:02<00:02, 75.0MB/s]
51%|##### | 182M/360M [00:02<00:02, 79.4MB/s]
53%|#####2 | 191M/360M [00:02<00:02, 82.2MB/s]
55%|#####5 | 199M/360M [00:02<00:02, 84.3MB/s]
58%|#####7 | 207M/360M [00:02<00:01, 82.5MB/s]
60%|#####9 | 215M/360M [00:02<00:01, 80.6MB/s]
62%|######2 | 224M/360M [00:03<00:01, 84.2MB/s]
65%|######4 | 233M/360M [00:03<00:01, 84.6MB/s]
67%|######7 | 241M/360M [00:03<00:01, 86.5MB/s]
69%|######9 | 250M/360M [00:03<00:01, 85.9MB/s]
72%|#######1 | 258M/360M [00:03<00:01, 82.6MB/s]
74%|#######3 | 266M/360M [00:03<00:01, 82.2MB/s]
76%|#######6 | 274M/360M [00:03<00:01, 80.1MB/s]
78%|#######8 | 282M/360M [00:03<00:00, 82.3MB/s]
81%|######## | 290M/360M [00:03<00:00, 74.9MB/s]
83%|########2 | 298M/360M [00:04<00:00, 76.4MB/s]
85%|########4 | 306M/360M [00:04<00:00, 77.6MB/s]
87%|########7 | 315M/360M [00:04<00:00, 82.0MB/s]
90%|########9 | 322M/360M [00:04<00:00, 81.8MB/s]
92%|#########1| 330M/360M [00:04<00:00, 79.6MB/s]
94%|#########4| 339M/360M [00:04<00:00, 82.1MB/s]
97%|#########6| 348M/360M [00:04<00:00, 85.5MB/s]
99%|#########8| 356M/360M [00:04<00:00, 87.1MB/s]
100%|##########| 360M/360M [00:04<00:00, 79.1MB/s]
<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>
加载数据¶
我们将使用来自VOiCES数据集的语音数据,该数据集授权于 Creative Commos BY 4.0。
IPython.display.Audio(SPEECH_FILE)
要加载数据,我们使用 torchaudio.load()。
如果采样率与管道预期的不同,那么
我们可以使用 torchaudio.functional.resample() 进行重采样。
注意
torchaudio.functional.resample()也可以在CUDA张量上运行。当对同一组采样率进行多次重采样时, 使用
torchaudio.transforms.Resample()可能会提高性能。
waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
提取音频特征¶
下一步是从音频中提取声学特征。
注意
为ASR任务微调的Wav2Vec2模型可以一步完成特征提取和分类,但为了教程的完整性,我们在这里也展示如何进行特征提取。
with torch.inference_mode():
features, _ = model.extract_features(waveform)
返回的特征是一个张量列表。每个张量是某个 transformer 层的输出。
fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
ax[i].imshow(feats[0].cpu())
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()

特征分类¶
一旦提取了声学特征,下一步就是将它们分类到一组类别中。
Wav2Vec2 模型提供了一种在一步中完成特征提取和分类的方法。
with torch.inference_mode():
emission, _ = model(waveform)
输出的形式是logits。它不是以概率的形式呈现的。
让我们来可视化一下。
plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
print("Class labels:", bundle.get_labels())

Out:
Class labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
我们可以看到,在时间轴上某些标签有很强的指示性。
生成字幕¶
从标签概率序列中,我们现在想要生成文本。生成假设的过程通常被称为“解码”。
解码比简单的分类更为复杂,因为某些时间步的解码可能会受到周围观测值的影响。
例如,取一个词,比如 night 和 knight。即使它们的
先验概率分布不同(在典型的对话中,night 出现的频率远高于 knight),为了准确
地生成包含 knight 的文本,例如 a knight with a sword,
解码过程必须推迟最终决定,直到看到
足够的上下文。
已经提出了许多解码技术,它们需要外部资源,例如词典和语言模型。
在本教程中,为了简化起见,我们将执行贪心解码,这种解码方式不依赖于这些外部组件,而是在每个时间步简单地选择最佳假设。因此,上下文信息未被使用,只能生成一个转录文本。
我们首先从定义贪心解码算法开始。
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> str:
"""Given a sequence emission over labels, get the best path string
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
return "".join([self.labels[i] for i in indices])
现在创建解码器对象并解码转录文本。
decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])
让我们查看结果,并再次聆听音频。
print(transcript)
IPython.display.Audio(SPEECH_FILE)
Out:
I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|
ASR模型使用称为连接时序分类(CTC)的损失函数进行微调。 CTC损失的详细信息在 这里解释。在CTC中,空白令牌(ϵ)是一个 特殊令牌,表示前一个符号的重复。在 解码过程中,这些被简单地忽略。
结论¶
在本教程中,我们学习了如何使用 torchaudio.pipelines 来
执行声学特征提取和语音识别。构建模型并获取发射结果仅需两行代码。
model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
emission = model(waveforms, ...)
脚本的总运行时间: ( 0 分钟 10.757 秒)