SSF-CNN:空间光谱融合的卷积光谱图像超分网络

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

  • SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION
    • 简介
    • 解决问题
    • 网络框架
    • 代码实现
    • 训练部分
    • 运行结果

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值,使其空间分辨率大小和多光谱一致,然后进行concat操作。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark——光谱图像超分基准-CSDN博客
Paper: IEEE
Code:https://github.com/miraclefan777/SSFCNN

2023-11-25_16-06-09

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

class SSFCNNnet(nn.Module):def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):super(SSFCNNnet, self).__init__()self.scale_factor = scale_factorself.pdconv = pdconvself.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")if pdconv:self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")else:self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")self.relu = nn.ReLU(inplace=True)def forward(self, lr_hs, hr_ms):""":param lr_hs:LR-HSI低分辨率的高光谱图像:param hr_ms:高分辨率的多光谱图像:return:"""# 对LR-HSI低分辨率图像进行上采样,让其分辨率更高lr_hs_up = self.Upsample(lr_hs)# 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接x = torch.cat((lr_hs_up, hr_ms), dim=1)x = self.relu(self.conv1(x))if self.pdconv:x = torch.cat((x, hr_ms), dim=1)x = self.relu(self.conv2(x))x = torch.cat((x, hr_ms), dim=1)else:x = self.relu(self.conv2(x))out = self.conv3(x)return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nnif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--model', type=str, default="SSFCNNnet")parser.add_argument('--train-file', type=str, required=True)parser.add_argument('--eval-file', type=str, required=True)parser.add_argument('--outputs-dir', type=str, required=True)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-4)parser.add_argument('--batch-size', type=int, default=32)parser.add_argument('--num-workers', type=int, default=0)parser.add_argument('--num-epochs', type=int, default=400)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()assert args.model in ['SSFCNNnet', 'PDcon_SSF']outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))if not os.path.exists(outputs_dir):os.makedirs(outputs_dir)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)# 训练参数# loss_func = nn.L1Loss(reduction='mean').cuda()criterion = nn.MSELoss()#################数据集处理#################R = create_F()PSF = fspecial('gaussian', 8, 3)downsample_factor = 8training_size = 64stride = 32stride1 = 32train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)#################数据集处理################## 模型if args.model == 'SSFCNNnet':model = SSFCNNnet().cuda()else:model = SSFCNNnet(pdconv=True).cuda()best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0# 模型初始化for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.xavier_uniform_(m.weight)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)optimizer = torch.optim.Adam([{'params': model.conv1.parameters()},{'params': model.conv2.parameters()},{'params': model.conv3.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)start_epoch = 0for epoch in range(start_epoch, args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:label, lr_hs, hr_ms = datalabel = label.to(device)lr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)lr = optimizer.param_groups[0]['lr']pred = model(hr_ms, lr_hs)loss = criterion(pred, label)epoch_losses.update(loss.item(), len(label))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))t.update(len(label))# torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))if epoch % 5 == 0:model.eval()val_loss = AverageMeter()SAM = Loss_SAM()RMSE = Loss_RMSE()PSNR = Loss_PSNR()sam = AverageMeter()rmse = AverageMeter()psnr = AverageMeter()for data in eval_dataloader:label, lr_hs, hr_ms = datalr_hs = lr_hs.to(device)hr_ms = hr_ms.to(device)label = label.cpu().numpy()with torch.no_grad():preds = model(hr_ms, lr_hs).cpu().numpy()sam.update(SAM(preds, label), len(label))rmse.update(RMSE(preds, label), len(label))psnr.update(PSNR(preds, label), len(label))if psnr.avg > best_psnr:best_epoch = epochbest_psnr = psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/218510.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis面试题:redis做为缓存,mysql的数据如何与redis进行同步呢?(双写一致性)

目录 强一致性:延迟双删,读写锁。 弱一致性:使用MQ或者canal实现异步通知 面试官:redis做为缓存,mysql的数据如何与redis进行同步呢?(双写一致性) 候选人:嗯&#xff…

【单片机学习笔记】STC8H1K08参考手册学习笔记

STC8H1K08参考手册学习笔记 STC8H系列芯片STC8H1K08开发环境串口烧录 STC8H系列芯片 STC8H 系列单片机是不需要外部晶振和外部复位的单片机,是以超强抗干扰/超低价/高速/低功耗为目标的 8051 单片机,在相同的工作频率下,STC8H 系列单片机比传统的 8051约快12 倍速度…

⑨【Stream】Redis流是什么?怎么用?: Stream [使用手册]

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ ⑨Redis Stream基本操作命令汇总 一、Redis流 …

arduino的API函数

API在这里:Arduino Reference - Arduino Reference 我觉得一天是不可能学的完的,这么多呢 我现在觉得:不用去学习这些API,以后碰到再去看好了

北邮22级信通院数电:Verilog-FPGA(11)第十一周实验(2)设计一个24秒倒计时器

北邮22信通一枚~ 跟随课程进度更新北邮信通院数字系统设计的笔记、代码和文章 持续关注作者 迎接数电实验学习~ 获取更多文章,请访问专栏: 北邮22级信通院数电实验_青山如墨雨如画的博客-CSDN博客 目录 一.代码部分 1.1 counter_24.v 1.2 divid…

JVM——几种常见的对象引用

目录 1. 软引用软引用的使用场景-缓存 2.弱引用3.虚引用和终结器引用 可达性算法中描述的对象引用,一般指的是强引用,即是GCRoot对象对普通对象有引用关系,只要这层关系存在, 普通对象就不会被回收。除了强引用之外,Ja…

【限流配电开关】TPS2001C

🚩 WRITE IN FRONT 🚩 🔎 介绍:"謓泽"正在路上朝着"攻城狮"方向"前进四" 🔎🏅 荣誉:2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2222年获评百大…

高清动态壁纸软件Live Wallpaper Themes 4K mac中文版功能

Live Wallpaper & Themes 4K mac是一款提供各种高清动态壁纸和主题的应用程序。该应用程序提供了大量的动态壁纸和主题,包括自然、动物、城市、抽象等各种类别,可以满足用户不同的需求。除了壁纸和主题之外,该应用程序还提供了许多其他功…

[修订版][工控]SIEMENS S7-200 控制交通红绿灯程序编写与分析

下载地址>https://github.com/MartinxMax/Siemens_S7-200_Traffic_Light 特别鸣谢接线过程实验目的题目要求I/O分配公式公式套用示例 程序分析分割块[不是必要的,自己分析用]左侧梯形图 [B1-B5]B1 [东西绿灯亮25s]B2 B3 B23 [东西绿灯闪烁3s]B4 [东西黄灯亮2s]B5 [东西红灯…

BART 并行成像压缩感知重建:联合重建

本文使用 variavle-density possion-disc 采样的多通道膝盖数据进行并行重建和压缩感知重建。 0 数据欠采样sampling pattern 1 计算ESPIRiT maps % A visualization of k-space dataknee = readcfl(data/knee); ksp_rss = bart(rss 8, knee);ksp_rss = squeeze(ksp_rss); figu…

CSS新特性(2-2)

CSS新特性(2-2) 前言box相关box-shadow background背景rgba颜色与透明度transform:rotate(Xdeg) 2D旋转transform:tranlate 平移 前言 本文继续讲解CSS3其他的新特性,想看之前新特性点击这里,那么好本文正式开始。 box相关 box…

时间序列预测实战(十九)魔改Informer模型进行滚动长期预测(科研版本)

论文地址->Informer论文地址PDF点击即可阅读 代码地址-> 论文官方代码地址点击即可跳转下载GIthub链接 个人魔改版本地址-> 文章末尾 一、本文介绍 在之前的文章中我们已经讲过Informer模型了,但是呢官方的预测功能开发的很简陋只能设定固定长度去预测未…