论文代码学习—HiFi-GAN(3)——模型损失函数loss解析

文章目录

    • 引言
    • 正文
      • 生成器损失函数
        • 最小二乘损失函数
        • 梅尔频谱图损失函数
        • 特征匹配损失函数
        • 生成器最终损失函数loss
        • 生成器loss对应代码
      • 鉴定器损失函数
        • 鉴定器损失函数代码
    • 总结
    • 引用

引言

  • 这里翻译了HiFi-GAN这篇论文的具体内容,具体链接。
  • 这篇文章还是学到了很多东西,从整体上说,学到了生成对抗网络的构建思路,包括生成器和鉴定器。细化到具体实现的细节,如何 实现对于特定周期的数据处理?在细化,膨胀卷积是如何实现的?这些通过文章,仅仅是了解大概的实现原理,但是对于代码的实现细节并不是很了解。如果要加深印象,还是要结合代码来具体看一下实现的细节。
  • 本文主要围绕具体的代码实现细节展开,对于相关原理,只会简单引用和讲解。因为官方代码使用的是pytorch,所以是通过pytorch展开的。
  • 关于模型其他部分的介绍,链接如下
    • 论文代码学习(1)—HiFi-GAN——生成器generator代码
    • 论文代码学习—HiFi-GAN(2)——鉴别器discriminator代码

正文

  • 关于模型的损失函数,这里总共有两部分损失函数,分别是生成器损失函数和鉴定器损失函数。其中生成器的损失函数,有分为三部分,分别是常规的对抗生成损失、针对特征匹配的损失函数和针对梅尔频谱图的损失函数,后两者是作者自己的加上去的。

生成器损失函数

  • 对于生成器损失函数,作者分成了三个部分,分别是基本损失函数、针对特征匹配的损失函数以及梅尔损失函数。

最小二乘损失函数

  • 不同于一般的GAN网络使用交叉熵损失函数,这里使用的是最小二乘损失函数,借此来避免梯度丢失的现象。

  • 最小二乘损失函数

    • 用于衡量模型预测值和真实值的差异,具体特点如下
      • 平方项:通过平方差异,扩大误差,模型更加关注于难以拟合的样本
      • 连续可微:连续可微,可以有效找到最小值
      • 非负:损失函数的值始终非负

在这里插入图片描述

  • 生成器的损失函数的目的是为了使得生成的数据,经过鉴定器判定,和真的差不多。
  • 具体的公式如下
    • s s s是梅尔频谱图,输入的条件变量
    • x x x是真实数据
    • D ( x ) D(x) D(x)是鉴定器对于输入结果的评分,越逼真越接近1
    • G ( s ) G(s) G(s)是生成器根据梅尔频谱图生成的结果

在这里插入图片描述

  • 在上式子中,损失函数越小越好,生成器的效果越好,鉴定器,会将其分辨为1,做差,越靠近零,效果越好。

梅尔频谱图损失函数

  • 除了考虑基本的损失函数,这里还增加梅尔频谱图损失函数,用来提高训练效果和生成音频的分辨率,主要是抓住了梅尔频谱图对于感知能力的重视。
  • 定义
    • 计算合成的波形图和实际波形图的对应采样点的L1距离
  • 参数说明
    • ∅ \varnothing 表示将波形图转为mel频谱图
  • 效果:
    • 帮助生成器生成和输入相关的实际波形
    • 是的对抗训练阶段能够快速稳定下来

在这里插入图片描述

特征匹配损失函数

  • 特征匹配损失函数是用来衡量真实样本和生成样本在鉴定器上提取出来的特征的差异程度。不同于上一个mel频谱图的特征衡量,这里是直接衡量鉴定器生成的中间特征的差异程度。

  • 定义

    • 计算真实样本和生成样本分别在鉴定器上生成的中间特征的L1距离
  • 参数说明

    • T T T表示为鉴定器的层数
    • D i D^i Di N i N_i Ni分别表示第i层的特征值和特征的数量。
  • 效果

    • 从鉴定器特征角度使得生成器的样本更加逼真
      在这里插入图片描述
  • 注意

    • 这里并不是单单一个层的特征,是鉴定器上每一层的输出特征的L1距离累加和的平均值。

