目录

(测试版) 使用缩放点积注意力(SDPA)实现高性能变压器

创建时间:2023年3月15日 | 最后更新时间:2024年10月9日 | 最后验证时间:2024年11月5日

作者: Driss Guessous

摘要

在本教程中,我们想突出一个对实现变压器架构有帮助的新torch.nn.functional函数。 该函数名为torch.nn.functional.scaled_dot_product_attention。 有关该函数的详细描述,请参阅PyTorch 文档。 此函数已包含在torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer中。

概述

在高层次上,这个PyTorch函数根据论文注意力就是你所需要的中的定义,计算查询、键和值之间的缩放点积注意力(SDPA)。虽然可以使用现有的函数在PyTorch中编写此函数,但融合实现相比简单的实现可以提供显著的性能优势。

融合实现

对于CUDA张量输入,该函数将派发到以下其中一个实现:

注意

本教程要求使用 PyTorch 2.0.0 或更高版本。

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489,  0.3015, -0.3912,  0.9867,  0.3137, -0.0691,
          -1.2593],
         [-1.0882,  0.2506,  0.6491,  0.1360,  0.5238, -0.2448, -0.0820,
          -0.6171],
         [-1.0012,  0.3990,  0.6441, -0.0277,  0.5325, -0.2564, -0.0607,
          -0.6404]],

        [[ 0.6091,  0.0708,  0.6188,  0.3252, -0.1598,  0.4197, -0.2335,
           0.0630],
         [ 0.5285,  0.3890, -0.2649,  0.3706, -0.3839,  0.1963, -0.6242,
           0.2312],
         [ 0.4048,  0.0762,  0.3777,  0.4689, -0.2978,  0.2754, -0.6429,
           0.1037]]], device='cuda:0')

显式调度器控制

虽然该函数会隐式地调度到三个实现中的一个,但用户也可以通过使用上下文管理器显式地控制调度。此上下文管理器允许用户显式地禁用某些实现。如果用户希望确保该函数确实使用了对其特定输入最快的实现,可以通过上下文管理器遍历并测量性能。

# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel


with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
    print(f"The math implementation runs in {math_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2326.418 microseconds
The math implementation runs in 87382.506 microseconds
The flash attention implementation runs in 2328.379 microseconds
The memory efficient implementation runs in 4305.558 microseconds

硬件依赖性

根据你运行上述单元格的机器以及可用的硬件情况,你的结果可能会有所不同。 - 如果你没有 GPU 并且是在 CPU 上运行,那么使用 FP32 时上下文管理器将不会产生任何效果,三次运行的时间应该相似。 - 根据你的显卡支持的计算能力,flash attention 或 memory efficient 可能会失败。

因果自注意力

以下是一个多头因果自注意力块的示例实现,灵感来自于 Andrej Karpathy NanoGPT 仓库。

class CausalSelfAttention(nn.Module):

    def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
        super().__init__()
        assert embed_dimension % num_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
        # output projection
        self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
        # regularization
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)
        self.num_heads = num_heads
        self.embed_dimension = embed_dimension
        # Perform causal masking
        self.is_causal = is_causal

    def forward(self, x):
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        query_projected = self.c_attn(x)

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)
        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        if self.training:
            dropout = self.dropout
            is_causal = self.is_causal
        else:
            dropout = 0.0
            is_causal = False

        y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
        y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)

        y = self.resid_dropout(self.c_proj(y))
        return y


num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
  (c_attn): Linear(in_features=512, out_features=1536, bias=False)
  (c_proj): Linear(in_features=512, out_features=512, bias=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

NestedTensor 和密集张量支持

SDPA 支持 NestedTensor 和 Dense 张量输入。 NestedTensors 处理输入为一批可变长度序列的情况,而无需将每个序列填充到批次中的最大长度。如需了解更多信息 NestedTensors,请参见 torch.nestedNestedTensors 教程

import random
def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Random sequence lengths
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random entry in the batch have max sequence length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension,
                            dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )

random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    try:
        print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
        print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:226: UserWarning:

The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)

Random NT runs in 561.846 microseconds
Random Dense runs in 948.365 microseconds

使用 SDPA 与 torch.compile

随着PyTorch 2.0的发布,引入了一个新功能 torch.compile(),它可以提供比eager模式显著的性能提升。 缩放点积注意力机制可以与 torch.compile() 完全组合使用。 为了演示这一点,让我们使用 CausalSelfAttention 模块通过 torch.compile() 进行编译,并观察由此产生的性能提升。

batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
               embed_dimension, device=device, dtype=dtype)
print(
    f"The non compiled module runs in  {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")


compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
    f"The compiled module runs in  {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in  415.514 microseconds
The compiled module runs in  513.798 microseconds

精确的执行时间取决于机器,不过我得到的结果是: 非编译模块运行时间为 166.616 微秒 编译模块运行时间为 166.726 微秒 这并不是我们所期望的结果。让我们更深入地探究一下。 PyTorch 提供了一个令人惊叹的内置分析器,您可以用来 检查您代码的性能特征。

