【TPAMI 2026即插即用模块】DSWA 可变形滑动窗口注意力机制,适合图像恢复(All-in-One)、图像超分辨率、语义分割与实例分割、目标检测、图像增强、图像分类等CV任务通用,涨点起飞!
一、论文信息本文目录一、论文信息二、论文摘要概况三、DSWA 可变形滑动窗口注意力机制结构图四、DSWA 模块的作用五、DSWA 模块的原理六、DSWA 模块的优势七、即插即用模块代码论文题目DSwinIR: Rethinking Window-based Attention forImage Restoration中文题目DSwinIR重新思考基于窗口的注意力机制在图像恢复中的应用所属单位哈尔滨工业大学二、论文摘要概况随着深度学习模型的发展图像修复领域取得了显著进展。基于Transformer的模型尤其是采用窗口自注意力机制的模型已成为主流技术。然而其性能受限于僵化的非重叠窗口划分方案导致窗口间特征交互不足且感受野范围有限。这凸显了开发更具自适应性和灵活性的注意力机制的必要性。本文提出了一种新型注意力机制——可变形滑动窗口TransformerDSwinIR即可变形滑动窗口DSwin注意力机制。该机制采用以标记为中心且具备内容感知能力的设计范式突破了传统网格和固定窗口划分的局限由两个互补组件构成首先用以标记为中心的滑动窗口框架替代传统划分方式有效消除边界伪影其次引入内容感知型可变形采样策略使注意力机制能够学习数据依赖性的偏移量并主动调整感受野范围以聚焦最具信息价值的图像区域。大量实验表明DSwinIR在多项评估基准中均取得优异性能在综合图像修复任务中更在三任务基准上超越最新主流模型GridFormer 0.53 dB在五任务基准上超越0.87 dB。相关代码及预训练模型可访问 https://github.com/Aitical/DSwinIR 。图1. 以锚定标记用⋆标示为参考点对特征提取机制进行对比分析。(a)传统卷积采用固定的采样模式利用邻域特征(b)可变形卷积根据内容特性引入自适应采样位置能更有效地整合相关区域的特征(c)窗口注意力机制存在边界限制位于窗口边缘尤其是角落的锚定标记其感受野范围有限(d)提出的可变形滑动窗口DSwin注意力机制通过以标记为中心的设计框架和内容感知采样方式扩展了窗口注意力机制为锚定标记提供了稳健的特征聚合能力。三、DSWA 可变形滑动窗口注意力机制结构图图3. 所提出的DSwinIR架构概述。该模型基于U形架构设计其核心组件为DSwin变换器模块DSTB。主要模块包括(a) 可变形滑动窗口注意力机制DSwin通过学习内容依赖性偏移量来自适应采样特征(b) 多尺度可变形滑动窗口注意力机制MS-DSwin可在多个注意力头之间整合多尺度DSwin注意力(c) 多尺度门控前馈网络MSG- FFN利用并行卷积分支增强特征表示能力。四、DSWA 模块的作用作为新型窗口注意力模块替代传统固定窗口自注意力专门服务于图像恢复任务的特征提取。针对性解决传统窗口注意力的两大核心缺陷窗口边界导致的上下文截断、固定方形感受野无法适配图像内容。构成 DSwinIR 骨干网络的核心组件支撑单任务、多任务统一、复合退化等多种图像恢复场景的精度提升。图4. DSwinIR中内容感知采样机制的可视化示意图。该图对我们提出的DSwin注意力机制进行了定性分析(a)输入图像红色圆点标示局部分析所需的锚点位置(b)对应的局部采样模式显示参考网格、自适应采样点及学习得到的偏移量(c)形变幅度图以热图形式呈现偏移距离d √[p²dx dy²]其中颜色越亮表示偏移量越大。五、DSWA 模块的原理以 token 为中心的滑动窗口机制放弃传统不重叠的网格窗口划分方式改为以每个查询 token 为中心点截取固定大小的邻域作为该 token 的注意力计算范围。不同 token 的注意力窗口相互重叠天然打通窗口边界的信息阻隔让边缘位置的 token 也能获取边界外的上下文特征。内容感知的可变形采样机制通过一个轻量网络根据每个查询 token 的特征为其邻域内的所有采样点预测空间偏移量。实际采样时不再遵循规则的网格位置而是沿着预测的偏移量动态调整采样点主动对齐雨线、物体轮廓、雾区等图像结构与退化区域。采用可微分的双线性插值方式获取偏移后的特征保证整个模块可以端到端训练无需额外的正则约束。多尺度扩展设计借助多头注意力的结构给不同分组的注意力头分配不同大小的滑动窗口让模块同时捕捉局部精细纹理和大范围上下文实现多尺度特征融合。六、DSWA 模块的优势根除窗口边界问题滑动重叠的设计彻底消除了传统窗口注意力的边界截断效应避免边缘特征失真减少恢复结果中的窗口伪影。动态自适应建模感受野形状和范围由图像内容自动决定复杂区域主动扩大信息搜索范围平坦区域保持局部高效运算建模精准度远高于固定模式的窗口注意力。算力性价比突出在多项任务上取得最优性能的同时参数量、计算量和推理延迟均低于多数同性能模型相比主流方案算力开销降低 15%~40%。场景适配性广既在去噪、去雨、去雾等单一任务上表现优异也能很好适配多任务统一恢复、复合退化、真实场景退化等复杂场景性能增益稳定。组件协同效应强滑动窗口与可变形采样并非简单叠加二者结合带来的性能提升大于单独改进的总和搭配多尺度门控前馈网络可进一步强化效果。表II为设置1的定量比较结果包含三项不同的降解任务。*号表示引用自先前文献[15][74]的结果。图5. 三种退化处理任务下修复效果的视觉对比噪声消除上排、雨痕消除中排及去雾处理下排。放大区域以彩色框标示表明我们的方法在细节保留和退化痕迹消除方面均表现优异。七、即插即用模块代码import torch import torch.nn as nn import torch.nn.functional as F import einops from timm.models.layers import to_2tuple, trunc_normal_ from natten.functional import na2d_qk, na2d_av FUSED True try: from natten.functional import na2d except ImportError: FUSED False print(natten 0.17 not installed, using dummy implementation) class LayerNormFunction(torch.autograd.Function): staticmethod def forward(ctx, x, weight, bias, eps): ctx.eps eps N, C, H, W x.size() mu x.mean(1, keepdimTrue) var (x - mu).pow(2).mean(1, keepdimTrue) y (x - mu) / (var eps).sqrt() ctx.save_for_backward(y, var, weight) y weight.view(1, C, 1, 1) * y bias.view(1, C, 1, 1) return y staticmethod def backward(ctx, grad_output): eps ctx.eps N, C, H, W grad_output.size() y, var, weight ctx.saved_variables g grad_output * weight.view(1, C, 1, 1) mean_g g.mean(dim1, keepdimTrue) mean_gy (g * y).mean(dim1, keepdimTrue) gx 1. / torch.sqrt(var eps) * (g - y * mean_gy - mean_g) return gx, (grad_output * y).sum(dim3).sum(dim2).sum(dim0), grad_output.sum(dim3).sum(dim2).sum( dim0), None class LayerNorm2d(nn.Module): def __init__(self, channels, eps1e-6): super(LayerNorm2d, self).__init__() self.register_parameter(weight, nn.Parameter(torch.ones(channels))) self.register_parameter(bias, nn.Parameter(torch.zeros(channels))) self.eps eps def forward(self, x): return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) class DeformableNeighborhoodAttention(nn.Module): def __init__( self, dim: int, num_heads: int, kernel_size: int, dilation: int 1, offset_range_factor1.0, stride1, use_peTrue, dwc_peTrue, no_offFalse, fixed_peFalse, is_causal: bool False, rel_pos_bias: bool False, attn_drop: float 0.0, proj_drop: float 0.0, ): super().__init__() n_head_channels dim // num_heads n_groups num_heads self.dwc_pe dwc_pe self.n_head_channels n_head_channels self.scale self.n_head_channels ** -0.5 self.n_heads num_heads self.nc n_head_channels * num_heads self.n_groups num_heads self.n_group_channels self.nc // self.n_groups self.n_group_heads self.n_heads // self.n_groups self.use_pe use_pe self.fixed_pe fixed_pe self.no_off no_off self.offset_range_factor offset_range_factor self.ksize kernel_size self.kernel_size (kernel_size, kernel_size) self.stride stride self.dilation dilation self.is_causal is_causal kk self.ksize pad_size kk // 2 if kk ! stride else 0 self.conv_offset nn.Sequential( nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groupsself.n_group_channels), LayerNorm2d(self.n_group_channels), nn.GELU(), nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, biasFalse) ) if self.no_off: for m in self.conv_offset.parameters(): m.requires_grad_(False) self.proj_q nn.Conv2d( self.nc, self.nc, kernel_size1, stride1, padding0 ) self.proj_k nn.Conv2d( self.nc, self.nc, kernel_size1, stride1, padding0 ) self.proj_v nn.Conv2d( self.nc, self.nc, kernel_size1, stride1, padding0 ) self.proj_out nn.Conv2d( self.nc, self.nc, kernel_size1, stride1, padding0 ) if rel_pos_bias: self.rpb nn.Parameter( torch.zeros( num_heads, (2 * self.kernel_size[0] - 1), (2 * self.kernel_size[1] - 1), ) ) trunc_normal_(self.rpb, std0.02, mean0.0, a-2.0, b2.0) else: self.register_parameter(rpb, None) self.proj_drop nn.Dropout(proj_drop, inplaceTrue) self.attn_drop nn.Dropout(attn_drop, inplaceTrue) self.rpe_table nn.Conv2d( self.nc, self.nc, kernel_size3, stride1, padding1, groupsself.nc) torch.no_grad() def _get_ref_points(self, H_key, W_key, B, dtype, device): ref_y, ref_x torch.meshgrid( torch.linspace(0.5, H_key - 0.5, H_key, dtypedtype, devicedevice), torch.linspace(0.5, W_key - 0.5, W_key, dtypedtype, devicedevice), indexingij ) ref torch.stack((ref_y, ref_x), -1) ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0) ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0) ref ref[None, ...].expand( B * self.n_groups, -1, -1, -1) # B * g H W 2 return ref torch.no_grad() def _get_q_grid(self, H, W, B, dtype, device): ref_y, ref_x torch.meshgrid( torch.arange(0, H, dtypedtype, devicedevice), torch.arange(0, W, dtypedtype, devicedevice), indexingij ) ref torch.stack((ref_y, ref_x), -1) ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0) ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0) ref ref[None, ...].expand( B * self.n_groups, -1, -1, -1) # B * g H W 2 return ref def forward(self, x): B, C, H, W x.size() dtype, device x.dtype, x.device q self.proj_q(x) q_off einops.rearrange( q, b (g c) h w - (b g) c h w, gself.n_groups, cself.n_group_channels) offset self.conv_offset(q_off).contiguous() # B * g 2 Hg Wg Hk, Wk offset.size(2), offset.size(3) if self.offset_range_factor 0 and not self.no_off: offset_range torch.tensor( [1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], devicedevice).reshape(1, 2, 1, 1) offset offset.tanh().mul(offset_range).mul(self.offset_range_factor) offset einops.rearrange(offset, b p h w - b h w p) reference self._get_ref_points(Hk, Wk, B, dtype, device) if self.no_off: offset offset.fill_(0.0) if self.offset_range_factor 0: pos offset reference else: pos (offset reference).clamp(-1., 1.) if self.no_off: x_sampled F.avg_pool2d( x, kernel_sizeself.stride, strideself.stride) assert x_sampled.size(2) Hk and x_sampled.size( 3) Wk, fSize is {x_sampled.size()} else: x_sampled F.grid_sample( inputx.reshape(B * self.n_groups, self.n_group_channels, H, W), gridpos[..., (1, 0)], # y, x - x, y modebilinear, align_cornersTrue) # B * g, Cg, Hg, Wg x_sampled x_sampled.reshape(B, C, H, W) residual_lepe self.rpe_table(q) if self.rpb is not None or not FUSED: q einops.rearrange(q, b (g c) h w - b g h w c, gself.n_groups, bB, cself.n_group_channels, hH, wW) k einops.rearrange(self.proj_k(x_sampled), b (g c) h w - b g h w c, gself.n_groups, bB, cself.n_group_channels, hH, wW) v einops.rearrange(self.proj_v(x_sampled), b (g c) h w - b g h w c, gself.n_groups, bB, cself.n_group_channels, hH, wW) q q*self.scale attn na2d_qk( q, k, kernel_sizeself.kernel_size, dilationself.dilation, is_causalself.is_causal, rpbself.rpb, ) attn attn.softmax(dim-1) attn self.attn_drop(attn) out na2d_av( attn, v, kernel_sizeself.kernel_size, dilationself.dilation, is_causalself.is_causal, ) out einops.rearrange(out, b g h w c - b (g c) h w) else: q einops.rearrange(q, b (g c) h w - b h w g c, gself.n_groups, bB, cself.n_group_channels, hH, wW) k einops.rearrange(self.proj_k(x_sampled), b (g c) h w - b h w g c, gself.n_groups, bB, cself.n_group_channels, hH, wW) v einops.rearrange(self.proj_v(x_sampled), b (g c) h w - b h w g c, gself.n_groups, bB, cself.n_group_channels, hH, wW) out na2d( q, k, v, kernel_sizeself.kernel_size, dilationself.dilation, is_causalself.is_causal, rpbself.rpb, scaleself.scale, ) out out.reshape(B, H, W, C).permute(0, 3, 1, 2) if self.use_pe and self.dwc_pe: out out residual_lepe y self.proj_drop(self.proj_out(out)) return y if __name__ __main__: # 超参数设置 batch_size 2 height, width 128, 128 # 输入图像大小 channels 64 # 输入通道数需能被 num_heads 整除 num_heads 8 # 注意力头数 kernel_size 7 dilation 1 # 创建输入张量形状为 (B, C, H, W) input torch.randn(batch_size, channels, height, width) # 初始化 DeformableNeighborhoodAttention 模块 model DeformableNeighborhoodAttention( dimchannels, num_headsnum_heads, kernel_sizekernel_size, dilationdilation, offset_range_factor1.0, stride1, use_peTrue, dwc_peTrue, no_offFalse, fixed_peFalse, is_causalFalse, rel_pos_biasFalse, attn_drop0.0, proj_drop0.0, ) print(model) print(CSDN:AI魔改博士) output model(input) print(DeformableNeighborhoodAttention input_size:, input.size()) print(DeformableNeighborhoodAttention output_size:, output.size())