torchtext.nn¶
MultiheadAttentionContainer¶
- class torchtext.nn.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]¶
- __init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) None[source]¶
一个多头注意力容器
- Parameters:
nhead – 多头注意力模型中的头数
in_proj_container – 多头内投影线性层的容器(又称 nn.Linear)。
attention_layer – 自定义注意力层。从MHA容器发送到注意力层的输入在形状上为(…, L, N * H, E / H)用于查询,(…, S, N * H, E / H)用于键/值, 而注意力层的输出形状预期为(…, L, N * H, E / H)。 如果用户希望整体MultiheadAttentionContainer支持广播,则注意力层需要支持广播。
out_proj – 多头输出投影层(即 nn.Linear)。
batch_first – 如果
True,则输入和输出张量为(…, N, L, E)。默认值:False
- Examples::
>>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]¶
- Parameters:
查询 (张量) – 注意力函数的查询。 有关更多详细信息,请参阅“注意力机制全靠它”。
键 (张量) – 注意力函数的键。 有关更多详细信息,请参阅“注意力机制全靠它”。
值 (张量) – 注意力函数的值。 有关更多详细信息,请参阅“注意力机制全靠它”。
attn_mask (BoolTensor, 可选) – 3D 遮罩,用于阻止对某些位置的注意力。
bias_k (张量, 可选) – 在序列维度(dim=-3)上要添加的一个额外键和值序列。这些用于增量解码。用户应提供
bias_v。bias_v (张量, 可选) – 在序列维度(dim=-3)上要添加的一个额外键和值序列。这些用于增量解码。用户还应提供
bias_k。
Shape:
Inputs:
query: \((..., L, N, E)\)
key: \((..., S, N, E)\)
value: \((..., S, N, E)\)
attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer.
Outputs:
attn_output: \((..., L, N, E)\)
attn_output_weights: \((N * H, L, S)\)
Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose). The MultiheadAttentionContainer module will operate on the last three dimensions.
where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension.
InProjContainer¶
- class torchtext.nn.InProjContainer(query_proj, key_proj, value_proj)[source]¶
- __init__(query_proj, key_proj, value_proj) None[source]¶
一个多头注意力中的投影容器,用于在 MultiheadAttention 中投影查询/键/值。此模块发生在将投影后的查询/键/值重塑为多个头之前。参见《Attention Is All You Need》论文图 2 中多头注意力的线性层(底部)。另请查看 torchtext.nn.MultiheadAttentionContainer 的用法示例。
- Parameters:
query_proj – 用于查询的投影层。典型的投影层是 torch.nn.Linear。
key_proj – 用于键的投影层。典型的投影层是 torch.nn.Linear。
value_proj – 一个值投影层。典型的投影层是 torch.nn.Linear。
- forward(query: Tensor, key: Tensor, value: Tensor) Tuple[Tensor, Tensor, Tensor][source]¶
通过 in-proj 层投影输入序列。query、key 和 value 分别简单地传递给 query/key/value_proj 的前向函数。
- Parameters:
查询 (张量) – 需要被投影的查询。
键 (张量) – 需要投影的键。
值 (张量) – 需要投影的值。
- Examples::
>>> import torch >>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v)
ScaledDotProduct¶
- class torchtext.nn.ScaledDotProduct(dropout=0.0, batch_first=False)[source]¶
- __init__(dropout=0.0, batch_first=False) None[source]¶
处理一个投影的查询和键值对以应用缩放点积注意力。
- Parameters:
dropout (float) – 概率值,表示丢弃注意力权重的概率。
batch_first – 如果
True,则输入和输出张量为(batch, seq, feature)。默认值:False
- Examples::
>>> import torch, torchtext >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21])
- forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]¶
使用缩放点积和投影的键值对来更新投影的查询。
- Parameters:
查询 (张量) – 投影查询
键 (张量) – 投影键
值 (张量) – 投影值
attn_mask (BoolTensor, 可选) – 3D 遮罩,用于阻止对某些位置的注意力。
attn_mask – 3D 网罩,用于阻止对某些位置的注意力。
bias_k (张量, 可选) – 在序列维度(dim=-3)上要添加的一个额外键和值序列。这些用于增量解码。用户应提供
bias_v。bias_v (张量, 可选) – 在序列维度(dim=-3)上要添加的一个额外键和值序列。这些用于增量解码。用户还应提供
bias_k。
- Shape:
查询: \((..., L, N * H, E / H)\)
键: \((..., S, N * H, E / H)\)
值: \((..., S, N * H, E / H)\)
- attn_mask: \((N * H, L, S)\), positions with
Trueare not allowed to attend 当
False值不变。
- attn_mask: \((N * H, L, S)\), positions with
偏置_k 和 偏置_v: 偏置: \((1, N * H, E / H)\)
输出:\((..., L, N * H, E / H)\),\((N * H, L, S)\)
- Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).
ScaledDotProduct 模块将在最后三个维度上进行运算。
其中 L 是目标长度,S 是源长度,H 是注意力头的数量,N 是批量大小,E 是嵌入维度。