目录

使用 Wav2Vec2 进行语音识别

作者Moto Hira

本教程介绍如何使用 来自 wav2vec 2.0 的预训练模型 [论文]。

概述

语音识别的过程如下所示。

  1. 从音频波形中提取声学特征

  2. 逐帧估计声学特征的类别

  3. 从类概率序列生成假设

Torchaudio 提供了对预训练权重的轻松访问,并且 相关信息,例如预期采样率和类 标签。它们捆绑在一起,可在torchaudio.pipelines()模块。

制备

首先,我们导入必要的包,并获取我们处理的数据。

# %matplotlib inline

import os

import IPython
import matplotlib
import matplotlib.pyplot as plt
import requests
import torch
import torchaudio

matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.__version__)
print(torchaudio.__version__)
print(device)

SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"  # noqa: E501
SPEECH_FILE = "_assets/speech.wav"

if not os.path.exists(SPEECH_FILE):
    os.makedirs("_assets", exist_ok=True)
    with open(SPEECH_FILE, "wb") as file:
        file.write(requests.get(SPEECH_URL).content)

外:

1.12.0
0.12.0
cpu

创建管道

首先,我们将创建一个执行该功能的 Wav2Vec2 模型 提取和分类。

有两种类型的 Wav2Vec2 预训练权重可用 torchaudio 中。针对 ASR 任务进行微调的 API,以及未针对 ASR 任务进行微调的 微调。

Wav2Vec2 (和 HuBERT) 模型以自我监督的方式进行训练。他们 首先使用仅用于表示学习的音频进行训练,然后 针对特定任务进行微调,并带有额外的标签。

无需微调的预训练权重可以进行微调 对于其他下游任务,但本教程不会 覆盖那个。

我们将使用torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H()这里。

有多种型号可供选择torchaudio.pipelines.请查看文档 他们如何接受训练的细节。

bundle 对象提供了实例化 model 和其他 信息。采样率和类标签如下所示。

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H

print("Sample Rate:", bundle.sample_rate)

print("Labels:", bundle.get_labels())

外:

Sample Rate: 16000
Labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')

模型可以按以下方式构建。此过程将自动 获取预先训练的权重并将其加载到模型中。

model = bundle.get_model().to(device)

print(model.__class__)

外:

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth

  0%|          | 0.00/360M [00:00<?, ?B/s]
  2%|1         | 5.41M/360M [00:00<00:06, 56.5MB/s]
  3%|3         | 11.1M/360M [00:00<00:06, 58.2MB/s]
  5%|4         | 17.7M/360M [00:00<00:05, 63.2MB/s]
  7%|6         | 25.0M/360M [00:00<00:05, 68.4MB/s]
  9%|9         | 32.9M/360M [00:00<00:04, 73.6MB/s]
 11%|#1        | 40.7M/360M [00:00<00:04, 76.1MB/s]
 13%|#3        | 48.0M/360M [00:00<00:04, 71.0MB/s]
 15%|#5        | 54.8M/360M [00:00<00:04, 69.0MB/s]
 18%|#7        | 63.6M/360M [00:00<00:04, 75.9MB/s]
 20%|##        | 72.0M/360M [00:01<00:03, 79.3MB/s]
 22%|##2       | 80.4M/360M [00:01<00:03, 81.9MB/s]
 25%|##4       | 88.5M/360M [00:01<00:03, 82.4MB/s]
 27%|##6       | 96.4M/360M [00:01<00:03, 81.9MB/s]
 29%|##8       | 104M/360M [00:01<00:03, 80.0MB/s]
 31%|###1      | 112M/360M [00:01<00:03, 79.8MB/s]
 33%|###3      | 119M/360M [00:01<00:03, 76.8MB/s]
 35%|###5      | 127M/360M [00:01<00:03, 76.0MB/s]
 37%|###7      | 134M/360M [00:01<00:03, 71.7MB/s]
 40%|###9      | 142M/360M [00:01<00:02, 76.3MB/s]
 42%|####1     | 150M/360M [00:02<00:02, 74.7MB/s]
 44%|####3     | 158M/360M [00:02<00:02, 77.9MB/s]
 46%|####5     | 166M/360M [00:02<00:02, 71.7MB/s]
 48%|####8     | 173M/360M [00:02<00:02, 75.0MB/s]
 51%|#####     | 182M/360M [00:02<00:02, 79.4MB/s]
 53%|#####2    | 191M/360M [00:02<00:02, 82.2MB/s]
 55%|#####5    | 199M/360M [00:02<00:02, 84.3MB/s]
 58%|#####7    | 207M/360M [00:02<00:01, 82.5MB/s]
 60%|#####9    | 215M/360M [00:02<00:01, 80.6MB/s]
 62%|######2   | 224M/360M [00:03<00:01, 84.2MB/s]
 65%|######4   | 233M/360M [00:03<00:01, 84.6MB/s]
 67%|######7   | 241M/360M [00:03<00:01, 86.5MB/s]
 69%|######9   | 250M/360M [00:03<00:01, 85.9MB/s]
 72%|#######1  | 258M/360M [00:03<00:01, 82.6MB/s]
 74%|#######3  | 266M/360M [00:03<00:01, 82.2MB/s]
 76%|#######6  | 274M/360M [00:03<00:01, 80.1MB/s]
 78%|#######8  | 282M/360M [00:03<00:00, 82.3MB/s]
 81%|########  | 290M/360M [00:03<00:00, 74.9MB/s]
 83%|########2 | 298M/360M [00:04<00:00, 76.4MB/s]
 85%|########4 | 306M/360M [00:04<00:00, 77.6MB/s]
 87%|########7 | 315M/360M [00:04<00:00, 82.0MB/s]
 90%|########9 | 322M/360M [00:04<00:00, 81.8MB/s]
 92%|#########1| 330M/360M [00:04<00:00, 79.6MB/s]
 94%|#########4| 339M/360M [00:04<00:00, 82.1MB/s]
 97%|#########6| 348M/360M [00:04<00:00, 85.5MB/s]
 99%|#########8| 356M/360M [00:04<00:00, 87.1MB/s]
