torchtext.nn 中¶
多头注意力容器¶
- 类 torchtext.nn 中。MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[来源]¶
- __init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) 无 [来源]¶
多头注意力容器
- 参数
nhead – MultiHeadAttention 模型中的头数
in_proj_container – 多头投影内线性图层(又名 nn.线性)。
attention_layer – 自定义注意力层。从 MHA 容器发送到注意力层的输入 形状为 (..., L, N * H, E / H) 表示查询,形状为 (..., S, N * H, E / H) 表示键/值 而注意力层的输出形状应为 (..., L, N * H, E / H)。 如果用户想要整个 MultiheadAttentionContainer attention_layer 则需要支持 broadcast 与广播。
out_proj – 多头外投影层(又名 nn.线性)。
batch_first – 如果 ,则提供输入和输出张量 如 (..., N, L, E)。违约:
True
False
- 例子::
>>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: 可选[Tensor] = 无,bias_k:可选[Tensor] = 无,bias_v:可选[Tensor] = None) Tuple[Tensor, Tensor] [来源]¶
- 参数
query (Tensor) – attention 函数的查询。 有关更多详细信息,请参阅“Attention Is All You Need”。
key (Tensor) – 注意力函数的键。 有关更多详细信息,请参阅“Attention Is All You Need”。
value (Tensor) - 注意力函数的值。 有关更多详细信息,请参阅“Attention Is All You Need”。
attn_mask (BoolTensor,可选) – 阻止对某些位置的注意的 3D 掩码。
bias_k (Tensor ,可选) – 要添加到 key 的另外一个键和值序列 序列 dim (dim=-3)。这些用于增量解码。用户应提供 .
bias_v
bias_v (Tensor,可选) – 要添加到 at 的值的另一个键和值序列 序列 dim (dim=-3)。这些用于增量解码。用户还应提供 .
bias_k
形状:
输入:
查询:
钥匙:
价值:
attn_mask、bias_k 和 bias_v:与关注层中对应 args 的形状相同。
输出:
attn_output:
attn_output_weights:
注意:可以选择具有三个以上维度的 query/key/value 输入(用于广播目的)。 MultiheadAttentionContainer 模块将在最后三个维度上运行。
其中 L 是目标长度,S 是序列长度,H 是注意力头的数量, N 是批量大小,E 是嵌入维度。
InProjContainer¶
- 类 torchtext.nn 中。InProjContainer(query_proj, key_proj, value_proj)[来源]¶
- __init__(query_proj, key_proj, value_proj) 无 [来源]¶
一个项目内容器,用于在 MultiheadAttention 中投影 query/key/value。此模块发生在重塑之前 将 query/key/value 投影到多个 heads。请参阅 Multi-head Attention 的线性图层(底部) Attention Is All You Need 的图 2 纸。另请查看使用示例 在 torchtext.nn.MultiheadAttentionContainer 中。
- 参数
query_proj – 用于查询的 proj 层。典型的投影层是 torch.nn.Linear。
key_proj – Key 的 proj 层。典型的投影层是 torch.nn.Linear。
value_proj – value 的 proj 层。典型的投影层是 torch.nn.Linear。
- forward(query: Tensor, key: Tensor, value: Tensor) Tuple[张量、张量、张量] [来源]¶
使用 in-proj layers 投影输入序列。query/key/value 都简单地传递给 分别是 query/key/value_proj 的 forward func。
- 参数
query (Tensor) – 要投影的查询。
key (Tensor) – 要投影的键。
value (Tensor) (value (Tensor)) ) – 要投影的值。
- 例子::
>>> import torch >>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v)
ScaledDotProduct¶
- 类 torchtext.nn 中。ScaledDotProduct(dropout=0.0, batch_first=False)[来源]¶
- __init__(dropout=0.0, batch_first=False) None [来源]¶
处理要应用的投影查询和键值对 缩放点积注意。
- 参数
dropout (float) - 放弃注意力权重的概率。
batch_first – 如果 ,则提供输入和输出张量 as (batch, seq, feature) 中。违约:
True
False
- 例子::
>>> import torch, torchtext >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: 可选[Tensor] = 无,bias_k:可选[Tensor] = 无,bias_v:可选[Tensor] = None) Tuple[Tensor, Tensor] [来源]¶
使用带有投影键值对的缩放点积进行更新 投影的查询。
- 参数
query (Tensor) – 投影查询
key (Tensor) – 投影的键
value (Tensor) – 预计值
attn_mask (BoolTensor,可选) – 阻止对某些位置的注意的 3D 掩码。
attn_mask – 阻止对某些位置的注意的 3D 蒙版。
bias_k (Tensor ,可选) – 要添加到 key 的另外一个键和值序列 序列 dim (dim=-3)。这些用于增量解码。用户应提供 .
bias_v
bias_v (Tensor,可选) – 要添加到 at 的值的另一个键和值序列 序列 dim (dim=-3)。这些用于增量解码。用户还应提供 .
bias_k
- 形状:
查询:
钥匙:
价值:
- attn_mask: , 职位不允许参加
True
while 值将保持不变。
False
- attn_mask: , 职位不允许参加
bias_k 和 bias_v:bias:
输出: ,
- 注意:可以选择具有三个以上维度的 query/key/value 输入(用于广播目的)。
ScaledDotProduct 模块将在最后三个维度上运行。
其中 L 是目标长度,S 是源长度,H 是数字 的 attention heads,N 是批量大小,E 是嵌入维度。