注意
单击此处下载完整的示例代码
(测试版)实现具有缩放点积注意 (SDPA) 的高性能变压器¶
创建时间: 2023年3月15日 |上次更新时间: 2024-10-09 |上次验证: Nov 05, 2024
作者: Driss Guessous
总结¶
在本教程中,我们要重点介绍一个新函数
这对于实现 transformer 架构很有帮助。这
函数被命名为 。
有关该功能的详细说明,请参阅 PyTorch 文档。
此函数已合并到 和 中。torch.nn.functional
torch.nn.functional.scaled_dot_product_attention
torch.nn.MultiheadAttention
torch.nn.TransformerEncoderLayer
概述¶
概括地,此 PyTorch 函数计算 根据 query、key 和 value 之间的缩放点积注意力 (SDPA) 在论文 Attention is all you 中找到的定义 需要。虽然这个函数可以 使用现有函数在 PyTorch 中编写,融合实现可以提供 与简单的实施相比,性能优势更大。
融合实现¶
对于 CUDA 张量输入,该函数将分派到以下之一 实现:
用 C++ 定义的 PyTorch 实现
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Example Usage:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691,
-1.2593],
[-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820,
-0.6171],
[-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607,
-0.6404]],
[[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335,
0.0630],
[ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242,
0.2312],
[ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429,
0.1037]]], device='cuda:0')
显式 Dispatcher 控制¶
虽然该函数将隐式 dispatch 到三个 实现中,用户也可以通过 上下文管理器的使用。此上下文管理器允许用户 显式禁用某些 implementation。如果用户希望确保 该函数确实为他们的 特定输入,上下文管理器可用于扫掠 衡量性能。
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# Lets explore the speed of each of the 3 implementations
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel(SDPBackend.MATH):
math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The math implementation runs in {math_time:.3f} microseconds")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
try:
efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2326.418 microseconds
The math implementation runs in 87382.506 microseconds
The flash attention implementation runs in 2328.379 microseconds
The memory efficient implementation runs in 4305.558 microseconds
硬件依赖性¶
具体取决于您运行上述单元的机器和硬件 available,则您的结果可能会有所不同。 - 如果您没有 GPU 并且在 CPU 上运行,那么使用 FP32 时,上下文管理器 将不起作用,并且所有三次运行都应返回相似的时间。 - 取决于您的显卡支持的计算能力 Flash Attention 或 Memory Efficient 可能已失败。
因果自我注意¶
下面是一个多头因果自我的示例实现 受 Andrej Karpathy NanoGPT 存储库启发的注意力块。
class CausalSelfAttention(nn.Module):
def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):
super().__init__()
assert embed_dimension % num_heads == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)
# output projection
self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)
# regularization
self.dropout = dropout
self.resid_dropout = nn.Dropout(dropout)
self.num_heads = num_heads
self.embed_dimension = embed_dimension
# Perform causal masking
self.is_causal = is_causal
def forward(self, x):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query_projected = self.c_attn(x)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
if self.training:
dropout = self.dropout
is_causal = self.is_causal
else:
dropout = 0.0
is_causal = False
y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)
y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)
y = self.resid_dropout(self.c_proj(y))
return y
num_heads = 8
heads_per_dim = 64
embed_dimension = num_heads * heads_per_dim
dtype = torch.float16
model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval()
print(model)
CausalSelfAttention(
(c_attn): Linear(in_features=512, out_features=1536, bias=False)
(c_proj): Linear(in_features=512, out_features=512, bias=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
NestedTensor
和 Dense Tensor 支持¶
SDPA 同时支持 Dense 张量输入。 处理 input 是一批可变长度序列的情况
而无需将每个序列填充到批处理中的最大长度。有关更多信息,请参阅 torch.nested 和 NestedTensors 教程。NestedTensor
NestedTensors
NestedTensors
import random
def generate_rand_batch(
batch_size,
max_sequence_len,
embed_dimension,
pad_percentage=None,
dtype=torch.float16,
device="cuda",
):
if not pad_percentage:
return (
torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device=device,
),
None,
)
# Random sequence lengths
seq_len_list = [
int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
for _ in range(batch_size)
]
# Make random entry in the batch have max sequence length
seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
return (
torch.nested.nested_tensor(
[
torch.randn(seq_len, embed_dimension,
dtype=dtype, device=device)
for seq_len in seq_len_list
]
),
seq_len_list,
)
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)
# Currently the fused implementations don't support ``NestedTensor`` for training
model.eval()
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
try:
print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds")
print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")
/usr/local/lib/python3.10/dist-packages/torch/nested/__init__.py:226: UserWarning:
The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
Random NT runs in 561.846 microseconds
Random Dense runs in 948.365 microseconds
将 SDPA 与torch.compile
¶
随着 PyTorch 2.0 的发布,引入了一个名为
与 EAGER 模式相比,性能有显著提高。
缩放点积 attention 可与 .
为了演示这一点,让我们使用 编译模块并观察由此产生的性能改进。torch.compile()
torch.compile()
CausalSelfAttention
torch.compile()
batch_size = 32
max_sequence_len = 256
x = torch.rand(batch_size, max_sequence_len,
embed_dimension, device=device, dtype=dtype)
print(
f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds")
compiled_model = torch.compile(model)
# Let's compile it
compiled_model(x)
print(
f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 415.514 microseconds
The compiled module runs in 513.798 microseconds
确切的执行时间取决于机器,但是我的结果: 未编译的模块在 166.616 微秒内运行 编译后的模块在 166.726 微秒内运行 这不是我们所期望的。让我们更深入地了解一下。 PyTorch 附带了一个出色的内置分析器,您可以使用它来 检查代码的性能特征。
from torch.profiler import profile, record_function, ProfilerActivity
activities = [ProfilerActivity.CPU]
if device == 'cuda':
activities.append(ProfilerActivity.CUDA)
with profile(activities=activities, record_shapes=False) as prof:
with record_function(" Non-Compilied Causal Attention"):
for _ in range(25):
model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
with profile(activities=activities, record_shapes=False) as prof:
with record_function("Compiled Causal Attention"):
for _ in range(25):
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
#
# .. code-block:: python
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Non-Compilied Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.515ms 101.39% 10.515ms 10.515ms 1
Non-Compilied Causal Attention 20.72% 2.282ms 75.23% 8.284ms 8.284ms 0.000us 0.00% 10.371ms 10.371ms 1
aten::linear 1.14% 126.000us 28.07% 3.091ms 61.823us 0.000us 0.00% 7.749ms 154.980us 50
aten::matmul 2.18% 239.531us 24.11% 2.655ms 53.096us 0.000us 0.00% 7.749ms 154.980us 50
aten::mm 15.35% 1.690ms 19.65% 2.164ms 43.274us 7.749ms 74.72% 7.749ms 154.980us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.544ms 53.46% 5.544ms 221.767us 25
aten::scaled_dot_product_attention 1.91% 210.851us 17.21% 1.896ms 75.823us 0.000us 0.00% 2.622ms 104.893us 25
aten::_scaled_dot_product_flash_attention 2.87% 316.371us 15.30% 1.685ms 67.389us 0.000us 0.00% 2.622ms 104.893us 25
aten::_flash_attention_forward 3.42% 377.071us 10.69% 1.177ms 47.081us 2.622ms 25.28% 2.622ms 104.893us 25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.622ms 25.28% 2.622ms 104.893us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 11.012ms
Self CUDA time total: 10.371ms
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Compiled Causal Attention 0.00% 0.000us 0.00% 0.000us 0.000us 10.569ms 101.84% 10.569ms 10.569ms 1
Compiled Causal Attention 8.70% 984.016us 73.88% 8.360ms 8.360ms 0.000us 0.00% 10.378ms 10.378ms 1
Torch-Compiled Region 8.42% 952.756us 62.97% 7.126ms 285.053us 0.000us 0.00% 10.378ms 415.117us 25
CompiledFunction 26.38% 2.985ms 54.55% 6.174ms 246.943us 0.000us 0.00% 10.378ms 415.117us 25
aten::mm 9.38% 1.061ms 14.07% 1.592ms 31.842us 7.758ms 74.75% 7.758ms 155.157us 50
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.552ms 53.50% 5.552ms 222.092us 25
aten::_scaled_dot_product_flash_attention 2.12% 239.381us 14.10% 1.596ms 63.844us 0.000us 0.00% 2.620ms 104.803us 25
aten::_flash_attention_forward 3.46% 391.220us 10.29% 1.165ms 46.582us 2.620ms 25.25% 2.620ms 104.803us 25
void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.620ms 25.25% 2.620ms 104.803us 25
ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.206ms 21.25% 2.206ms 88.222us 25
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 11.316ms
Self CUDA time total: 10.378ms
前面的代码片段生成了前 10 个 PyTorch 函数的报告
对于已编译和未编译的模块,消耗的 GPU 执行时间最多。
分析表明,花在 GPU 上的大部分时间都集中在 GPU 上
在两个模块的同一组函数上。
这样做的原因是它非常擅长去除
与 PyTorch 关联的框架开销。如果您的模型正在启动
大型、高效的 CUDA 内核,在本例中是,则可以隐藏 PyTorch 的开销。torch.compile
CausalSelfAttention
实际上,你的模块通常不是由单个块组成的。在试验 Andrej Karpathy NanoGPT 存储库时,编译
该模块将每个火车步骤的时间从 : 到 !这是在 commit: of NanoGPT 训练
Shakespeare 数据集。CausalSelfAttention
6090.49ms
3273.17ms
ae3a8d5
将 SDPA 与 attn_bias 子类一起使用¶
# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.
# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.
# The module is named ``torch.nn.attention.bias`` and contains the following two
# utilities for generating causal attention variants:
#
# - ``torch.nn.attention.bias.causal_upper_left``
# - ``torch.nn.attention.bias.causal_lower_right``
#
# .. note::
# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``
# is the same as using ``torch.nn.attention.bias.causal_upper_left``.
#
from torch.nn.attention.bias import causal_lower_right, causal_upper_left
batch_size = 32
sequence_length_q = 2
sequence_length_kv = 10
num_heads = 16
embed_dimension = 32
dtype = torch.float16
query = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)
upper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)
lower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)
print(type(upper_left_bias))
print(type(lower_right_bias))
assert type(upper_left_bias) == type(lower_right_bias)
assert issubclass(type(upper_left_bias), torch.Tensor)
# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``
# and subclass ``torch.Tensor``
# Lets see what these tensors look like
print(upper_left_bias)
print(lower_right_bias)
# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.
# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.
# Another way of thinking about this concept is that when you use upper left bias,
# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,
# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score
# between the 0th token in the query and the 0th token in the key.
# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k
# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k
# even if the sequence length of q and k are different.
# These objects are intended to be used with sdpa
out_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)
out_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)
out_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)
assert torch.allclose(out_upper_left, out_is_causal)
assert not torch.allclose(out_upper_left, out_lower_right)
# These attention biases should also be compatible with torch.compile
compiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)
out_upper_left = compiled_sdpa(query, key, value, upper_left_bias)
<class 'torch.nn.attention.bias.CausalBias'>
<class 'torch.nn.attention.bias.CausalBias'>
tensor([[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False]])
tensor([[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]])
结论¶
在本教程中,我们演示了 .我们已经展示了如何
上下文管理器 可用于断言某个
implementation 在 GPU 上使用。此外,我们还构建了一个简单的模块,该模块可与 torch 一起使用
可编译的。在此过程中,我们展示了如何让分析工具
用于探索用户定义的
模块。torch.nn.functional.scaled_dot_product_attention
sdpa_kernel
CausalSelfAttention
NestedTensor
脚本总运行时间:(0 分 7.698 秒)