生成器最终损失函数loss

  • 生成器最终的损失函数,是上述三个损失函数之和,并且特征匹配损失函数和mel频谱图损失函数,加上对应的权重,具体如下
    • λ f m = 2 \lambda_{fm} = 2 λfm=2 λ m e l = 45 \lambda_{mel} = 45 λmel=45
      在这里插入图片描述

生成器loss对应代码

def feature_loss(fmap_r, fmap_g):# 特征损失函数# fmap_r是真实音频信号的特征图,fmap_g是生成音频信号的特征图loss = 0for dr, dg in zip(fmap_r, fmap_g):for rl, gl in zip(dr, dg):# 遍历每一层特征图,计算特征损失,做差,求绝对值,求均值loss += torch.mean(torch.abs(rl - gl))# 根据经验,特征损失函数的权重为10return loss*2def generator_loss(disc_outputs):# 生成器的损失函数# disc_outputs是鉴定器的输出loss = 0gen_losses = []for dg in disc_outputs:l = torch.mean((1-dg)**2)gen_losses.append(l)loss += l# loss是生成器的总损失,用于反向传播来更新生成器的参数# gen_losses是生成器的损失列表,用于记录鉴定器中每一个元素对应的损失,可以用于调试设备return loss, gen_losses
  • 结合代码来看,并没有将mel频谱图损失记录在内,这里仅仅包含了两个损失函数,generator_loss实现了最小二乘损失函数,feature_loss计算了鉴定器每一层的匹配的损失函数。

在这里插入图片描述

  • 她是把mel频谱图损失定义在训练过程中了.

鉴定器损失函数

  • 我们鉴定器的训练目标:
    • 能够将真实数据鉴定为真,标记为1
    • 能够将生成器生成的数据鉴定为假,标记为0
  • 所以,鉴定器的损失函数应该从两方面进行考虑,分别是鉴定生成数据和鉴定真实数据。
  • 具体的公式如下
    • s s s是梅尔频谱图,输入的条件变量
    • x x x是真实数据
    • D ( x ) D(x) D(x)是鉴定器对于输入结果的评分,越逼真越接近1
    • G ( s ) G(s) G(s)是生成器根据梅尔频谱图生成的结果

在这里插入图片描述

鉴定器损失函数代码

def discriminator_loss(disc_real_outputs, disc_generated_outputs):# 鉴定器的损失函数# disc_real_outputs是真实音频信号的鉴定器的输出# disc_generated_outputs是生成音频信号的鉴定器的输出loss = 0r_losses = []g_losses = []for dr, dg in zip(disc_real_outputs, disc_generated_outputs):# 计算真实音频信号的损失r_loss = torch.mean((1-dr)**2)# 计算生成音频信号的损失g_loss = torch.mean(dg**2)# 将两个损失相加loss += (r_loss + g_loss)# 记录各个鉴定器的损失r_losses.append(r_loss.item())g_losses.append(g_loss.item())return loss, r_losses, g_losses
  • 这个损失函数实现起来还是比较容易的,只需要分别计算两种数据的损失,然后累加求和即可

总结

  • 总的来说,这是第一次接触对抗生成学习,知道了对于鉴定器和生成器要分别定义,损失函数也是分别定义的。除此之外,他们的损失函数也是相互调用的。值得学习。
  • 下部分将讲述关于train文件具体内容,这个是模型的具体训练文件,定义了模型的前向传播和反向传播的过程。

引用

  • chatGPT-plus
  • HiFi-GAN demo
  • HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/54600.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rabbitmq的消息确认

