BartAttention

Ikko Lv3

实验结果

使用cnn-daily test集1w条数据推理,记得attention计算各部分耗时如下

1
2
3
4
5
6
7
8
9
Q Proj Total Time: 396.274400s
K V Proj Total Time: 445.651598s
Attention Reshape Total Time: 130.337801s
Attn Weights Compute Total Time: 165.823140s
Attn Mask Total Time: 0.015915s
Softmax Total Time: 129.015671s
Head Mask Dropout Total Time: 53.706548s
Attn Output Compute Total Time: 211.477478s
Out Proj Total Time: 297.960706s

概述

BartAttention 实现了多头注意力机制(Multi-Head Attention, MHA),通过投影、重塑和并行计算,将输入张量拆分为多个头的表示形式,并在注意力计算中高效处理。本文档分析张量维度从输入到注意力权重计算的变化过程,分为两部分:

  1. 通解分析:基于通用参数(如 , , )和 BartAttention 的代码,描述维度变化的通用规律。
  2. 实例佐证:结合日志数据(),验证通解的正确性。

分析

通用参数定义

  • 批次大小:,batchsize。
  • 序列长度:(对于自注意力,;对于跨注意力,)。
  • 模型嵌入维度:
  • 注意力头数:
  • 每头维度:

代码片段

BartAttention 原论文attention实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[BartConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal

self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
import time
import logging as loggingg

is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
loggingg.info(f"Input hidden_states shape: {hidden_states.shape}")
if is_cross_attention:
loggingg.info(f"Key_value_states shape: {key_value_states.shape}")

# Time query projection
start_time = time.time()
query_states = self.q_proj(hidden_states) * self.scaling
loggingg.info(f"q_proj time: {time.time() - start_time:.6f}s")
loggingg.info(f"Query_states shape after q_proj: {query_states.shape}")

# Time key/value projections
start_time = time.time()
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
key_states = past_key_value[0]
value_states = past_key_value[1]
loggingg.info(f"Using past_key_value - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}")
elif is_cross_attention:
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
loggingg.info(f"Cross-attention - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}")
elif past_key_value is not None:
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
loggingg.info(f"Self-attention with past - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}")
else:
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
loggingg.info(f"Self-attention - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}")
loggingg.info(f"k_v_proj time: {time.time() - start_time:.6f}s")

if self.is_decoder:
past_key_value = (key_states, value_states)

# Time attention reshaping
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
start_time = time.time()
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
loggingg.info(f"attention_reshape time: {time.time() - start_time:.6f}s")
loggingg.info(f"Reshaped Query_states shape: {query_states.shape}")
loggingg.info(f"Reshaped Key_states shape: {key_states.shape}")
loggingg.info(f"Reshaped Value_states shape: {value_states.shape}")

src_len = key_states.size(1)

# Time attention weights computation
start_time = time.time()
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
loggingg.info(f"attn_weights_compute time: {time.time() - start_time:.6f}s")
loggingg.info(f"Attn_weights shape: {attn_weights.shape}")

if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)

# Time attention mask application
start_time = time.time()
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
loggingg.info(f"Attention_mask shape: {attention_mask.shape}")
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
loggingg.info(f"attn_mask time: {time.time() - start_time:.6f}s")

# Time softmax
start_time = time.time()
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
loggingg.info(f"softmax time: {time.time() - start_time:.6f}s")
loggingg.info(f"Attn_weights shape after softmax: {attn_weights.shape}")

# Time layer head mask and dropout
start_time = time.time()
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
loggingg.info(f"Layer_head_mask shape: {layer_head_mask.shape}")
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

if output_attentions:
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None

attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
loggingg.info(f"head_mask_dropout time: {time.time() - start_time:.6f}s")
loggingg.info(f"Attn_probs shape after dropout: {attn_probs.shape}")

# Time attention output computation
start_time = time.time()
attn_output = torch.bmm(attn_probs, value_states)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
loggingg.info(f"attn_output_compute time: {time.time() - start_time:.6f}s")
loggingg.info(f"Attn_output shape: {attn_output.shape}")

# Time output projection
start_time = time.time()
attn_output = self.out_proj(attn_output)
loggingg.info(f"out_proj time: {time.time() - start_time:.6f}s")
loggingg.info(f"Final attn_output shape: {attn_output.shape}")

