python基于VGG19实现图像风格迁移

目录

1、原理

2、代码实现


1、原理

图像风格迁移是一种将一张图片的内容与另一张图片的风格进行合成的技术。

风格(style)是指图像中不同空间尺度的纹理、颜色和视觉图案,内容(content)是指图像的高级宏观结构。

实现风格迁移背后的关键概念与所有深度学习算法的核心思想是一样的:定义一个损失函数来指定想要实现的目标,然后将这个损失最小化。你知道想要实现的目标是什么,就是保存原始图像的内容,同时采用参考图像的风格。

在Python中,我们可以使用基于深度学习的模型来实现这一技术。​神经风格迁移可以用任何预训练卷积神经网络来实现。我们这里将使用    Gatys等人所使用的 VGG19网络。

2、代码实现

​以下是一个基于VGG19模型的简单图像风格迁移的实现过程:

(1)创建一个网络,它能够同时计算风格参考图像、目标图像和生成图像的 VGG19层激活。

(2)使用这三张图像上计算的层激活来定义之前所述的损失函数,为了实现风格迁移,需要将这个损失函数最小化。

(3)设置梯度下降过程来将这个损失函数最小化

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
​
import torch
import torch.optim as optim
from torchvision import transforms, models
​
vgg = models.vgg19(pretrained=True).features
​
for param in vgg.parameters():param.requires_grad_(False)
​
​
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
vgg.to(device)
​
​
def load_image(img_path, max_size=400):
​image = Image.open(img_path)if max(image.size) > max_size:size = max_sizeelse:size = max(image.size)image_transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
​image = image_transform(image).unsqueeze(0)return image
​
​
content = load_image('dogs_and_cats.jpg').to(device)
style = load_image('picasso.jpg').to(device)
​
​
assert style.size() == content.size(), "输入的风格图片和内容图片大小需要一致"
​
​
plt.ion()
def imshow(tensor,title=None):image = tensor.cpu().clone().detach()image = image.numpy().squeeze()image = image.transpose(1,2,0)image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))plt.imshow(image)if title is not None:plt.title(title)plt.pause(0.1)
​
plt.figure()
imshow(style, title='Style Image')
​
plt.figure()
imshow(content, title='Content Image')
​
​
​
​
def get_features(image, model, layers=None):if layers is None:layers = {'0': 'conv1_1','5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1','21': 'conv4_2',  '28': 'conv5_1'}features = {}x = imagefor name, layer in model._modules.items():x = layer(x)if name in layers:features[layers[name]] = xreturn features
​
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
​
def gram_matrix(tensor):_, d, h, w = tensor.size() tensor = tensor.view(d, h * w)gram = torch.mm(tensor, tensor.t())return gram
​
style_grams={}
for layer in style_features:style_grams[layer] = gram_matrix(style_features[layer])
​
import torch.nn.functional as F
​
def ContentLoss(target_features,content_features):content_loss = F.mse_loss(target_features['conv4_2'],content_features['conv4_2'])return content_loss
​
def StyleLoss(target_features,style_grams,style_weights):style_loss = 0for layer in style_weights:target_feature = target_features[layer]target_gram = gram_matrix(target_feature)_, d, h, w = target_feature.shapestyle_gram = style_grams[layer]layer_style_loss = style_weights[layer] * F.mse_loss(target_gram,style_gram)style_loss += layer_style_loss / (d * h * w)
​return style_loss
​
​
style_weights = {'conv1_1': 1.,'conv2_1': 0.75,'conv3_1': 0.2,'conv4_1': 0.2,'conv5_1': 0.2}
​
alpha = 1  # alpha
beta = 1e6  # beta
​
​
show_every = 100
steps = 2000 
​
target = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([target], lr=0.003)
​
​
for ii in range(1, steps+1):target_features = get_features(target, vgg)content_loss = ContentLoss(target_features,content_features)style_loss = StyleLoss(target_features,style_grams,style_weights)total_loss = alpha * content_loss + beta * style_lossoptimizer.zero_grad()total_loss.backward()optimizer.step()#print(ii)if  ii % show_every == 0:print('Total loss: ', total_loss.item())plt.figure()      imshow(target)
​
plt.figure()
imshow(target,"Target Image")
plt.ioff()
plt.show()
​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/155626.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis 哨兵

