目录

模型集成

创建时间: 2023年3月15日 |上次更新时间:2024 年 1 月 16 日 |上次验证: Nov 05, 2024

本教程说明了如何使用 .torch.vmap

什么是模型集成?

模型集成将来自多个模型的预测组合在一起。 传统上,这是通过在某些 inputs 上单独运行每个模型来完成的 然后组合预测。但是,如果您使用 相同的架构,那么也许可以将它们组合在一起 用。 是一个函数转换,它将函数映射到 维度。它的用例之一是消除 for 循环,并通过矢量化加速它们。torch.vmapvmap

让我们演示如何使用一组简单的 MLP 来做到这一点。

注意

本教程需要 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple MLP
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

让我们生成一批虚拟数据,并假装我们正在处理 一个 MNIST 数据集。因此,虚拟图像是 28 x 28,并且我们有一个 大小为 64 的小批量。此外,假设我们想要合并预测 来自 10 种不同的型号。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我们有几个选项可用于生成预测。也许我们想 为每个模型提供不同的随机小批量数据。或者 也许我们想通过每个模型运行相同的小批量数据(例如 如果我们要测试不同模型初始化的效果)。

选项 1:每个模型的不同小批量

minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

选项 2:相同的小批量

minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

用于矢量化融合vmap

让我们 use 来加速 for 循环。我们必须先准备模型 用于 。vmapvmap

首先,让我们通过将每个 参数。例如,具有 shape ;我们是 将 10 个模型中的每一个堆叠起来,以产生一个大的 形状的重量 .model[i].fc1.weight[784, 128].fc1.weight[10, 784, 128]

PyTorch 提供了方便的功能来执行 这。torch.func.stack_module_state

from torch.func import stack_module_state

params, buffers = stack_module_state(models)

接下来,我们需要定义一个 over 函数。该函数应 给定参数、缓冲区和输入,使用它们运行模型 参数、缓冲区和输入。我们将用于帮助:vmaptorch.func.functional_call

from torch.func import functional_call
import copy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

选项 1:为每个模型使用不同的小批量获取预测。

默认情况下,将所有输入的第一个维度的函数映射到 传入的函数。使用 后,每个 and 缓冲区num_models在 front 和 minibatches 的维度大小为 'num_models”。vmapstack_module_stateparams

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

from torch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]

选项 2:使用相同的小批量数据获取预测。

vmap具有指定要映射的维度的参数。 通过使用 ,我们告诉我们希望同一个 minibatch 应用于所有 10 款车型。in_dimsNonevmap

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

快速说明:函数的类型存在限制 由 转换。最适合转换的函数是纯函数 functions:输出仅由输入决定的函数 没有副作用(例如突变)。 无法处理突变 的任意 Python 数据结构,但它能够就地处理许多 PyTorch 操作。vmapvmap

性能

对性能数据感到好奇?以下是数字的样子。

from torch.utils.benchmark import Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fd43694c670>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  2.61 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fd43694ca90>
vmap(fmodel)(params, buffers, minibatches)
  894.99 us
  1 measurement, 100 runs , 1 thread

使用 !vmap

一般来说,使用进行矢量化应该比运行函数更快 在 for 循环中,与手动批处理竞争。有一些例外 不过,就像我们没有为特定的 操作,或者底层内核未针对较旧的硬件进行优化 (GPU)。如果您看到任何这些情况,请通过打开一个 issue 来告知我们 在 GitHub 上。vmapvmap

脚本总运行时间:(0 分 0.896 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源