return attn_output, attn_weights_reshaped, past_key_value

维度变化步骤

1. 输入张量

  • 形状
  • 意义
    • 第一维:批次大小batchsize
    • 第二维:序列长度
    • 第三维:嵌入维度

2. 查询投影(q_proj

  • 形状
  • 计算量

3. 键和值投影(k_proj, v_proj

  • 形状
  • 计算量

4. 重塑为批量形式

  • 形状

5. 注意力权重计算(attn_weights_compute

  • 形状
  • 计算量

实例佐证

日志数据

  • 输入hidden_states shape: [1, 694, 1024]
  • 投影后Query_states shape after q_proj: [1, 694, 1024]
  • 重塑后
    • Reshaped Query_states: [16, 694, 64]
    • Reshaped Key_states: [16, 694, 64]
  • 注意力权重Attn_weights shape: [16, 694, 694]

维度变化验证

✅ 实际日志输出: Query_states shape after q_proj: [1, 694, 1024]
✅ 预期:

Key_states: [1, 16, 694, 64] 符合
Reshaped Query_states: [16, 694, 64] 符合
Attn_weights: [16, 694, 694] 符合

维度与计算时间分析

如果输入序列长度(即处理的 token 数,)变大,那么 (即 attn_weights_compute)的计算时间会显著增加。我会从理论和实际数据两个角度,结合你的日志,分析 增大对 耗时的影响,并给出公式和证据支持。


理论分析

的计算量

  • 维度
    • 输出:
  • 计算量公式
    • 单头:
    • 总计:
  • 自注意力):
  • 关键点
    • 计算量与 成正比,序列长度 的增加会导致计算量平方级增长。

与投影的对比

  • 投影(如 q_proj
    • 输入:,权重:
    • 计算量:
    • 呈线性关系。
  • ****:
    • 计算量:
    • 呈平方关系。
  • 趋势
    • 较小时,(例如 )主导投影耗时, 开销较小。
    • 增大时, 增长更快, 的耗时会逐渐超过投影。

2. 日志数据的佐证

当前数据

  • 参数
  • ****:
    • 形状:
    • 计算量:
    • 耗时:第一次 ,第二次
  • 投影(q_proj
    • 计算量:
    • 耗时:第一次 ,第二次
  • 当前比值
    • 计算量:
    • 耗时(第二次):
    • 时,投影耗时仍大于

增大 的推演

  • **假设 **:

      • 增长倍数:
      • 耗时(按比例估算):
    • 投影:
      • 增长倍数:
      • 耗时:
    • 比值
      • 计算量:
      • 耗时:
      • 时, 耗时接近甚至可能超过投影。

3. 公式说明

  • 投影
    • 时间复杂度:
  • ****:
    • 时间复杂度:
  • 临界点
    • 时,耗时相等。
    • 解得:
    • 时, 耗时开始主导。

5. 结论

  • 序列长度增大
    • 的计算量随 增长,耗时显著增加。
    • 投影耗时随 线性增长,变化较缓。
  • 日志佐证
    • 时, 次,)远低于投影( 次,)。
    • 时, 次)接近投影( 次),耗时可能反超。
  • 公式验证
    • 时, 时间占比显著上升。

附:decoder阶段

假设摘要长度为 100 个 token,则 Decoder 自注意力调用 100 次(每层)。日志中,每次q_proj纬度为[4, 1, 1024]。表示当前只处理 1 个 token(自回归生成)。
此外,还使用跨注意力(cross-attention),区别于自注意力(Self-Attention)。它的核心特点是查询(Query)来自一个序列,而键(Key)和值(Value)来自另一个不同的序列。在自然语言处理中,跨注意力通常用于需要两个不同来源的信息交互的场景,例如编码器-解码器(Encoder-Decoder)结构。
来自一个序列(例如 Decoder 当前 token), 来自另一个序列(例如 Encoder 输出)。这里纬度为[4,16,694,64]

结论

BartAttention 的张量维度变化从 ,最终计算得到 的注意力权重。日志数据验证了这一过程,确保了多头注意力的高效实现。

  • Title: BartAttention
  • Author: Ikko
  • Created at : 2025-03-20 18:16:49
  • Updated at : 2025-03-28 15:29:22
  • Link: http://ikko-debug.github.io/2025/03/20/tra/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments