神经网络图像压缩代码

news/2025/2/27 6:06:15/文章来源:https://www.cnblogs.com/gotoplay/p/18740127

参加第五届人工智能竞赛,选的图像编码赛道(钱多),纯记录下,这神经网络结构打榜分数也不高,我觉得重要的在于找到一种合适于图像压缩任务的结构,训练倒是其次,主办方让完全采用AI的方式去做,我觉得在网络结构的选取上,势必要加入一些自己对图像的专业理解的,只是这种理解不能以传统的方式表现出来。

点击查看代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import numpy as np
import lpips
from pytorch_msssim import ms_ssim
from torchvision.transforms.functional import normalizeclass ImageCompressor(nn.Module):def __init__(self, compression_ratio=8):super(ImageCompressor, self).__init__()# 编码器 - 减少通道数以提高压缩率# 编码器部分self.encoder = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(64, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 16, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(16, 4, kernel_size=3, padding=1),nn.ReLU())# 解码器部分相应修改self.decoder = nn.Sequential(nn.ConvTranspose2d(4, 16, kernel_size=3, padding=1),nn.ReLU(),nn.Upsample(scale_factor=2),nn.ConvTranspose2d(16, 32, kernel_size=3, padding=1),nn.ReLU(),nn.Upsample(scale_factor=2),nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.Upsample(scale_factor=2),nn.ConvTranspose2d(64, 3, kernel_size=3, padding=1),nn.Sigmoid())def forward(self, x):encoded = self.encoder(x)decoded = self.decoder(encoded)return decodeddef compress_image(image_path, output_path, model_path=None):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ImageCompressor().to(device)# 定义多个损失函数mse_criterion = nn.MSELoss()  # MSE损失(用于PSNR)lpips_criterion = lpips.LPIPS(net='alex').to(device)  # LPIPS损失optimizer = optim.Adam(model.parameters(), lr=0.001)# 准备数据transform = transforms.Compose([transforms.ToTensor(),# 添加归一化处理transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])img = Image.open(image_path)img_tensor = transform(img).unsqueeze(0).to(device)# 训练模型model.train()for epoch in range(300):optimizer.zero_grad()output = model(img_tensor)# 计算多个损失# 1. MSE损失(用于PSNR)mse_loss = mse_criterion(output, img_tensor)psnr = -10 * torch.log10(mse_loss)# 2. MS-SSIM损失ms_ssim_loss = 1 - ms_ssim(output, img_tensor, data_range=1.0)# 3. LPIPS损失lpips_loss = lpips_criterion(output, img_tensor).mean()# 组合损失,使用权重平衡各项total_loss = (1.0 * mse_loss +        # 基础重建损失0.3 * ms_ssim_loss +    # 结构相似性损失0.1 * lpips_loss        # 感知损失)total_loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch+1}/300]')print(f'PSNR: {psnr.item():.2f}')print(f'MS-SSIM Loss: {ms_ssim_loss.item():.4f}')print(f'LPIPS: {lpips_loss.item():.4f}')print('------------------------')# 保存模型if model_path:torch.save(model.state_dict(), model_path)# 压缩图像model.eval()with torch.no_grad():img = Image.open(image_path)img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)compressed = model(img_tensor)# 将结果转换回图像output_img = transforms.ToPILImage()(compressed.squeeze(0).cpu())output_img.save(output_path)def decompress_image(compressed_path, output_path, model_path):# 加载模型device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ImageCompressor().to(device)if not os.path.exists(model_path):raise ValueError("未找到模型文件!")model.load_state_dict(torch.load(model_path))model.eval()# 解压缩图像with torch.no_grad():img = Image.open(compressed_path)img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device)decompressed = model(img_tensor)# 将结果转换回图像output_img = transforms.ToPILImage()(decompressed.squeeze(0).cpu())output_img.save(output_path)if __name__ == "__main__":import os# 示例使用input_path = "input.jpg"compressed_path = "compressed.jpg"decompressed_path = "decompressed.jpg"model_path = "compressor_model.pth"# 压缩流程print("正在压缩图像...")compress_image(input_path, compressed_path, model_path)print(f"压缩完成,已保存至 {compressed_path}")# 解压缩流程print("正在解压缩图像...")decompress_image(compressed_path, decompressed_path, model_path)print(f"解压缩完成,已保存至 {decompressed_path}")
目前效果如下:
点击查看代码
Epoch [300/300]
PSNR: 10.01
MS-SSIM Loss: 0.3211
LPIPS: 0.2903

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/890443.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web开发 辅助学习管理系统开发日记 day3

Q1:讲解了关于日志输出的方法:首先引入logback以来以及相关的xml文件到resources的文件夹下。然后到test目录下进行测试 可以通过调整logback的xml文件来调整日志输出的格式,以及位置日志级别记录如下Q2:通过这条外键约束可以将两表关联后避免删除误删两边关联所需要的键值产…

How to Fix Raspberry Pi Imager lost Advanced Menu problem All In One

Raspberry Pi Imager removed Advanced Menu All In One 如何修复 Raspberry Pi Imager 丢失高级菜单问题Raspberry Pi Imager removed Advanced Menu All In One如何修复 Raspberry Pi Imager 丢失高级菜单问题树莓派 bug Raspberry Pi Imager v1.8.5 删除高级菜单选项 ❓solu…

WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战

开源项目名称:leagueoflegends-OpenSilver 作者:Vicky&James leagueoflegends-opensilver:https://github.com/jamesnetgroup/leagueoflegends-opensilver leagueoflegends-wpf:https://github.com/jamesnetgroup/leagueoflegends-wpf Jamesnet个人网站:https://jamesnet…

Raspberry Pi Imager removed Advanced Menu All In One

Raspberry Pi Imager removed Advanced Menu All In One Raspberry Pi Imager 已删除高级菜单Raspberry Pi Imager removed Advanced Menu All In OneRaspberry Pi Imager 已删除高级菜单raspberry pi imager advanced options missing demos树莓派 bugRaspberry Pi Imager v1.…

作业一 自我介绍+软工5问

| 这个作业属于哪个课程 | <班级> | | ----------------- |--------------- | | 这个作业要求在哪里| <作业要求> | | 这个作业的目标 | <- 介绍自己,比如兴趣爱好,学习或者生活经历,认为有趣或者值得向别人展示的记忆快速阅读教材,提出5个想弄懂的问题。 想…

《软件开发与创新课程设计》第一次作业:软件二次开发

一、项目名称与来源 该项目名为体育新闻信息查询系统,源码来自同学。 二、目的 项目体育新闻信息查询系统的目的是基于Java和JavaScript创建一个能够查询体育新闻的web系统。本次作业的目的是基于该系统的基础上进行二次开发。 三、部分原代码 1.体育新闻网点击查看代码 <!…

如何更改 debian 系统家目录中文件夹的语言

一、当前家目录文件夹是中文 当时安装系统的时候,选择了中文,导致家目录的文件夹也是中文的。这导致在命令行中会出现中文路径,现在想把它改成英文的。二、改成英文 家目录的 .config 文件夹中有两个文件与此相关,它们分别是: # /home/xxx/.config user-dirs.dirs user-di…

C++ 超市零售系统二次开发

一、来源 本次分析与二次开发的超市零售系统项目名称为 "SimpleSupermarketManagement", 作者是 GitHub 用户 "CodeExplorer1995",项目地址为https://github.com/CodeExplorer1995/SimpleSupermarketManagement。该项目旨在为小型超市提供基础的业务管理…

IOC 和 DI 详解及其简单用法

1. IOC 详解 1.1 Bean 的声明 IOC 控制反转,就是将对象的控制权交给 Spring 的 IOC 容器,由 IOC 容器创建及管理对象。IOC 容器创建的对象称为 bean 对象。 而 Spring 框架为了更好的标识 Web 应用程序开发当中,bean 对象到底归属于哪一层,又提供了 @Component 的衍生注解:…

开源一款DDS信号发生扩展板-FreakStudio多米诺系列

信号发生扩展板通过SPI接口生成可调频率和幅度的正弦波、方波和三角波,频率小于1MHz。支持幅度调节,提供原始和6倍放大输出接口。配备5阶低通滤波器、噪声抑制功能,优化信号稳定性。原文链接: FreakStudio的博客 摘要 信号发生扩展板通过SPI接口生成可调频率和幅度的正弦波…

Windows系统更改/迁移用户目录

Windows系统更改/迁移用户目录Windows系统更改/迁移用户目录 迁移的原因C盘空间不足 不想将我的文档等放在C盘,方便重做系统 其他原因迁移有什么风险么目前没发现有什么风险迁移过程 准备工作 更改/迁移用户目录之前先自行备份当前用户的资料(下载目录、桌面文件等),以免数…

[2025.2.26 JavaWeb学习]登录校验

流程图会话技术指浏览器与服务器的一次连接,直到某一方断开,某个浏览器的一次会话可以包含多次请求和响应会话跟踪:一种维护浏览器状态的方法,服务器需要识别多次请求是否来自于同一浏览器,以便在同一次会话的多次请求间共享数据