别再死记硬背Self-Attention公式了!用Python手搓一个Transformer核心模块(附完整代码)
从零实现Self-Attention用NumPy拆解Transformer核心逻辑当第一次看到Transformer论文中那个著名的Self-Attention公式时相信不少开发者都有过这样的困惑这些矩阵乘法究竟在做什么为什么需要Q、K、V三个矩阵多头机制又该如何理解本文将通过纯Python实现带你亲手构建一个可运行的Self-Attention模块用代码而非公式回答这些问题。1. 环境准备与基础概念在开始编码前我们需要明确几个关键概念。Self-Attention的本质是一种动态特征加权机制——它让模型能够根据输入序列中不同位置的重要性自动调整关注度。想象你在阅读一段文字时大脑会不自觉地对关键词给予更多注意力Self-Attention正是模拟这种认知过程。必备工具安装pip install numpy matplotlib基础实现只需要NumPy库但为了验证结果准确性我们可以准备一个对照环境import numpy as np from numpy.random import randn # 设置随机种子保证可复现性 np.random.seed(42)2. 单头注意力实现让我们从最基础的Scaled Dot-Product Attention开始。这个名称包含三个关键信息Dot-Product使用点积计算相似度Scaled对结果进行缩放防止梯度消失Attention最终形成注意力权重2.1 输入投影层首先实现将输入转换为Q、K、V的线性变换def input_projection(X, d_model64, d_k8, d_v8): X: 输入序列 [batch_size, seq_len, d_model] 返回: Q, K, V 投影矩阵 batch_size, seq_len, _ X.shape WQ randn(d_model, d_k) * 0.1 WK randn(d_model, d_k) * 0.1 WV randn(d_model, d_v) * 0.1 Q X WQ # [batch_size, seq_len, d_k] K X WK # [batch_size, seq_len, d_k] V X WV # [batch_size, seq_len, d_v] return Q, K, V2.2 注意力计算核心接下来实现注意力权重的计算过程def scaled_dot_product_attention(Q, K, V, maskNone): d_k Q.shape[-1] scores Q K.transpose(0,1,3,2) / np.sqrt(d_k) # [batch_size, seq_len, seq_len] if mask is not None: scores scores.masked_fill(mask 0, -1e9) weights softmax(scores, axis-1) # 沿最后一个维度做softmax output weights V # [batch_size, seq_len, d_v] return output, weights def softmax(x, axis-1): e_x np.exp(x - np.max(x, axisaxis, keepdimsTrue)) return e_x / e_x.sum(axisaxis, keepdimsTrue)注意实际应用中需要添加mask机制处理变长序列这里简化实现3. 多头注意力机制多头注意力就像让模型拥有多组眼睛可以从不同角度观察数据。以下是关键实现步骤3.1 多头投影class MultiHeadAttention: def __init__(self, d_model64, h8): self.d_model d_model self.h h assert d_model % h 0, d_model必须能被h整除 self.d_k d_model // h self.WQ randn(d_model, d_model) * 0.1 self.WK randn(d_model, d_model) * 0.1 self.WV randn(d_model, d_model) * 0.1 self.WO randn(d_model, d_model) * 0.1 def split_heads(self, x): batch_size x.shape[0] return x.reshape(batch_size, -1, self.h, self.d_k).transpose(0,2,1,3) def forward(self, X, maskNone): Q X self.WQ K X self.WK V X self.WV Q self.split_heads(Q) # [batch_size, h, seq_len, d_k] K self.split_heads(K) V self.split_heads(V) # 计算缩放点积注意力 attn_output, attn_weights scaled_dot_product_attention(Q, K, V, mask) # 合并多头结果 batch_size attn_output.shape[0] attn_output attn_output.transpose(0,2,1,3).reshape(batch_size, -1, self.d_model) return attn_output self.WO, attn_weights3.2 效果验证让我们用实际数据测试这个实现# 模拟输入数据 batch_size 2 seq_len 10 d_model 64 X randn(batch_size, seq_len, d_model) # 初始化多头注意力 mha MultiHeadAttention(d_modeld_model, h8) # 前向计算 output, weights mha.forward(X) print(f输入形状: {X.shape}) print(f输出形状: {output.shape}) print(f注意力权重形状: {weights.shape})典型输出结果输入形状: (2, 10, 64) 输出形状: (2, 10, 64) 注意力权重形状: (2, 8, 10, 10)4. 与框架实现对比为了验证我们的实现是否正确可以与PyTorch官方实现进行对比import torch import torch.nn as nn # 使用相同输入数据 X_torch torch.from_numpy(X) # PyTorch多头注意力层 mha_torch nn.MultiheadAttention(embed_dimd_model, num_heads8, batch_firstTrue) output_torch, _ mha_torch(X_torch, X_torch, X_torch) # 比较结果差异 diff np.mean(np.abs(output.detach().numpy() - output_torch.detach().numpy())) print(f与PyTorch实现的平均差异: {diff:.6f})提示实际差异主要来自初始化方式不同核心计算逻辑应该保持一致5. 常见问题与调试技巧在实现过程中开发者常遇到以下几个典型问题5.1 梯度消失问题当维度较大时点积结果可能变得极大导致softmax后某些位置接近0或1。解决方法使用缩放因子1/√d_k添加微小epsilon值防止数值不稳定5.2 内存占用优化注意力矩阵的大小为O(n²)对于长序列实现分块计算使用稀疏注意力模式考虑线性注意力变体5.3 训练不稳定多头注意力可能出现的训练问题# 解决方案示例添加LayerNorm class TransformerLayer: def __init__(self, d_model, h): self.self_attn MultiHeadAttention(d_model, h) self.norm1 LayerNorm(d_model) self.norm2 LayerNorm(d_model) def forward(self, x): attn_out, _ self.self_attn(x) x self.norm1(x attn_out) return x6. 完整实现与扩展将上述模块组合成完整实现class TransformerEncoderLayer: def __init__(self, d_model512, h8, d_ff2048): self.self_attn MultiHeadAttention(d_model, h) self.ffn PositionwiseFFN(d_model, d_ff) self.norm1 LayerNorm(d_model) self.norm2 LayerNorm(d_model) def forward(self, x, maskNone): # 自注意力子层 attn_out, _ self.self_attn(x, mask) x self.norm1(x attn_out) # 前馈网络子层 ffn_out self.ffn(x) x self.norm2(x ffn_out) return x class PositionwiseFFN: def __init__(self, d_model, d_ff): self.w1 randn(d_model, d_ff) * 0.1 self.w2 randn(d_ff, d_model) * 0.1 def forward(self, x): return x self.w1 self.w2在实际项目中这种从零实现的方式虽然性能不如优化后的框架代码但它带来的理解深度是无可替代的。当我在处理一个序列标注任务时正是通过这种手写实现才真正理解了为什么某些位置的注意力权重会异常偏高——原来是输入数据中存在特殊标记导致的注意力聚焦。