深度学习(超分辨率)

news/2024/12/21 21:53:43/文章来源:https://www.cnblogs.com/tiandsp/p/18611110

简单训练了一个模型,可以实现超分辨率效果。模型在这里。

模型用了一些卷积层,最后接一个PixelShuffle算子。

训练数据是原始图像resize后的亮度通道。

标签是原始图像的亮度通道。

损失函数设为MSE。

代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
from PIL import Image
from os import listdir
from os.path import join
import numpy as npcrop_size = 256
upscale_factor = 3
crop_size = crop_size - (crop_size % upscale_factor)input_transformer= Compose([CenterCrop(crop_size),Resize(crop_size // upscale_factor),ToTensor()])target_transform =Compose([CenterCrop(crop_size),ToTensor()])class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(1, 64, 5, 1, 2)self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)self.conv3 = nn.Conv2d(64, 32, 3, 1, 1)self.conv4 = nn.Conv2d(32, upscale_factor ** 2, 3, 1, 1)self.pixel_shuffle = nn.PixelShuffle(upscale_factor)def forward(self, x):x = self.relu(self.conv1(x))x = self.relu(self.conv2(x))x = self.relu(self.conv3(x))x = self.pixel_shuffle(self.conv4(x))       return xclass SRData(Dataset):def __init__(self, image_dir):self.image_filenames = [join(image_dir, x) for x in listdir(image_dir)]def __len__(self):return len(self.image_filenames)def __getitem__(self, index):image = Image.open(self.image_filenames[index]).convert('YCbCr')y, _, _ = image.split()img = input_transformer(y)lab = target_transform(y)return img, labdef train():num_epochs = 2model = Net()optimizer = optim.Adam(model.parameters(), lr=0.01)criterion = nn.MSELoss()train_dataset = SRData('./dataset')train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)model.train()for epoch in range(num_epochs):running_loss = 0.0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")torch.save(model, 'super_res.pth')def test():img = Image.open("test.jpg").convert('YCbCr')y, cb, cr = img.split()model = torch.load("super_res.pth")img_to_tensor = ToTensor()input = img_to_tensor(y).view(1, 1, y.size[1], y.size[0])model = model.cuda()input = input.cuda()out = model(input)out = out.cpu()out_img_y = out[0].detach().numpy()out_img_y *= 255.0out_img_y = out_img_y.clip(0, 255)out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')out_img.save("out.jpg")if __name__ == "__main__":#  train()test()

效果如下:

原图:

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/856555.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s阶段09 Velero备份恢复功能, 云原生的定义, k8s实践项目(Mall-MicroService)

4 基于Velero的备份和恢复Velero介绍Velero是用于备份和恢复 Kubernetes 集群资源和PV的开源项目,由VMWare-Tanzu维护◼ 基于Velero CRD创建备份(Backup)和恢复作业(Restore)◼ 可以备份或恢复集群中的所有对象,也可以按类型、名称空间或标签过滤对象◼ 可基于文件系统…

AI火灾监测报警摄像机

AI火灾监测报警摄像机,作为一种结合人工智能技术和摄像监控技术的创新产品,在火灾防控领域发挥着越来越重要的作用。这种摄像机通过先进的AI算法,能够实时监测摄像头画面,识别出火灾的特征,如火光、浓烟等。一旦检测到火灾迹象,系统会立即启动报警机制,并向相关管理人员…

AI人员入侵识别摄像机

AI人员入侵识别摄像机是一种智能监控设备,利用人工智能技术辨认并报警可能的入侵行为。这种摄像机利用深度学习算法实时分析监控画面,识别出普通行人和潜在入侵者之间的差异,从而更准确地预警可能发生的安全事件。AI人员入侵识别摄像机是一种智能监控设备,利用人工智能技术…

javaweb练习分析——2

在进行完文件的配置之后,就要按照数据库封装bean,放在pojo层中,然后创建相应的mapper.xml文件(创建时要用/间隔)之后根据项目要求,搭建主界面。 根据不同角色的功能,搭建各自的界面,以其中一个为例 <!DOCTYPE html> <html lang="en"> <head&g…

javaweb练习分析——1

首先在写项目时首先要做的是创建一个web项目,配置好pom.xml文件,mybatis.xml文件,还有创建相应的结构比如pojo、mapper、service等等。xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"><modelVersion>…

鸿蒙HarmonyOS应用开发 | 「鸿蒙技术分享」HarmonyOS NEXT元服务卡片实战体验

「鸿蒙技术分享」—HarmonyOS NEXT 元服务卡片实战体验 HarmonyOS NEXT 是华为鸿蒙系统的最新版本,带来了更为流畅、高效的体验,并以元服务卡片(Service Widget)为核心,优化了服务分发和交互体验。本文将从开发者的角度,分享如何开发和部署元服务卡片,并结合代码实例,带…

Java 基础:关键字 标识符

1. 关键字(Keyword)定义:被Java语言赋予了特殊含义,用做专门用途的字符串(或单词)HelloWorld案例中,出现的关键字有 class、public 、 static 、 void 等,这些单词已经被Java定义好了特点:全部关键字都是小写字母 关键字比较多,不需要死记硬背,学到哪里记到哪里…

三点估算

三点估算选择的三种估算值不包括如下哪项如下: 三点估算是一种常用的项目管理工具,用于估算项目的成本、工期和资源等情况。三点估算通过选择最可能值、最乐观值和最悲观值来确定估算范围,以提高估算的准确性和可信度。在进行三点估算时,选择的三种估算值通常不包括以下几项…

Wpf Prism中添加新控件的区域适配器

上节中我们讲了怎么样定义一个区域与区域引用视图,但并不是所有的组件都支持组件当作区域使用,比如StackPanel就不支持当作区域来使用: 我们自接使用会报以下错误,这时候我们就要自定义一个区域适配器: 1.首先我们创建一个StackPanelRegionAdapter的类:1 using Prism.Reg…

【专题】大模型时代的具身智能2024报告汇总PDF洞察(附原数据表)

原文链接: https://tecdat.cn/?p=38597 在当今科技飞速发展的时代,大模型的崛起如同一股强劲的浪潮,席卷了整个科技领域,而具身智能则在这浪潮中崭露头角,成为人工智能领域备受瞩目的前沿方向。随着数据的海量增长与计算能力的迅猛提升,大模型为具身智能注入了强大的智慧…

【数据库开发】小红书MySQL数据一致性校验能力探索与实践

原创 等你加入的 小红书技术REDtech 2024年12月19日 17:01 北京 图片 本文主要介绍数据一致性校验如何结合小红书的业务进行实践并落地,以及数据一致性校验在小红书内部拿到的实际收益。 如有感兴趣的同学,欢迎联系我们开展技术交流。 一、背景 1.1 什么是数据一致性校验 在数…