整体架构

image-20260616160341426

transformer整体架构由编码器和解码器构成。

  1. 编码器(Encoder):负责“理解”输入,它为每个Token生成一个包含上下文信息的向量表示。
  2. 解码器(Decoder):负责“生成”输出,它参考自己已经生成的前文,结合编码器的输出,生成下一个词。

代码实现

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
  """
  位置编码模块
  """
  def forward(self, x):
    pass
    
class MultiHeadAttention(nn.Module):
  """
  多头注意力模块
  """
  def forward(self, query, key, value, mask):
    pass
  
class PositionWiseFeedForward(nn.Module):
  """
  位置前馈网络模块
  """
  def forward(self, x):
    pass
  
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(EncoderLayer, self).__init__()
    self.self_attn = MultiHeadAttention()
    self.feed_forward = PositionWiseFeedForward()
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)
    
  def forward(self, x, mask):
    # 1.多头注意力
    attn_output = self.self_attn(x, x, x, mask);
    x = self.norm1(x + self.dropout(attn_output))
    # 2.前馈网络
    ff_output = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_output))
    
    return x
  
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(DecoderLayer, self).__int__()
    self.self_attn = MultiHeadAttention()
    self.cross_attn = MultiHeadAttention()
    self.feed_forward = PositionWiseFeedForward()
		self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    
    self.dropout = nn.dropout(dropout)
    
  def forward(self, x, encoder_output, src_mask, tgt_mask):
    # 1.掩码多头注意力
    attn_output = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(attn_output))
    
    # 2.交叉注意力
    cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
    x = self.norm2(x + self.dropout(cross_attn_output))
    
    # 3.前馈网络
    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    
    return x;
  

核心概念

自注意力(Self-Attention)

自注意力(Self-Attention)允许模型在处理每一个Token时,都关注到其他Token,并给每个其他Token分配不同的权重。权重越高的Token,表示其与当前Token关联性越高。

自注意力通过为每个Token引入三个可学习的向量来实现:

  • 查询(Query,Q):代表当前Token,它正在主动“查询”其他Token的信息。
  • 键(Key,K):代表句子中被查询的Token的索引。
  • 值(value,V):代表被查询的Token的信息。

这三个向量由原始的Token向量乘3个不同的可学习权重矩阵($W^Q, W^K, W^V$)得到,整个计算过程可以描述为:

  1. 计算QKV:对于句子中的每个词,由权重矩阵生成 $Q, K, V$向量。
  2. 计算相关性得分:将A与所有K向量点积运算(包括A自己),得到其他词对于A的相关性分数。
  3. 稳定化与归一化:将所有分数处以一个缩放因子 $\sqrt{d_k}$,防止梯度过小,然后用Softmax函数将Logits分数转换为概率分布。
  4. 加权求和:将上一步得到的每个权重乘每个词对应的V向量,并将所有结果相加,就得到了最终的A的向量表示。

用公式表示整个过程就是:
$\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$

多头注意力(Multi-Head-Attention)

多头注意力可以让模型学到多组注意力权重,它将原始的QKV在维度上切分成h(头数)份,每一份独立计算注意力,再将h个输出向量拼接、线性整个,得到最终输出。

image-20260616172936420

class MultiHeadAttention(nn.Module):
  """
  多头注意力机制模块
  """
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % num_heads == 0, "d_mdoel 需要被num_heads 整除"
    
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads
    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)
    
  def scaled_dot_product_attenttion(self, Q, K, V, mask = None):
    # 1. 计算注意力得分
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)/math.sqrt(self.d_k))
    # 2. 应用掩码
    if mask is not None:
      attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
      
    # 3.计算注意力权重
    attn_probs = torch.softmax(attn_scores, dim = -1)
    # 4.加权求和
    output = torch.matmul(attn_probs, V)
    return output
  
  def split_heads(self, x):
    # 变化输入形状(batch_size, seq_length, d_model) 为 (batch_size, num_heads, seq_length, d_k)
    batch_size, seq_length, d_model = x.size()
    return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
  
  def combine_heads(self, x):
    batch_size, num_heads, seq_length, d_k = x.size()
    return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
  
  def forwards(self, Q, K, V, mask = None):
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))
    
    attn_output = self.scaled_dot_product_attenttion(Q, K, V, mask)
    output = self.W_o(self.combine_heads(attn_output))
    return output