Fisher矩阵与自然梯度法

文章目录

    • Fisher矩阵及自然梯度法
      • Fisher矩阵
      • 自然梯度法
      • 总结
      • 参考链接

Fisher矩阵及自然梯度法

自然梯度法相比传统的梯度下降法具有以下优势:

更好的适应性:自然梯度法通过引入黎曼流形上的梯度概念,能够更好地适应参数空间的几何结构。这使得自然梯度法在求解具有复杂几何结构的参数优化问题时具有更高的效率。
更高的收敛速度:由于自然梯度法考虑了参数空间的几何结构,因此它在参数更新过程中能够更准确地找到下降方向。这有助于加快算法的收敛速度,提高优化效率。
避免局部最优解:自然梯度法通过调整参数空间的几何结构,有助于避免陷入局部最优解。这使得自然梯度法在求解全局优化问题时具有更好的性能。

Fisher矩阵

我们使用迭代法求解问题时,计算每个参数相对于损失函数的导数,即雅可比矩阵。这些导数表示可以更新参数,以获得损失函数最大变化的方向,称为梯度。

得到梯度后,我们在梯度方向的负方向上,更新参数,从而减少损失函数。但是由于参数所带的噪声的不同,不同的参数应该据此调整步长。

在这里插入图片描述
在这里插入图片描述
如上两个图,同样的期望差异,但是明显参数的差异对符合第二图中的分布的函数影响更大。

因此,定义Fisher矩阵用来衡量参数空间的曲率,即参数对函数模型的敏感性。

定义一个最大似然问题: max ⁡ p ( x ∣ θ ) \max p(x|\theta) maxp(xθ)

再设立一个得分函数 s ( θ ) s(\theta) s(θ),定义为似然函数的对数梯度,
s ( θ ) = ∇ θ log ⁡ p ( x ∣ θ ) s(\theta)=\nabla_\theta\log p(x\mid\theta) s(θ)=θlogp(xθ)
用来评估我们最大化似然函数的好坏以及参数对似然函数的敏感度。其期望值为0, E p ( x ∣ θ ) [ s ( θ ) ] = 0 \mathbb{E}_{p(x|\theta)}[s(\theta)]=0 Ep(xθ)[s(θ)]=0,证明过程见参考文献。

Fisher矩阵定义为函数 s ( θ ) s(\theta) s(θ)的方差: F ( θ ) = E p ( x ∣ θ ) [ ( s ( θ ) − 0 ) ( s ( θ ) − 0 ) T ] = E p ( x ∣ θ ) [ ∇ θ log ⁡ p ( x ∣ θ ) ∇ θ log ⁡ p ( x ∣ θ ) T ] F(\theta) =\mathbb{E}_{p(x|\theta)}\left[(s(\theta)-0)(s(\theta)-0)^\mathrm{T}\right] =\mathbb{E}_{p(x|\theta)}\left[\nabla_\theta\log p(x\mid\theta)\nabla_\theta\log p(x\mid\theta)^\mathrm{T}\right] F(θ)=Ep(xθ)[(s(θ)0)(s(θ)0)T]=Ep(xθ)[θlogp(xθ)θlogp(xθ)T]

= E [ ( ∂ log ⁡ p ( x , θ ) ∂ θ ) ( ∂ log ⁡ p ( x , θ ) ∂ θ ) T ] = E\left[\left(\frac{\partial \log p(x, \theta)}{\partial \theta}\right)\left(\frac{\partial \log p(x, \theta)}{\partial \theta}\right)^{\rm T} \right] =E[(θlogp(x,θ))(θlogp(x,θ))T]

其中 E E E表示期望操作。Fisher矩阵可以帮助我们理解参数对模型的影响,以及在优化过程中如何调整参数以更有效地学习函数。

自然梯度法

同样使用传统的Euclidean 距离来衡量参数的差异时,参数的噪声分布并未被考虑。使用KL散度用来衡量,然而,同样观察文中的两图,如果我们只在参数空间中工作,我们就不能考虑关于参数实现的分布的这些信息。但在分布空间中,即当我们考虑高斯的形状时,第一和第二图像中的距离是不同的。在第一幅图像中,KL散度应该更低,因为这些高斯之间有更多的重叠。因而,我们使用KL散度来衡量参数的正确性:
D K L ( p ∥ q ) = ∑ i = 1 n p ( x i ) log ⁡ ( p ( x i ) q ( x i ) ) D_{KL}\left(p\|q\right)=\sum_{i=1}^{n}p\left(x_{i}\right)\log\left(\frac{p(x_{i})}{q(x_{i})}\right) DKL(pq)=i=1np(xi)log(q(xi)p(xi))

θ \theta θ为参数期望, θ ′ \theta' θ为参数估计,则定义其损失函数(KL散度)如下:
L ( θ ) = K L [ p ( x ∣ θ ) ∥ p ( x ∣ θ ′ ) ] = E ⁡ p ( x ∣ θ ) [ log ⁡ p ( x ∣ θ ) ] − E ⁡ p ( x ∣ θ ) [ log ⁡ p ( x ∣ θ ′ ) ] \mathcal{L}(\theta)=\mathrm{KL}\left[p(x\mid\theta)\|p\left(x\mid\theta^{\prime}\right)\right]=\underset{p(x\mid\theta)}{\operatorname*{\mathbb{E}}}\left[\log p(x\mid\theta)\right]-\underset{p(x\mid\theta)}{\operatorname*{\mathbb{E}}}\left[\log p\left(x\mid\theta^{\prime}\right)\right] L(θ)=KL[p(xθ)p(xθ)]=p(xθ)E[logp(xθ)]p(xθ)E[logp(xθ)]

