import torch.nn as nn
import torch
import torch.nn.functional as F
class PagFM(nn.Module):# 选择性特征融合 直接融合细节和低频上下文 容易导致细节丢失def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):super(PagFM, self).__init__()self.with_channel = with_channelself.after_relu = after_reluself.f_x = nn.Sequential(nn.Conv2d(in_channels, mid_channels,kernel_size=1, bias=False),BatchNorm(mid_channels))self.f_y = nn.Sequential(nn.Conv2d(in_channels, mid_channels,kernel_size=1, bias=False),BatchNorm(mid_channels))if with_channel:self.up = nn.Sequential(nn.Conv2d(mid_channels, in_channels,kernel_size=1, bias=False),BatchNorm(in_channels))if after_relu:self.relu = nn.ReLU(inplace=True)def forward(self, x, y):input_size = x.size()if self.after_relu:y = self.relu(y)x = self.relu(x)y_q = self.f_y(y)y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],mode='bilinear', align_corners=False)x_k = self.f_x(x)if self.with_channel:sim_map = torch.sigmoid(self.up(x_k * y_q))else:sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))y = F.interpolate(y, size=[input_size[2], input_size[3]],mode='bilinear', align_corners=False)x = (1 - sim_map) * x + sim_map * yreturn xif __name__ == '__main__':x = torch.randn((1, 4, 9, 9)).cuda() # x应该是细节y = torch.randn((1, 4, 4, 4)).cuda() # y应该是低频上下文expand_ratio=4model = PagFM(in_channels=4,mid_channels=4*expand_ratio).cuda()out = model(x,y)print(out.shape)