1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| class Attention(nn.Module): def __init__(self, hidden_size, n_heads, cacheKV, max_batch_size, max_seq_len, device=device): super().__init__() self.n_heads = n_heads self.head_dim = hidden_size // n_heads self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
def forward(self, hidden_states, rotary_emb, start_pos=0, mask=None, is_causal=True): bsz, seqlen, hidden_size = hidden_states.shape
q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim) k = k.view(bsz, seqlen, self.n_heads, self.head_dim) v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
q = rotary_emb(q) k = rotary_emb(k)
q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2)
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, hidden_size) return self.o_proj(output)
|