目录 ​编辑 一、哨兵原理 1、集群结构和作用 2、集群监控原理 3、集群故障恢复原理 二、搭建哨兵集群 1、集群结构 2、准备实例和配置 3、启动 三、RedisTemplate 的哨兵模式 一、哨兵原理 1、集群结构和作用 Redis提供了哨兵(Sentinel)机制…

华为NAT配置实例(含dhcp、ospf配置)

一、网络拓朴如下: 二、要求:PC1 能访问到Server1 三、思路: R2配置DHCP,R2和R1配OSPF,R1出NAT 四、主要配置: R2的DHCP和OSPF: ip pool 1gateway-list 10.1.1.1 network 10.1.1.0 mask 25…

RabbitMQ消费者的可靠性

目录 一、消费者确认 二、失败重试机制 2.1、失败处理策略 三、业务幂等性 3.1、唯一消息ID 3.2、业务判断 3.3、兜底方案 一、消费者确认 RabbitMQ提供了消费者确认机制(Consumer Acknowledgement)。即:当消费者处理消息结束后&#x…

Python构造代理IP池提高访问量

目录 前言 一、代理IP是什么 二、代理IP池是什么 三、如何构建代理 IP 池 1. 从网上获取代理 IP 地址 2. 对 IP 地址进行筛选 3. 使用筛选出来的 IP 地址进行数据的爬取 四、总结 前言 爬虫程序是批量获取互联网上的信息的重要工具,在访问目标网站时需要频…

linux系统的环境变量-搞清环境变量到底是什么

环境变量 引例环境变量常见的环境变量echoexportenvunsetset 通过代码获取环境变量使用第三个参数获取使用全局变量enviorn获取环境变量通过系统调用获取环境变量 环境变量具有全局属性main函数前两个参数的作用 引例 在linux系统中,我们使用ls命令,直接…

python 打印与去除不可见字符 \x00

# 此处不是真实的\x00 被 空格替换了 text "boot_1__normal/ " print(text.strip()"boot_1__normal/") # 打印不可见字符 print(repr(text))>>> False boot_1__normal/\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0…

Prometheus+Grafana

一、Prometheus 获取配置文件 docker run -d -p 9090:9090 --name prometheus prom/prometheusmkdir -p /app/prometheusdocker cp prometheus:/etc/prometheus/prometheus.yml /app/prometheus/prometheus.yml停止并删除旧的容器,重新启动 docker run -d --name…

一带一路10周年:爱创科技加速中国药企国际化征程

“源自中国,属于世界”。 共建“一带一路”倡议提出10周年来,中国与沿线国家经济深度融合,在共商共建共享的基本原则下,“一带一路”形成了国际合作的平台和机制,跨国经济合作已基本形成。 随着“一带一路”合作日益加…

阿里云Apsara云栖大会2023

文章目录 2023/10/312023/11/012023/11/02彩蛋1:神州十六号彩蛋2:emm… 计算,为了无法计算的价值。 2023/10/31 合规性评审 2023/11/01 暂未开始 2023/11/02 暂未开始 彩蛋1:神州十六号 彩蛋2:emm…

Mybatis—基础操作

mybatis入门后,继续学习mybatis基础操作。 目录 Mybatis基础操作准备工作删除操作日志输入预编译SQLSQL注入参数占位符 新增操作基本新增添加后返回主键 更新操作查询操作根据id查询数据封装条件查询条件查询 Mybatis基础操作 准备工作 根据下面页面原型及需求&am…

【Linux】安装与配置虚拟机及虚拟机服务器坏境配置与连接

目录 操作系统介绍 什么是操作系统 常见操作系统 UNIX操作系统 linux操作系统 mac操作系统 嵌入式操作系统 个人版本和服务器版本的区别 安装VMWare虚拟机 VMWare虚拟网卡 ​编辑 配置虚拟网络编辑器 ​编辑 安装配置Windows Server 2012 R2 安装Windows Server 2…

轻量级gif制作工具 GIFfun中文 for mac

GIFfun是一款GIF制作工具,可以帮助用户从照片和视频中创建GIF动画。该软件具有多种功能,例如GIF转视频、视频转GIF、照片转GIF、照片转视频、GIF转JPG、调整GIF大小、PDF转GIF、PDF转JPG、裁剪视频、GIF编辑等。 GIFfun还提供了专业版功能,如…