残差网络ResNet原理深度解读:连小学生都能看懂的近路哲学
残差网络 ResNet 原理深度解读连小学生都能看懂的近路哲学2015年He Kaiming 等人发表了一篇论文标题只有两个词《Deep Residual Learning》。同年 ImageNet 挑战赛他们用这个方法拿下了图像分类、检测、定位三项冠军错误率低至 3.6%首次低于人类的直观判断。这篇解释不需要你懂高等数学只需要你有过这些生活体验搬家时漏了东西怎么办、抄近路比绕路快、组队完成任务比一个人扛更稳。一、深度学习为什么会力不从心普通网络像什么一条越来越窄的流水线想象你在组装手机。第1个人负责焊接零件第2个人负责装屏幕第3个人负责测试…流水线越来越长每个工位都在加工前一个人传过来的东西。问题来了如果第1个人焊错了后面所有人都在用错误零件作业越往后错得越离谱。这就是普通深层网络的问题信息从第一层传到最后一层每层都要加工一次传到最后层时原始信号已经变味了。梯度消失像什么老板的消息在群里传丢了你发消息给老板“明早9点开会”。老板转发给主管“明早9点开会”。主管转发给组长“明早开会”。组长转发给员工“开会”。员工不知道几点来。老板原始消息在传递过程中消失了每层都改一点最后面目全非。这就是梯度消失反向传播时梯度信号每经过一层就衰减一点传到第一层时几乎归零。最前面的参数根本收不到老板的指令不知道怎么调整。退化问题不是过拟合是连拟合都做不到有人会问深度网络至少应该和浅层网络性能一样吧最差的情况我让后面的层什么都不做把前面学好的结果直接传过去不行吗不行。因为网络有非线性层。让后面的层什么都不做恒等映射比让它学一个新的变换更难。这听起来违反直觉举个例子你想抄近路直接走到对面 → 你知道目的地路径清晰一句话的事你被迫绕路先把地图背下来然后在脑子里模拟走一遍路线 → 信息必须经过理解-编码-存储流程多了好几层处理网络里的非线性层就是那个强迫你绕路的流程。它让简单的直接传过去变得很麻烦。这就是退化问题Degradation Problem网络越深性能反而下降。不是因为过拟合而是因为连抄近路都做不到。二、残差连接给信息开一条近路H(x) F(x) x 是什么先别被公式吓跑我们把它翻译成人话H(x)这个残差块最终要输出的东西最终产品x原始输入原材料F(x)这个残差块真正做的变换加工过程 x把加工后的东西和原材料加在一起最终出厂还是搬家公司的例子你有10个箱子工人负责把箱子从A点搬到B点。普通做法工人全部从A搬到B走同样的路线最前面的人如果搬错后面的人都在用错的箱子继续搬。残差做法工人先搬同时留一个人直接拎箱子走过去近路。最后B点收到的是工人搬的箱子 直接拎过来的箱子两者合并。近路Shortcut保证了就算工人搬错了最原始的箱子还是有一条通路能到B点。F(x) x 的直观理解生活中找类比老师批作业。普通网络老师把学生作业收上来撕掉原题只留学生写的答案然后从头推算学生原来做了什么。结果你根本不知道学生原来写了什么。残差网络老师把学生作业和标准答案一起看只批改学生写错的部分。改完的分数 错的分数 基准分。基准分就是 x原始输入改错的分数就是 F(x)残差部分。老师只需要专注改错不用从头算整道题。跳跃连接Skip Connection是什么跳跃 跳过某些层直接把前面学到的信息传过来。生活类比微信群里的回复这条消息功能。你在一堆聊天记录里看到一条消息被标记为重要回复。你直接点进去跳过了中间几百条废话直接看这条重点。跳跃连接就是这样后面的层可以直接点进前面的信息不用一层层爬楼。三、从代码理解残差块最简单的残差块importtorchimporttorch.nnasnnclassSimpleResidualBlock(nn.Module): 最基础的残差块F(x) x def__init__(self,channels):super().__init__()self.conv1nn.Conv2d(channels,channels,3,padding1)self.bn1nn.BatchNorm2d(channels)self.conv2nn.Conv2d(channels,channels,3,padding1)self.bn2nn.BatchNorm2d(channels)self.relunn.ReLU(inplaceTrue)defforward(self,x):residualx# 第一步把原材料单独留一份xoutself.relu(self.bn1(self.conv1(x)))# 第二步F(x) 第一层outself.bn2(self.conv2(out))# 第三步F(x) 第二层outoutresidual# 第四步加工完的 原材料outself.relu(out)# 第五步激活出厂returnout这5步翻译成人话留一份原材料过一个非线性层再过一个非线性层把加工结果和原材料合并激活出厂核心就是第4步把 x原材料和 F(x)加工品加在一起。维度不匹配怎么办近路也有窄路和宽路有时候从A点到B点近路是条小路走不了大卡车。这时候你需要投影Projection把大卡车的东西装到小车上走小路运过去再卸到卡车上。classResidualBlockWithProjection(nn.Module): 输入输出维度不同时的残差块需要投影近路 def__init__(self,in_channels,out_channels,stride1):super().__init__()# 主路加工线self.conv1nn.Conv2d(in_channels,out_channels,3,stridestride,padding1)self.bn1nn.BatchNorm2d(out_channels)self.conv2nn.Conv2d(out_channels,out_channels,3,padding1)self.bn2nn.BatchNorm2d(out_channels)self.relunn.ReLU(inplaceTrue)# 近路Shortcut维度不同时需要投影成一样大self.shortcutnn.Sequential()ifstride!1orin_channels!out_channels:self.shortcutnn.Sequential(nn.Conv2d(in_channels,out_channels,1,stridestride),nn.BatchNorm2d(out_channels))defforward(self,x):residualself.shortcut(x)# 近路先过一遍把尺寸对齐outself.relu(self.bn1(self.conv1(x)))outself.bn2(self.conv2(out))outoutresidualreturnself.relu(out)什么时候维度会变下采样时图片从224×224缩小到112×112通道数通常会增加。就像搬家时把所有小箱子合并成几个大箱子近路必须跟着调整。四、梯度高速公路为什么残差网络不会失联普通网络的梯度传播爬楼梯停电了想象你在一栋100层的楼里电梯坏了楼梯停电了。你要从100层往下走每走一层手机就掉一格电。走到第50层电没了你困在中间上不去下不来。普通深层网络的梯度传播就是这样经过100层每层衰减一点传到第1层时梯度几乎归零第一层的参数困住了不知道该往哪个方向调。残差网络的梯度传播有备用电源同样是100层楼残差网络在每10层留了一条直接到地面的滑梯。就算走到第20层没电了你可以滑到10层继续往下走滑到地面。梯度也是同样的道理梯度 主路梯度 近路梯度恒为1就算主路梯度衰减到接近零近路梯度永远等于1直接传回去。数学上∂L/∂x ∂L/∂H · (∂H/∂F · ∂F/∂x 1)不管 ∂H/∂F · ∂F/∂x 多小1保证梯度不会消失。为什么加法能让梯度跳过层生活类比老板让你做报告你不是一个人写完交差而是你先写初稿秘书帮你校对格式财务帮你核实数据法务帮你检查合规每一步都有**“原始需求”**作为参照不是每个人从零理解老板的意思。残差连接就是那个让每层都能看到老板原始需求的机制——原始输入 x 作为基准始终存在梯度可以顺着这个基准快速回传。五、ResNet 的整体结构4个车间层层递进ResNet 像一条工厂流水线ResNet-50 的结构可以这样理解图片进来 → 粗加工车间(conv1) → 细加工车间1(stage1,3个残差块) → 细加工车间2(stage2,4个残差块) → 细加工车间3(stage3,6个残差块) → 细加工车间4(stage4,3个残差块) → 打包出厂(fc)每个车间之间会发生什么图片尺寸缩小从224→112→56→28→14→7通道数增加从64→128→256→512。这就是下采样。工厂需要把原材料逐步缩小、压缩、提纯。瓶颈残差块Bottleneck先压缩再加工ResNet-50 和 ResNet-101 用的不是基础残差块而是瓶颈块。类比你要把一张高清照片压缩成表情包。普通做法原图 → 模糊化 → 压缩 → 表情包直接压缩信息损失大瓶颈块做法原图 → 压缩成小图1×1降维→ 精细处理3×3卷积→ 放大回原尺寸1×1升维→ 表情包瓶颈块结构1×1卷积压缩→ 3×3卷积加工→ 1×1卷积还原这样可以用更少的计算量达到同样的效果所以 ResNet-50 可以有50层而 ResNet-18 只有18层18层用基础块就够了。各型号 ResNet 对比模型层数残差块数参数量ImageNet top-1 错误率ResNet-18188 (基础块)11.7M~30%ResNet-343416 (基础块)21.8M~26%ResNet-505016 (瓶颈块)25.6M~24%ResNet-10110133 (瓶颈块)44.5M~23%ResNet-15215250 (瓶颈块)60.2M~22%规律层数越深 → 错误率越低但参数量也在涨。六、训练 ResNet 的踩坑指南坑1BatchNorm 和 ReLU 的顺序错误做法Conv → ReLU → BN正确做法Conv → BN → ReLU类比工厂质量检测应该在发货前做而不是发货后做。BN 放在激活函数前面让激活函数在归一化后的分布上工作减少信息损失。# ✅ 正确outself.bn1(self.conv1(x))outself.relu(out)# ❌ 错误outself.relu(out)outself.bn1(self.conv1(x))坑2学习率别照抄普通网络普通网络训练时学习率 0.1 可能合适。ResNet 因为梯度流动更强参数更新幅度更大如果用同样的学习率容易震荡。建议用余弦退火让学习率平缓下降不要断崖式跌落。optimizertorch.optim.SGD(model.parameters(),lr0.05,# ResNet 建议从 0.05 开始momentum0.9,weight_decay1e-4)schedulertorch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max90)坑3权重初始化用 He 初始化普通网络可以用 Xavier 初始化。ResNet 用 He 初始化Kaiming Normal专门针对 ReLU 设计definit_weights(m):ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight,modefan_out,nonlinearityrelu)elifisinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)model.apply(init_weights)七、ResNet 之后发生了什么残差连接的思想催生了后续一整套模型进化路线1让残差更密集 → DenseNet残差连接是 Add加法信息是融合。DenseNet 是 Concat拼接信息是堆叠。类比ResNet把A同学做的笔记和B同学批注加在一起 → 两人合作的产出DenseNet把A的笔记、B的笔记、C的笔记全部钉在一起 → 三本笔记的完整集合classDenseBlock(nn.Module):DenseNet 的密集连接每一层的输出都和所有前面的层拼接def__init__(self,in_channels,growth_rate,num_layers):super().__init__()self.layersnn.ModuleList()foriinrange(num_layers):self.layers.append(self._make_layer(in_channelsi*growth_rate,growth_rate))def_make_layer(self,in_channels,growth_rate):returnnn.Sequential(nn.BatchNorm2d(in_channels),nn.ReLU(inplaceTrue),nn.Conv2d(in_channels,growth_rate,3,padding1))defforward(self,x):features[x]forlayerinself.layers:new_featurelayer(torch.cat(features,dim1))features.append(new_feature)returntorch.cat(features,dim1)DenseNet 的优势参数利用率更高相同性能下参数量更少。缺点是显存占用大所有层输出都堆在一起。路线2让残差关注重要通道 → SE-ResNetSqueeze-and-ExcitationSE模块让网络学会判断哪些通道重要、哪些可以忽略。类比老师在批改作业时有些学生字迹清晰重要通道有些学生写得很潦草噪声通道。SE 模块就是那个帮老师快速识别重点的工具。classSEBlock(nn.Module):通道注意力让网络自己决定哪些通道值得重点关注def__init__(self,channels,reduction16):super().__init__()self.avg_poolnn.AdaptiveAvgPool2d(1)self.fcnn.Sequential(nn.Linear(channels,channels//reduction),nn.ReLU(inplaceTrue),nn.Linear(channels//reduction,channels),nn.Sigmoid())defforward(self,x):b,c,_,_x.size()# 全局压缩把每个通道的信息压成一个数yself.avg_pool(x).view(b,c)# 重新赋权每个通道的重要性打分yself.fc(y).view(b,c,1,1)returnx*y.expand_as(x)路线3Transformer 里的残差 → ViT2020 年Vision TransformerViT把残差连接带入了 Transformer 架构。在 Transformer 里残差连接以LayerNorm Add的形式存在classTransformerBlock(nn.Module):Transformer 中的残差LayerNorm → 算子 → Add → LayerNorm → FFN → Adddef__init__(self,embed_dim,num_heads):super().__init__()self.norm1nn.LayerNorm(embed_dim)self.attnnn.MultiheadAttention(embed_dim,num_heads)self.norm2nn.LayerNorm(embed_dim)self.ffnnn.Sequential(nn.Linear(embed_dim,embed_dim*4),nn.GELU(),nn.Linear(embed_dim*4,embed_dim))defforward(self,x):# 残差连接原始 x 和注意力输出相加xxself.attn(self.norm1(x))[0]# 残差连接原始 x 和 FFN 输出相加xxself.ffn(self.norm2(x))returnxResNet 的跳跃连接思想在 Transformer 时代以 LayerNorm Add 的形式继续发光发热。八、生产环境实战用 PyTorch 跑起来加载预训练模型5行代码importtorchvision.modelsasmodels# 加载 ResNet-50用 ImageNet 预训练的权重modelmodels.resnet50(weightsIMAGENET1K_V2)# 替换最后的分类头原来分1000类我们分10类model.fcnn.Linear(model.fc.in_features,10)微调Fine-tune只训练最后几层# 冻结前面的层只训练分类头forparaminmodel.parameters():param.requires_gradFalse# 只解冻分类头forparaminmodel.fc.parameters():param.requires_gradTrueoptimizertorch.optim.Adam(model.fc.parameters(),lr1e-3)完整训练流程importtorchvision.transformsasT# 数据增强这步很关键决定了模型泛化能力train_transformT.Compose([T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ColorJitter(0.3,0.3,0.2),# 颜色抖动让模型不认颜色T.ToTensor(),T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])train_dstorchvision.datasets.ImageFolder(data/train,train_transform)train_loadertorch.utils.data.DataLoader(train_ds,batch_size64,shuffleTrue,num_workers4)# 训练forepochinrange(90):model.train()forimages,labelsintrain_loader:images,labelsimages.cuda(),labels.cuda()optimizer.zero_grad()lossnn.CrossEntropyLoss()(model(images),labels)loss.backward()optimizer.step()scheduler.step()九、总结残差的本质是什么ResNet 的核心公式只有一行H(x) F(x) x但这个简单的加法解决了一个根本问题把学什么变成了改什么。普通网络学习完整的 H(x)从零画一幅画残差网络学习 H(x) - x只画错了的部分近路的价值让梯度有了高速公路100层的训练成为可能残差思想的影响从 ResNet 到 DenseNet、SE-Net、ViT跳跃连接无处不在。记住这个比喻老师批作业不是从头重做一遍而是只改错的地方。近路Shortcut保证了就算改错了原文还在。参考资料He Kaiming et al., “Deep Residual Learning for Image Recognition”, CVPR 2016He Kaiming et al., “Identity Mappings in Deep Residual Networks”, ECCV 2016Gao Huang et al., “Densely Connected Convolutional Networks”, CVPR 2017Jie Hu et al., “Squeeze-and-Excitation Networks”, CVPR 2018