配置文件 spring:rabbitmq:publisher-confirm-type: correlated #开启确认回调publisher-returns: true #开启返回回调listener:simple:acknowledge-mode: manual #设置手动接受消息消息从生产者到交换机 无论消息是否到交换机ConfirmCallback都会触发。 Resourceprivate Rabb…

Linux(环境变量)

Linux(环境变量) 常见环境变量查看环境变量方法和环境变量相关的指令环境变量的组织方式通过代码如何获取环境变量 环境变量(environment variables)一般是指在操作系统中用来指定操作系统运行环境的一些参数如:我们在编写C/C代码的时候&…

【Nginx基础】Nginx基础及安装

目录 Nginx出现背景Nginx 概念Nginx 作用Http 代理,反向代理负载均衡:内置策略和扩展策略内置策略:轮询内置策略:加权轮询内置策略:IP hash 动静分离 安装 NginxWindows下安装(nginx-1.16.1)Lin…

安装zabbix5.0监控

官网安装手册: https://www.zabbix.com/cn/download 一、 安装zabbix a. 安装yum源 rpm -Uvh https://repo.zabbix.com/zabbix/5.0/rhel/7/x86_64/zabbix-release-5.0-1.el7.noarch.rpmyum clean allb. 安装Zabbix server,web前端,agent y…

Kubernetes关于cpu资源分配的设计

kubernetes资源 在K8s中定义Pod中运行容器有两个维度的限制: 资源需求(Requests):即运行Pod的节点必须满足运行Pod的最基本需求才能运行Pod。如 Pod运行至少需要2G内存,1核CPU。(软限制)资源限额(Limits):即运行Pod期间,可能内存使用量会增加,那最多能使用多少内存,这…

MongoDB 使用总结

🍓 简介:java系列技术分享(👉持续更新中…🔥) 🍓 初衷:一起学习、一起进步、坚持不懈 🍓 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正🙏 🍓 希望这篇文章对你有所帮助,欢…

【技能实训】DMS数据挖掘项目(完整程序)

文章目录 1. 系统需求分析1.1 需求概述1.2 需求说明 2. 系统总体设计2.1 编写目的2.2 总体设计2.2.1 功能划分2.2.2 数据库及表2.2.3 主要业务流程 3. 详细设计与实现3.1 表设计3.2 数据库访问工具类设计3.3 配置文件3.4 实体类及设计3.5 业务类及设计3.6 异常处理3.7 界面设计…

docker 安装 字体文件

先说一下我当前的 场景 及 环境,这样同学们可以先评估本篇文章是否有帮助。 环境: dockerphp8.1-fpmwindows 之所以有 php,是因为这个功能是使用 php 开发的,其他语言的同学,如果也有使用到 字体文件,那么…

zookeeper集群和kafka的相关概念就部署

目录 一、Zookeeper概述 1、Zookeeper 定义 2、Zookeeper 工作机制 3、Zookeeper 特点 4、Zookeeper 数据结构 5、Zookeeper 应用场景 (1)统一命名服务 (2)统一配置管理 (3)统一集群管理 (4&a…

vue2-vue项目中你是如何解决跨域的?

1、跨域是什么? 跨域本质是浏览器基于同源策略的一种安全手段。 同源策略(sameoriginpolicy),是一种约定,它是浏览器最核心也是最基本的安全功能。 所谓同源(即指在同一个域)具有以下三个相同点…

mysql大表的深度分页慢sql案例(跳页分页)

1 背景 有一张表,内容是 redis缓存中的key信息,数据量约1000万级, expiry列上有一个普通B树索引。 -- test.top definitionCREATE TABLE top (database int(11) DEFAULT NULL,type varchar(50) DEFAULT NULL,key varchar(500) DEFAULT NUL…

【驱动开发day8作业】

作业1&#xff1a; 应用层代码 #include <stdlib.h> #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <unistd.h> #include <string.h> #include <sys/ioctl.h>int main(int…