from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
    activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, record_shapes=False) as prof:
    with record_function(" Non-Compilied Causal Attention"):
        for _ in range(25):
            model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


with profile(activities=activities, record_shapes=False) as prof:
    with record_function("Compiled Causal Attention"):
        for _ in range(25):
            compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
#    prof.export_chrome_trace("compiled_causal_attention_trace.json").
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Non-Compilied Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.515ms       101.39%      10.515ms      10.515ms             1
                         Non-Compilied Causal Attention        20.72%       2.282ms        75.23%       8.284ms       8.284ms       0.000us         0.00%      10.371ms      10.371ms             1
                                           aten::linear         1.14%     126.000us        28.07%       3.091ms      61.823us       0.000us         0.00%       7.749ms     154.980us            50
                                           aten::matmul         2.18%     239.531us        24.11%       2.655ms      53.096us       0.000us         0.00%       7.749ms     154.980us            50
                                               aten::mm        15.35%       1.690ms        19.65%       2.164ms      43.274us       7.749ms        74.72%       7.749ms     154.980us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.544ms        53.46%       5.544ms     221.767us            25
                     aten::scaled_dot_product_attention         1.91%     210.851us        17.21%       1.896ms      75.823us       0.000us         0.00%       2.622ms     104.893us            25
              aten::_scaled_dot_product_flash_attention         2.87%     316.371us        15.30%       1.685ms      67.389us       0.000us         0.00%       2.622ms     104.893us            25
                         aten::_flash_attention_forward         3.42%     377.071us        10.69%       1.177ms      47.081us       2.622ms        25.28%       2.622ms     104.893us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.622ms        25.28%       2.622ms     104.893us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.012ms
Self CUDA time total: 10.371ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                              Compiled Causal Attention         0.00%       0.000us         0.00%       0.000us       0.000us      10.569ms       101.84%      10.569ms      10.569ms             1
                              Compiled Causal Attention         8.70%     984.016us        73.88%       8.360ms       8.360ms       0.000us         0.00%      10.378ms      10.378ms             1
                                  Torch-Compiled Region         8.42%     952.756us        62.97%       7.126ms     285.053us       0.000us         0.00%      10.378ms     415.117us            25
                                       CompiledFunction        26.38%       2.985ms        54.55%       6.174ms     246.943us       0.000us         0.00%      10.378ms     415.117us            25
                                               aten::mm         9.38%       1.061ms        14.07%       1.592ms      31.842us       7.758ms        74.75%       7.758ms     155.157us            50
         ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn         0.00%       0.000us         0.00%       0.000us       0.000us       5.552ms        53.50%       5.552ms     222.092us            25
              aten::_scaled_dot_product_flash_attention         2.12%     239.381us        14.10%       1.596ms      63.844us       0.000us         0.00%       2.620ms     104.803us            25
                         aten::_flash_attention_forward         3.46%     391.220us        10.29%       1.165ms      46.582us       2.620ms        25.25%       2.620ms     104.803us            25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::...         0.00%       0.000us         0.00%       0.000us       0.000us       2.620ms        25.25%       2.620ms     104.803us            25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3...         0.00%       0.000us         0.00%       0.000us       0.000us       2.206ms        21.25%       2.206ms      88.222us            25
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 11.316ms
Self CUDA time total: 10.378ms

前面的代码片段生成了一个报告,列出了消耗最多GPU执行时间的前10个PyTorch函数,包括编译和未编译模块。 分析表明,两个模块在GPU上花费的大部分时间都集中在相同的函数集合上。 这里的原因是torch.compile非常擅长去除与PyTorch相关的框架开销。如果您的模型启动了大型、高效的CUDA内核(在这种情况下CausalSelfAttention就是如此),那么PyTorch的开销就可以被隐藏。

实际上,你的模块通常不包含一个单独的 CausalSelfAttention 块。在尝试 Andrej Karpathy NanoGPT 仓库时,编译 模块使每个训练步骤的时间从:6090.49ms3273.17ms!这是在 NanoGPT 在 Shakespeare 数据集上训练的提交:ae3a8d5 上完成的。

使用带有attn_bias子类的SDPA

# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
#    The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
#    is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#

from torch.nn.attention.bias import causal_lower_right, causal_upper_left

batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32

dtype = torch.float16

query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)

upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)

print(type(upper_left_bias))
print(type(lower_right_bias))

assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)

# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``

# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)

# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.

# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)

assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)

# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False]])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

结论

在这个教程中,我们演示了 torch.nn.functional.scaled_dot_product_attention 的基本用法。我们展示了如何使用 sdpa_kernel 上下文管理器来断言某个实现是在 GPU 上使用的。此外,我们构建了一个简单的 CausalSelfAttention 模块,该模块与 NestedTensor 兼容并且可以被 torch 编译。在过程中,我们展示了如何使用性能分析工具来探索用户自定义模块的性能特征。

脚本总运行时间: ( 0 分钟 7.698 秒)

通过 Sphinx-Gallery 生成的画廊

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源