将相邻卷积层和批量范数层融合在一起通常是 Inference-time 优化以改进运行时。它通常是实现的 通过完全消除 Batch Norm 层并更新权重 和前一个卷积 [0] 的偏差。但是,此技术并非如此 适用于训练模型。
在本教程中,我们将展示一种不同的技术来融合这两个层 可以在训练期间应用。而不是改进的运行时, 此优化的目标是减少内存使用量。
这种优化背后的想法是看到卷积和 Batch Norm(以及许多其他 OPS)需要保存其输入的副本 during forward 表示 backward pass。对于大型 batch 大小,这些保存的输入负责大部分内存使用, 因此能够避免为每个 卷积 Batch Norm 对可以显著减少。
在本教程中,我们通过组合卷积来避免这种额外的分配 并将 Norm 批处理到单个层中(作为自定义函数)。在前进 在这个组合层中,我们按原样执行普通卷积和批量范数, 唯一的区别是,我们只将输入保存到卷积中。 要获得 batch norm 的输入,这是向后通过 it,我们在 backward pass 期间再次重新计算 convolution forward。
请务必注意,此优化的使用是视情况而定的。 尽管 (通过避免保存一个缓冲区) 我们总是减少在 前向传递结束时,存在分配峰值内存 实际上可能不会减少。有关更多详细信息,请参阅最后一部分。
为简单起见,在本教程中,我们对 bias=False、stride=1、padding=0、dilation=1、 和 groups=1 表示 Conv2D。对于 BatchNorm2D,我们对 eps=1e-3、momentum=0.1、affine=False 和 track_running_statistics=False 进行硬编码。另一个小区别 是我们在计算中平方根之外的分母中添加 epsilon 的批次规范。
实现自定义函数需要我们实现向后 我们自己。在这种情况下,我们需要 Conv2D 的两个反向公式 和 BatchNorm2D 的 BatchNorm2D 中。最终,我们将它们链接在一起,形成我们的统一 backward 函数,但下面我们首先将它们实现为自己的 自定义函数,以便我们可以单独验证它们的正确性
import torch
from torch.autograd.function import once_differentiable
import torch.nn.functional as F
def convolution_backward(grad_out, X, weight):
grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)
grad_X = F.conv_transpose2d(grad_out, weight)
return grad_X, grad_input
class Conv2D(torch.autograd.Function):
def forward(ctx, X, weight):
ctx.save_for_backward(X, weight)
return F.conv2d(X, weight)
# Use @once_differentiable by default unless we intend to double backward
def backward(ctx, grad_out):
X, weight = ctx.saved_tensors
return convolution_backward(grad_out, X, weight)
使用 进行测试时,使用双精度很重要gradcheck
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Conv2D.apply, (X, weight))
Batch Norm 的反向公式实现¶
Batch Norm 有两种模式:training 和 mode。在训练模式下
样本统计量是 inputs 的函数。在 mode 中,
我们使用保存的 Running 统计数据,它不是 inputs 的函数。
def unsqueeze_all(t):
# Helper function to ``unsqueeze`` all the dimensions that we reduce over
return t[None, :, None, None]
def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):
# We use the formula: ``out = (X - mean(X)) / (sqrt(var(X)) + eps)``
# in batch norm 2D forward. To simplify our derivation, we follow the
# chain rule and compute the gradients as follows before accumulating
# them all into a final grad_input.
# 1) ``grad of out wrt var(X)`` * ``grad of var(X) wrt X``
# 2) ``grad of out wrt mean(X)`` * ``grad of mean(X) wrt X``
# 3) ``grad of out wrt X in the numerator`` * ``grad of X wrt X``
# We then rewrite the formulas to use as few extra buffers as possible
tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))
tmp *= -1
d_denom = tmp / (sqrt_var + eps)**2 # ``d_denom = -num / denom**2``
# It is useful to delete tensors when you no longer need them with ``del``
# For example, we could've done ``del tmp`` here because we won't use it later
# In this case, it's not a big difference because ``tmp`` only has size of (C,)
# The important thing is avoid allocating NCHW-sized tensors unnecessarily
d_var = d_denom / (2 * sqrt_var) # ``denom = torch.sqrt(var) + eps``
# Compute ``d_mean_dx`` before allocating the final NCHW-sized grad_input buffer
d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)
d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)
# ``d_mean_dx`` has already been reassigned to a C-sized buffer so no need to worry
# ``(1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)``
grad_input = X * unsqueeze_all(d_var * N)
grad_input += unsqueeze_all(-d_var * sum)
grad_input *= 2 / ((N - 1) * N)
# (2) mean (see above)
grad_input += d_mean_dx
# (3) Add 'grad_out / <factor>' without allocating an extra buffer
grad_input *= unsqueeze_all(sqrt_var + eps)
grad_input += grad_out
grad_input /= unsqueeze_all(sqrt_var + eps) # ``sqrt_var + eps > 0!``
return grad_input
class BatchNorm(torch.autograd.Function):
def forward(ctx, X, eps=1e-3):
# Don't save ``keepdim`` values for backward
sum = X.sum(dim=(0, 2, 3))
var = X.var(unbiased=True, dim=(0, 2, 3))
N = X.numel() / X.size(1)
sqrt_var = torch.sqrt(var)
ctx.eps = eps
ctx.sum = sum
ctx.N = N
ctx.sqrt_var = sqrt_var
mean = sum / N
denom = sqrt_var + eps
out = X - unsqueeze_all(mean)
out /= unsqueeze_all(denom)
return out
def backward(ctx, grad_out):
X, = ctx.saved_tensors
return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)
a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)
融合卷积和 BatchNorm¶
现在大部分工作已经完成,我们可以合并 他们在一起。请注意,在 (1) 中,我们只保存了一个缓冲区 for backward,但这也意味着我们重新计算 convolution forward 在 (5) 中。另请参见 (2)、(3)、(4) 和 (6) 中,它们是相同的 与上面的示例完全相同的代码。
class FusedConvBN2DFunction(torch.autograd.Function):
def forward(ctx, X, conv_weight, eps=1e-3):
assert X.ndim == 4 # N, C, H, W
# (1) Only need to save this single buffer for backward!
ctx.save_for_backward(X, conv_weight)
# (2) Exact same Conv2D forward from example above
X = F.conv2d(X, conv_weight)
# (3) Exact same BatchNorm2D forward from example above
sum = X.sum(dim=(0, 2, 3))
var = X.var(unbiased=True, dim=(0, 2, 3))
N = X.numel() / X.size(1)
sqrt_var = torch.sqrt(var)
ctx.eps = eps
ctx.sum = sum
ctx.N = N
ctx.sqrt_var = sqrt_var
mean = sum / N
denom = sqrt_var + eps
# Try to do as many things in-place as possible
# Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`
# avoids allocating one extra NCHW-sized buffer here
out = X - unsqueeze_all(mean)
out /= unsqueeze_all(denom)
return out
def backward(ctx, grad_out):
X, conv_weight, = ctx.saved_tensors
# (4) Batch norm backward
# (5) We need to recompute conv
X_conv_out = F.conv2d(X, conv_weight)
grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,
ctx.N, ctx.eps)
# (6) Conv2d backward
grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)
return grad_X, grad_input, None, None, None, None, None
下一步是将我们的函数式变体包装在一个有状态的 nn 中。模块
import torch.nn as nn
import math
class FusedConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,
eps=1e-3, device=None, dtype=None):
super(FusedConvBN, self).__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
# Conv parameters
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))
# Batch norm parameters
num_features = out_channels
self.num_features = num_features
self.eps = eps
# Initialize
def forward(self, X):
return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))
测试我们的新 Layer¶
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
# Record memory allocated at the end of the forward pass
memory_allocated = [[],[]]
class Net(nn.Module):
def __init__(self, fused=True):
super(Net, self).__init__()
self.fused = fused
if fused:
self.convbn1 = FusedConvBN(1, 32, 3)
self.convbn2 = FusedConvBN(32, 64, 3)
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)
self.fc1 = nn.Linear(9216, 128)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
if self.fused:
x = self.convbn1(x)
x = self.conv1(x)
x = self.bn1(x)
if self.fused:
x = self.convbn2(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.max_pool2d(x, 2)
x = x.flatten(1)
x = self.fc1(x)
x = self.dropout(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
if fused:
return output
def train(model, device, train_loader, optimizer, epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 2 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
test_loss = 0
correct = 0
# Use inference mode instead of no_grad, for free improved test-time performance
with torch.inference_mode():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': 2048}
test_kwargs = {'batch_size': 2048}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
dataset1 = datasets.MNIST('../data', train=True, download=True,
dataset2 = datasets.MNIST('../data', train=False,
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
如果启用了 CUDA,则打印出 fused=True 和 fused=False 的内存使用情况对于在 NVIDIA GeForce RTX 3070 上运行的示例,NVIDIA CUDA® 深度神经网络库 (cuDNN) 8.0.5:融合峰值内存:1.56GB, 未融合峰值内存:2.68GB
请务必注意,此模型的峰值内存使用量可能会有所不同,具体取决于 使用的特定 cuDNN 卷积算法。对于较浅的型号,它 为 Fused 模型分配的峰值内存可能会超过 未融合模型的 THE FEEL!这是因为分配给 compute 的内存 某些 cuDNN 卷积算法可能足够高,以“隐藏”典型峰值 您会期望接近 backward pass 的起点。
为每个 fused 对分配一个少 buffer 的缓冲区。conv-bn
from statistics import mean
torch.backends.cudnn.enabled = True
if use_cuda:
peak_memory_allocated = []
for fused in (True, False):
model = Net(fused=fused).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
print("cuDNN version:", torch.backends.cudnn.version())
print("Peak memory allocated:")
print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB")
print("Memory allocated at end of forward pass:")
print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB")
Train Epoch: 0 [0/60000 (0%)] Loss: 2.348735
Train Epoch: 0 [4096/60000 (7%)] Loss: 7.435781
Train Epoch: 0 [8192/60000 (13%)] Loss: 5.540894
Train Epoch: 0 [12288/60000 (20%)] Loss: 2.274223
Train Epoch: 0 [16384/60000 (27%)] Loss: 1.618885
Train Epoch: 0 [20480/60000 (33%)] Loss: 1.515203
Train Epoch: 0 [24576/60000 (40%)] Loss: 1.329276
Train Epoch: 0 [28672/60000 (47%)] Loss: 1.184942
Train Epoch: 0 [32768/60000 (53%)] Loss: 1.140154
Train Epoch: 0 [36864/60000 (60%)] Loss: 1.174118
Train Epoch: 0 [40960/60000 (67%)] Loss: 1.057965
Train Epoch: 0 [45056/60000 (73%)] Loss: 0.976334
Train Epoch: 0 [49152/60000 (80%)] Loss: 0.842555
Train Epoch: 0 [53248/60000 (87%)] Loss: 0.690169
Train Epoch: 0 [57344/60000 (93%)] Loss: 0.656998
Test set: Average loss: 0.4197, Accuracy: 8681/10000 (87%)
Train Epoch: 0 [0/60000 (0%)] Loss: 2.349030
Train Epoch: 0 [4096/60000 (7%)] Loss: 7.435158
Train Epoch: 0 [8192/60000 (13%)] Loss: 5.443529
Train Epoch: 0 [12288/60000 (20%)] Loss: 2.457773
Train Epoch: 0 [16384/60000 (27%)] Loss: 1.739528
Train Epoch: 0 [20480/60000 (33%)] Loss: 1.448555
Train Epoch: 0 [24576/60000 (40%)] Loss: 1.311784
Train Epoch: 0 [28672/60000 (47%)] Loss: 1.149165
Train Epoch: 0 [32768/60000 (53%)] Loss: 1.513479
Train Epoch: 0 [36864/60000 (60%)] Loss: 1.243767
Train Epoch: 0 [40960/60000 (67%)] Loss: 1.079315
Train Epoch: 0 [45056/60000 (73%)] Loss: 0.896300
Train Epoch: 0 [49152/60000 (80%)] Loss: 0.839771
Train Epoch: 0 [53248/60000 (87%)] Loss: 0.729098
Train Epoch: 0 [57344/60000 (93%)] Loss: 0.748637
Test set: Average loss: 0.4340, Accuracy: 8715/10000 (87%)
cuDNN version: 90100
Peak memory allocated:
fused: 2.30GB, unfused: 1.77GB
Memory allocated at end of forward pass:
fused: 0.59GB, unfused: 0.96GB
