注意
单击此处下载完整的示例代码
每个样本的梯度¶
创建时间: 2023年3月15日 |上次更新时间:2024 年 4 月 24 日 |上次验证: Nov 05, 2024
这是什么?¶
每样本梯度计算是计算每个 batch 数据中的 sample。它是差分隐私中的一个有用量, 元学习和优化研究。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple CNN and loss function:
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
让我们生成一批虚拟数据,并假设我们正在使用 MNIST 数据集。 虚拟图像为 28 x 28,我们使用大小为 64 的小批量。
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
在常规模型训练中,用户可以通过模型 然后调用 .backward() 来计算梯度。这将生成一个 整个小批量的 'average' 梯度:
model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model
loss = loss_fn(predictions, targets)
loss.backward() # back propagate the 'average' gradient of this mini-batch
与上述方法相反,每样本梯度计算为 相当于:
对于每个单独的数据样本,执行 FORWARD 和 BACKWARD pass 获取单个 (每个样本) 梯度。
def compute_grad(sample, target):
sample = sample.unsqueeze(0) # prepend batch dimension for processing
target = target.unsqueeze(0)
prediction = model(sample)
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
sample_grads[0]
是 model.conv1.weight 的每 sample-grad。 是;注意怎么有一个
梯度,每个样品,在批次中总共 64 个。model.conv1.weight.shape
[32, 1, 3, 3]
print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])
Per-sample-grads,使用函数转换的有效方法¶
我们可以通过使用函数变换有效地计算每个样本的梯度。
函数转换 API 对函数进行转换。
我们的策略是定义一个计算损失的函数,然后应用
transforms 来构造一个计算每个样本梯度的函数。torch.func
我们将使用该函数将 like 函数视为 like。torch.func.functional_call
nn.Module
首先,让我们将状态提取到两个字典中,
参数和缓冲区。我们将分离它们,因为我们不会使用
常规 PyTorch autograd(例如 Tensor.backward()、torch.autograd.grad)。model
from torch.func import functional_call, vmap, grad
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
接下来,让我们定义一个函数来计算模型的损失,给定一个 单个输入,而不是一批输入。重要的是,这个 function 接受参数、输入和目标,因为我们将 正在改造他们。
注意 - 因为模型最初是为了处理批处理而编写的,所以我们将
用于添加批次维度。torch.unsqueeze
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
现在,让我们使用转换来创建一个新函数,该函数计算
相对于 的第一个参数的梯度(即 )。grad
compute_loss
params
ft_compute_grad = grad(compute_loss)
该函数计算单个
(sample, target) 对。我们可以用它来计算梯度
在整批样品和靶标上。请注意,因为我们希望映射
data 和 targets 的第 0 个维度,并使用相同的 和
buffers 的 buffer。ft_compute_grad
vmap
in_dims=(None, None, 0, 0)
ft_compute_grad
params
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
最后,让我们使用转换后的函数来计算每个样本的梯度:
我们可以仔细检查结果是否使用 并匹配
手工处理的结果:grad
vmap
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)
快速说明:函数的类型存在限制
由 转换。最适合转换的函数是纯函数
functions:输出仅由输入决定的函数,
并且没有副作用(例如突变)。 无法处理
任意 Python 数据结构的 mutation,但它能够处理许多
就地 PyTorch 操作。vmap
vmap
性能比较¶
好奇 的性能比较如何?vmap
目前,在较新的 GPU (如 A100)上获得的最佳效果 (Ampere) 中,我们看到此示例的加速高达 25 倍,但这里是 我们的构建计算机上的一些结果:
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
second_res = second.times[0]
first_res = first.times[0]
gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
final_gain = gain*100
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)
print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f104824fc40>
compute_sample_grads(data, targets)
102.96 ms
1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f104474fee0>
ft_compute_sample_grad(params, buffers, data, targets)
8.62 ms
1 measurement, 100 runs , 1 thread
Performance delta: 1094.4079 percent improvement with vmap
还有其他优化的解决方案(如 https://github.com/pytorch/opacus)
到 PyTorch 中计算每个样本的梯度,其性能也优于
朴素的方法。但是,创作并给我们一个
不错的加速。vmap
grad
一般来说,使用进行矢量化应该比运行函数更快
在 for 循环中,与手动批处理竞争。有一些例外
不过,就像我们没有为特定的
操作,或者底层内核未针对较旧的硬件进行优化
(GPU)。如果您看到任何这些情况,请通过打开一个 issue 来告知我们
在 GitHub 上。vmap
vmap
脚本总运行时间:(0 分 12.286 秒)