Fisher信息矩阵 F F F等同于两个分布 p ( x ∣ θ ) p(x|\theta) p(xθ p ( x ∣ θ ′ ) p(x|\theta') p(xθ)之间关于 θ ′ \theta' θ的KL散度的Hessian矩阵。 证明过程见参考文献

算法伪代码:

自然梯度下降循环:
  对我们的模型进行正向传递,并计算损失 L ( θ ) \mathcal{L}(\theta) L(θ)
  计算梯度 ∇ θ L ( θ ) \nabla_\theta \mathcal{L}(\theta) θL(θ)
  计算Fisher信息矩阵 F F F或其经验版本。
  计算自然梯度 ∇ ~ θ L ( θ ) = F − 1 ∇ θ L ( θ ) \tilde{\nabla}_\theta\mathcal{L}(\theta)=\mathrm{F}^{-1}\nabla_\theta\mathcal{L}(\theta) ~θL(θ)=F1θL(θ)
  更新参数: θ = θ − α ∇ ~ θ L ( θ ) \theta = \theta - \alpha\tilde{\nabla}_\theta\mathcal{L}(\theta) θ=θα~θL(θ),其中 α \alpha α是学习率。
  直到收敛。

总结

Fisher矩阵和自然梯度法是机器学习中重要的概念和方法,用于优化问题的求解。Fisher矩阵可以帮助我们理解参数空间的曲率,而自然梯度法则利用Fisher矩阵的信息来更好地学习函数。通过结合这两个概念,我们可以更有效地优化模型参数,并提高学习的效率和性能。

参考链接

Fisher矩阵 https://agustinus.kristia.de/techblog/2018/03/11/fisher-information/
自然梯度法 https://agustinus.kristia.de/techblog/2018/03/14/natural-gradient/
自然梯度法 https://kvfrans.com/what-is-the-natural-gradient-and-where-does-it-appear-in-trust-region-policy-optimization/
https://zhuanlan.zhihu.com/p/546885304

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/511139.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何本地安装gemma

目录 通过ollama开源软件来一键安装目前主流的大模型,支持的开源模型包括以下内容: https://github.com/ollama/ollama

基于pytorch的手写体识别

一、环境搭建 链接: python与深度学习——基础环境搭建 二、数据集准备 本次实验用的是MINIST数据集,利用MINIST数据集进行卷积神经网络的学习,就类似于学习单片机的点灯实验,学习一门机器语言输出hello world。MINIST数据集,可以…

【树】【异或】【深度优先】【DFS时间戳】2322. 从树中删除边的最小分数

作者推荐 【二分查找】【C算法】378. 有序矩阵中第 K 小的元素 涉及知识点 树 异或 DFS时间戳 LeetCode2322. 从树中删除边的最小分数 存在一棵无向连通树,树中有编号从 0 到 n - 1 的 n 个节点, 以及 n - 1 条边。 给你一个下标从 0 开始的整数数组…

京东商品优惠券API获取商品到手价

item_get_app-获得JD商品详情原数据 公共参数 请求地址: jd/item_get_app 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,i…

(十五)【Jmeter】取样器(Sampler)之HTTP请求

简述 操作路径如下: HTTP请求 (HTTP Sampler): 作用:模拟发送HTTP请求并获取响应。配置:设置URL、请求方法、请求参数等参数。使用场景:测试Web应用程序的HTTP接口性能。优点:支持多种HTTP方法和请求参数,适用于大多数Web应用程序测试。缺点:功能较为基础,对于复杂…

鸿蒙实战应用开发:【拨打电话】功能

概述 本示例通过输入电话,进行电话拨打,及电话相关信息的显示。 样例展示 涉及OpenHarmony技术特性 网络通信 基础信息 拨打电话 介绍 本示例使用call相关接口实现了拨打电话并显示电话相关信息的功能 效果预览 使用说明 1.输入电话号码后&#…

11. Nginx进阶-HTTPS

简介 基本概述 SSL SSL是安全套接层。 主要用于认证用户和服务器,确保数据发送到正确的客户机和服务器上。 SSL可以加密数据,防止数据中途被窃取。 SSL也可以维护数据的完整性,确保数据在传输过程中不被改变。 HTTPS HTTPS就是基于SSL来…

#QT(串口助手-界面)

1.IDE:QTCreator 2.实验:编写串口助手 3.记录 接收框:Plain Text Edit 属性选择:Combo Box 发送框:Line Edit 广告:Group Box (1)仿照现有串口助手设计UI界面 (2)此时串口助手大…

C#插入排序算法

插入排序实现原理 插入排序算法是一种简单、直观的排序算法,其原理是将一个待排序的元素逐个地插入到已经排好序的部分中。 具体实现步骤如下 首先咱们假设数组长度为n,从第二个元素开始,将当前元素存储在临时变量temp中。 从当前元素的前一…

Windows环境MySQL全量备份+增量备份

目录 一、环境准备 1.1.安装MySQL 1.2.添加log-bin日志配置 二、创建测试数据库和表 2.1.创建测试数据库 2.2.创建测试数据表 三、全量备份恢复数据库 3.1.全量备份数据库 3.2全量恢复数据库 四、增量备份恢复数据库 4.1.增量备份数据库 4.2.增量恢复数据库 五、…

避坑——Matlab c# 联合编程——Native

相同的库,Matlab生成供.net调用的库时会有两套,也就是Native(本地),两套库各有优缺点,这这里就不说了,可以翻看网上其他博文 主要是MWStructArray,MWArray等数据交换对象有两套&…

科技云报道:阿里云降价,京东云跟进,谁能打赢云计算价格战?

科技云报道原创。 就在大家还在回味2月29日阿里云发布“史上最大降价”的惊喜时,京东云连夜发布降价消息,成为第一家跟进的云服务商,其“随便降,比到底!”的口号,颇有对垒的意味,直接吹响了云计…