KL散度
KL散度用于衡量两个变量分布之间的差异性
K L ( P ∣ ∣ Q ) = ∫ − ∞ + ∞ p ( x ) log p ( x ) q ( x ) d x (1) KL(P\ ||\ Q)=\int_{-\infty}^{+\infty}p(x)\log\frac{p(x)}{q(x)}dx\tag{1} KL(P ∣∣ Q)=∫−∞+∞p(x)logq(x)p(x)dx(1)
P、Q为随机变量X的两个概率分布;p、q为对应的概率密度函数
如果P,Q均为高斯分布,即:
P = N ( μ 1 , σ 1 2 ) Q = N ( μ 2 , σ 2 2 ) (2) P=\mathcal{N}(\mu_1,\sigma^2_1)\\ Q=\mathcal{N}(\mu_2,\sigma^2_2)\tag{2} P=N(μ1,σ12)Q=N(μ2,σ22)(2)
那么(1)可以化简为:
K L ( N ( μ 1 , σ 1 2 ) ∣ ∣ N ( μ 2 , σ 2 2 ) ) = log σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(\mathcal{N}(\mu_1,\sigma^2_1)\ ||\ \mathcal{N}(\mu_2,\sigma^2_2))=\log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(N(μ1,σ12) ∣∣ N(μ2,σ22))=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
用matplotlib库绘制两个高斯分布的KL散度变化动画:
红线:
- μ \mu μ在(-10,10)之间变化
- σ \sigma σ始终为1
蓝线:
- μ \mu μ始终为0
- σ \sigma σ始终为1
代码如下:
import functools
from typing import List
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation# 更新函数,用于在动画中更新正态分布的均值
def update(mean,pdf_line: Line2D,kl_div_line: Line2D,connection_line: ConnectionPatch,points: List[Line2D],x_min: float,x_max: float,step: float,
):# 生成两个正态分布的概率密度函数x1 = np.arange(x_min, x_max + step, step)y1 = 1 / (np.sqrt(2 * np.pi)) * np.exp(-((x1 - mean) ** 2) / 2)x2 = np.arange(x_min, mean + step, step)y2 = (1 + x2**2) / 2 - 0.5# 更新线的数据pdf_line.set_data(x1, y1)kl_div_line.set_data(x2, y2)connection_line.xy1 = (mean, 1 / (np.sqrt(2 * np.pi)))connection_line.xy2 = (mean, (1 + mean**2) / 2 - 0.5)points[0].set_data(np.expand_dims(connection_line.xy1, axis=-1))points[1].set_data(np.expand_dims(connection_line.xy2, axis=-1))if __name__ == "__main__":# 创建图形和坐标轴fig, ax = plt.subplots(2, 1, sharex=True, figsize=(8, 6.6))fig.suptitle("Animation of two Gauss Distribution's KL divergence", x=0.50, y=0.92)x_min, x_max, step = -10, 10, 0.10x_ords = np.linspace(x_min, x_max, 200, endpoint=False)ax[0].set_xlim(x_min, x_max)ax[0].set_ylim(-0.2, 0.6)ax[0].set_ylabel("Probability Density")ax[1].set_xlim(x_min, x_max)ax[1].set_ylim(-5, 100)ax[1].set_xlabel("Mean")ax[1].set_ylabel("KL Divergence")# 绘制标准正态分布ax[0].plot(x_ords,1 / (np.sqrt(2 * np.pi)) * np.exp(-(x_ords**2) / 2),label="mean=0 & std=1",)# 初始化动画时要绘制的线(pdf_line,) = ax[0].plot([], [], color="red", label="mean=[-10, 10] & std=1")(kl_div_line,) = ax[1].plot([], [], color="purple")(point1,) = ax[0].plot([-10], [0], color="cyan", marker="o")(point2,) = ax[1].plot([-10], [0], color="cyan", marker="o")ax[0].legend(), ax[0].grid(), ax[1].grid()connection = ConnectionPatch([-10, 0],[-10, 0],"data","data",axesA=ax[0],axesB=ax[1],ls="dotted",lw=2,color="pink",)fig.add_artist(connection)# 创建动画animation = FuncAnimation(fig,func=functools.partial(update,pdf_line=pdf_line,kl_div_line=kl_div_line,connection_line=connection,points=[point1, point2],x_min=x_min,x_max=x_max,step=step,),frames=np.arange(x_min, x_max, step),interval=50,)plt.show()animation.save("KL.gif")