100%|##########| 360M/360M [00:04<00:00, 79.1MB/s]
<class 'torchaudio.models.wav2vec2.model.Wav2Vec2Model'>

加载数据

我们将使用来自 VOiCES 的语音数据 数据集,该数据集在 创意 Commos BY 4.0。

IPython.display.Audio(SPEECH_FILE)


为了加载数据,我们使用torchaudio.load().

如果采样率与管道预期的采样率不同,则 我们可以使用torchaudio.functional.resample()进行重新采样。

注意

waveform, sample_rate = torchaudio.load(SPEECH_FILE)
waveform = waveform.to(device)

if sample_rate != bundle.sample_rate:
    waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

提取声学特征

下一步是从音频中提取声学特征。

注意

针对 ASR 任务进行微调的 Wav2Vec2 模型可以执行功能 提取和分类一步完成,但为了 教程中,我们还在此处展示了如何执行特征提取。

with torch.inference_mode():
    features, _ = model.extract_features(waveform)

返回的 features 是一个 tensor 列表。每个张量都是 transformer 层。

fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))
for i, feats in enumerate(features):
    ax[i].imshow(feats[0].cpu())
    ax[i].set_title(f"Feature from transformer layer {i+1}")
    ax[i].set_xlabel("Feature dimension")
    ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
变压器层 1 的特征, 变压器层 2 的特征, 变压器层 3 的特征, 变压器层 4 的特征, 变压器层 5 的特征, 变压器层 6 的特征, 变压器层 7 的特征, 变压器层 8 的特征, 变压器层 9 的特征, 变压器层 10 的特征, 变压器层 11 的特征, 变压器层 12 的特征

特征分类

提取声学特征后,下一步是进行分类 它们被归类为一组类别。

Wav2Vec2 模型提供了执行特征提取和 分类一步完成。

输出采用 logits 的形式。它不是以 概率。

让我们想象一下。

plt.imshow(emission[0].cpu().T)
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
print("Class labels:", bundle.get_labels())
分类结果

外:

Class labels: ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')

我们可以看到,某些标签有很强的迹象 时间线。

生成转录

根据标签概率序列,现在我们要生成 成绩单。生成假设的过程通常称为 “解码”。

解码比简单分类更复杂,因为 解码在某个时间步会受到周围环境的影响 观察。

例如,以 和 之类的单词为例。即使他们的 先验概率分布是不同的(在典型的对话中,会比 ) 更频繁地发生 ),以准确 生成成绩单,例如 , 解码过程必须推迟最终决定,直到它看到 足够的背景。nightknightnightknightknighta knight with a sword

提出了许多解码技术,它们需要外部 资源,例如单词词典和语言模型。

在本教程中,为简单起见,我们将执行贪婪 解码,它不依赖于此类外部组件,并且只需 在每个时间步长中选取最佳假设。因此,上下文 信息,并且只能生成一个成绩单。

我们首先定义贪婪解码算法。

class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank=0):
        super().__init__()
        self.labels = labels
        self.blank = blank

    def forward(self, emission: torch.Tensor) -> str:
        """Given a sequence emission over labels, get the best path string
        Args:
          emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
          str: The resulting transcript
        """
        indices = torch.argmax(emission, dim=-1)  # [num_seq,]
        indices = torch.unique_consecutive(indices, dim=-1)
        indices = [i for i in indices if i != self.blank]
        return "".join([self.labels[i] for i in indices])

现在创建 decoder 对象并解码转录文本。

decoder = GreedyCTCDecoder(labels=bundle.get_labels())
transcript = decoder(emission[0])

让我们检查结果并再次收听音频。

print(transcript)
IPython.display.Audio(SPEECH_FILE)

外:

I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|


ASR 模型使用称为连接主义时间分类 (CTC) 的损失函数进行微调。 CTC 丢失的详细信息在此处解释。在 CTC 中,空白令牌 (ε) 是 特殊标记,它表示前一个符号的重复。在 decoding,这些都会被简单地忽略。

结论

在本教程中,我们了解了如何使用torchaudio.pipelines自 执行声学特征提取和语音识别。构建 一个模型并获得 emission 短至两行。

model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
emission = model(waveforms, ...)

脚本总运行时间:(0 分 10.757 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源