1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
| class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, config: Optional[BartConfig] = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config
if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" import time import logging as loggingg
is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() loggingg.info(f"Input hidden_states shape: {hidden_states.shape}") if is_cross_attention: loggingg.info(f"Key_value_states shape: {key_value_states.shape}")
start_time = time.time() query_states = self.q_proj(hidden_states) * self.scaling loggingg.info(f"q_proj time: {time.time() - start_time:.6f}s") loggingg.info(f"Query_states shape after q_proj: {query_states.shape}")
start_time = time.time() if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): key_states = past_key_value[0] value_states = past_key_value[1] loggingg.info(f"Using past_key_value - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}") elif is_cross_attention: key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) loggingg.info(f"Cross-attention - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}") elif past_key_value is not None: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) loggingg.info(f"Self-attention with past - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}") else: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) loggingg.info(f"Self-attention - Key_states shape: {key_states.shape}, Value_states shape: {value_states.shape}") loggingg.info(f"k_v_proj time: {time.time() - start_time:.6f}s")
if self.is_decoder: past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) start_time = time.time() query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) loggingg.info(f"attention_reshape time: {time.time() - start_time:.6f}s") loggingg.info(f"Reshaped Query_states shape: {query_states.shape}") loggingg.info(f"Reshaped Key_states shape: {key_states.shape}") loggingg.info(f"Reshaped Value_states shape: {value_states.shape}")
src_len = key_states.size(1) start_time = time.time() attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) loggingg.info(f"attn_weights_compute time: {time.time() - start_time:.6f}s") loggingg.info(f"Attn_weights shape: {attn_weights.shape}")
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" )
start_time = time.time() if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) loggingg.info(f"Attention_mask shape: {attention_mask.shape}") attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) loggingg.info(f"attn_mask time: {time.time() - start_time:.6f}s")
start_time = time.time() attn_weights = nn.functional.softmax(attn_weights, dim=-1) loggingg.info(f"softmax time: {time.time() - start_time:.6f}s") loggingg.info(f"Attn_weights shape after softmax: {attn_weights.shape}")
start_time = time.time() if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) loggingg.info(f"Layer_head_mask shape: {layer_head_mask.shape}") attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions: attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) loggingg.info(f"head_mask_dropout time: {time.time() - start_time:.6f}s") loggingg.info(f"Attn_probs shape after dropout: {attn_probs.shape}")
start_time = time.time() attn_output = torch.bmm(attn_probs, value_states) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) loggingg.info(f"attn_output_compute time: {time.time() - start_time:.6f}s") loggingg.info(f"Attn_output shape: {attn_output.shape}")
start_time = time.time() attn_output = self.out_proj(attn_output) loggingg.info(f"out_proj time: {time.time() - start_time:.6f}s") loggingg.info(f"Final attn_output shape: {attn_output.shape}")
return attn_output, attn_weights_reshaped, past_key_value
|