最优传输学习及问题总结

文章目录

  • 参考内容
  • lam=0.1
  • lam=3
  • lam=10
  • lam=50
  • lam=100
  • lam=300
  • 画图
  • 线性规划
    • matlab
    • python代码

参考内容

https://blog.csdn.net/qq_41129489/article/details/128830589
https://zhuanlan.zhihu.com/p/542379144

我主要想强调的是这个例子的解法存在的一些细节问题

lam=0.1

lam = 0.1P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

结果如下
在这里插入图片描述

lam=3

lam = 3P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))
print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=10

lam = 10P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=50

lam = 50P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=100

lam = 100P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)
PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个很接近了

在这里插入图片描述

lam=300

lam = 300P, d = compute_optimal_transport(M,r,c, lam=lam)partition = pd.DataFrame(P, index=np.arange(1, 9), columns=np.arange(1, 6))
ax = partition.plot(kind='bar', stacked=True)
print('Sinkhorn distance: {}'.format(d))
ax.set_ylabel('portions')
ax.set_title('Optimal distribution ($\lambda={}$)'.format(lam))print(d)
PP = np.around(P,3) 
print(PP)print("*"*100)
print(np.sum(PP,axis=0))
print(np.sum(PP,axis=1))
print("*"*100)
## 这个就不接近了,之前的求和都是相差在0.001左右,可以近似看作相等
## 但是这个行和是 [2.    1.714 3.75  2.286 2.5   2.5   4.    1.25 ]
## 很明显是 [3. 3. 3. 4. 2. 2. 2. 1.]这个是不对的,所以lam=300时这个值已经发散了,
## 虽然此时的Sinkhorn distance是小于24的,但也不起作用

在这里插入图片描述

画图

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltdef compute_optimal_transport(M=None, r=None, c=None, lam=None, eplison=1e-8):"""Computes the optimal transport matrix and Slinkhorn distance using theSinkhorn-Knopp algorithmInputs:- M : cost matrix (n x m)- r : vector of marginals (n, )- c : vector of marginals (m, )- lam : strength of the entropic regularization- epsilon : convergence parameterOutputs:- P : optimal transport matrix (n x m)- dist : Sinkhorn distance"""r = np.array([3, 3, 3, 4, 2, 2, 2, 1])c = np.array([4, 2, 6, 4, 4])M = np.array([[2, 2, 1, 0, 0], [0, -2, -2, -2, -2], [1, 2, 2, 2, -1], [2, 1, 0, 1, -1],[0.5, 2, 2, 1, 0], [0, 1, 1, 1, -1], [-2, 2, 2, 1, 1], [2, 1, 2, 1, -1]],dtype=float) M = -M # 将M变号,从偏好转为代价n, m = M.shape  # 8, 5P = np.exp(-lam * M) # (8, 5)P /= P.sum()  # 归一化u = np.zeros(n) # (8, )# normalize this matrixwhile np.max(np.abs(u - P.sum(1))) > eplison: # 这里是用行和判断收敛# 对行和列进行缩放,使用到了numpy的广播机制,不了解广播机制的同学可以去百度一下u = P.sum(1) # 行和 (8, )P *= (r / u).reshape((-1, 1)) # 缩放行元素,使行和逼近rv = P.sum(0) # 列和 (5, )P *= (c / v).reshape((1, -1)) # 缩放列元素,使列和逼近creturn P, np.sum(P * M) # 返回分配矩阵和Sinkhorn距离lam_list=[1,5,10,20,30,40,50,60,70,80,90,100,110,120,130,140,150]cost_list=[]
for lam in lam_list:P, d = compute_optimal_transport(lam=lam)cost_list.append(d)
print(cost_list)
plt.plot(np.array(lam_list),np.array(cost_list),c="g")
plt.show()## 现在这个地方也有的

在这里插入图片描述
这个地方其实有一个画图的小问题,我待会要再写一下

可以看到大概是在lam =150的时候,就已经不稳定了,所以这个例子的问题的解的最小花费约等于24,但是我发现一个更有意思的问题,就是这个分配矩阵是唯一的吗,很显然不是的, 利用我上篇文章学到的线性规划,我发现matlab和python找到的是两个不同的解,

线性规划

matlab

clc;
clear;r = [3, 3, 3, 4, 2, 2, 2, 1];
c = [4, 2, 6, 4, 4];
cost_matrix =  [2, 2, 1, 0, 0;0, -2, -2, -2, -2; 1, 2, 2, 2, -1;2, 1, 0, 1, -1;0.5, 2, 2, 1, 0;0, 1, 1, 1, -1;-2, 2, 2, 1, 1;2, 1, 2, 1, -1];cost_matrix_t = (-1)*transpose(cost_matrix);% 需要有符号
cost_vec = cost_matrix_t(:);raw_equ = zeros(8,40);
for i =1:8raw_equ(i,((i-1)*5+1):((i-1)*5+5))=1;
endcol_equ = zeros(5,40);
for i =1:5for j =1:8col_equ(i,i+(j-1)*5)=1;end
endequ = [raw_equ;col_equ];
equ_value = horzcat(r, c);
% x1,x2,x3,x4,x5
% x6,x7,x8,x9,x10
% x11,x12,x13,x14,x15
% x16,x17,x18,x19,x20
% x21,x22,x23,x24,x25
% x26,x27,x28,x29,x30
% x31,x32,x33,x34,x35
% x36,x37,x38,x39,x40% 现在我要求的变量是这样的,
f=cost_vec;			% 价值向量
a=[];	% a、b对应不等式的左边和右边
b=[];
aeq=equ;	% aeq和beq对应等式的左边和右边
beq=equ_value;
[x,y]=linprog(f,a,b,aeq,beq,zeros(40,1));arr_mat = transpose(reshape(x',5,8));

结果如下
在这里插入图片描述
分配矩阵如下在这里插入图片描述

python代码

# Define parameters
m = 8
n = 5p = np.array([3, 3, 3, 4, 2, 2, 2, 1])
q = np.array([4, 2, 6, 4, 4])C = -1*np.array([[2, 2, 1, 0, 0], [0, -2, -2, -2, -2], [1, 2, 2, 2, -1], [2, 1, 0, 1, -1],[0.5, 2, 2, 1, 0], [0, 1, 1, 1, -1], [-2, 2, 2, 1, 1], [2, 1, 2, 1, -1]],dtype=float)# Vectorize matrix C
C_vec = C.reshape((m*n, 1), order='F')# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A = np.vstack([A1, A2])# Construct vector b
b = np.hstack([p, q])# Solve the primal problem
res = linprog(C_vec, A_eq=A, b_eq=b)# Print results
print("message:", res.message)
print("nit:", res.nit)
print("fun:", res.fun)
print("z:", res.x)
print("X:", res.x.reshape((m,n), order='F'))

结果如下
在这里插入图片描述
可以看到花费都是24,但是两者的分配矩阵并不一样哈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/419029.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

EasyDarwin计划新增将各种流协议(RTSP、RTMP、HTTP、TCP、UDP)、文件转推RTMP到其他视频直播平台,支持转码H.264、文件直播推送

之前我们尝试做过EasyRTSPLive(将RTSP流转推RTMP)和EasyRTMPLive(将各种RTSP/RTMP/HTTP/UDP流转推RTMP,这两个服务在市场上都得到了比较多的好评,其中: 1、EasyRTSPLive用的是EasyRTSPClient取流&#xff…

MySQL缓冲池(Buffer Pool)深入解析:原理、组成及其在数据操作中的核心作用

在关系型数据库管理系统(RDBMS)中,性能优化一直是数据库管理员和开发者关注的焦点。作为最流行的开源RDBMS之一,MySQL提供了多种优化手段,其中InnoDB存储引擎的缓冲池(Buffer Pool)是最为关键的…

ctfshow-反序列化(web271-web276)

目录 web271 web272-273 web274 web275 web276 为什么不用分析具体为什么能成功 ,后面会有几个专题 会对php框架进行更深入的了解 这里面会专门的研究 为什么能够实现RCE 前面作为初步的熟悉 首先知道一下他的框架 知道框架的风格 知道啥版本可以用什么来打 首先先不用太研…

Spring+SprinMVC+MyBatis注解方式简易模板

SpringSprinMVCMyBatis注解方式简易模板代码Demo GitHub访问 ssm-tpl-anno 一、数据准备 创建数据库test,执行下方SQL创建表ssm-tpl-cfg /*Navicat Premium Data TransferSource Server : 127.0.0.1Source Server Type : MySQLSource Server Version :…

2008年苏州大学837复试机试C语言

2008年苏州大学复试机试C 题目 编写程序充成以下功能: 一、从键盘上输入随机变量x的 10个取样点。X0,X1—X9 的值; 1、计算样本平均值 2、判定x是否为等差数列 3、用以下公式计算z的值(t0.63) 注。请对程序中必要地方进行注释 补充:个人觉得这个题目回…

canvas绘制正三边形,正四边形,正五边形...正N边形

查看专栏目录 canvas实例应用100专栏,提供canvas的基础知识,高级动画,相关应用扩展等信息。canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重…

使 a === 1 a === 2 a === 3 为 true 的几种“下毒“方法

前言 这算得上是近些年的前端网红题了,曾经对这种网红题非常抵触,认为非常没有意义。 看到了不少人有做分享,有各种各样的方案,有涉及到 JS 非常基础的知识点,也不得不感叹解题者的脑洞之大。 但是,拿来…

代码随想录二刷 | 二叉树 | 修剪二叉搜索树

代码随想录二刷 | 二叉树 | 修剪二叉搜索树 题目描述解题思路代码实现 题目描述 669.修剪二叉搜索树 给定一个二叉搜索树,同时给定最小边界 L 和最大边界 R。通过修剪二叉搜索树,使得所有节点的值在[L, R]中 (R>L) 。你可能需要改变树的根节点&…

立体视觉几何(一)

1.什么是立体视觉几何 立体视觉对应重建: • 对应:给定一幅图像中的点pl,找到另一幅图像中的对应点pr。 • 重建:给定对应关系(pl, pr),计算空间中相应点的3D 坐标P。 立体视觉:从图像中的投影恢复场景中点…

一台手机用4年多,国产手机从态度傲慢到跪求消费者换机

分析机构trendforce公布的数据指出,中国消费者的换机周期已延长到51个月,面对消费者对国产手机用脚投票,如今国产手机企业开始采取多方举措,祈求消费者买手机,市场的变化促使国产手机不得不改变态度。 2010年国产手机刚…

VM使用教程--SDK取图 视频笔记

本笔记均由海康机器人官网的V学院视频中记录所得,属于省流大师了[doge] 图像采集 图像采集包括1图像源,2多图采集,3输出图像,4缓存图像,5光源 1图像源 图像源包括本地图像,相机采图,SDK 本…

二维旋转公式推导+旋转椭圆的公式推导

二维旋转公式推导+旋转椭圆的公式推导 二维旋转公式推导旋转椭圆的公式推导二维旋转公式推导 x , y x,y x,y表示二维坐标系中原坐标点, x ′ , y ′ x,y x′,y′表示逆时针旋转 β \beta β之后的坐标点: x ′ = x cos ⁡ ( β ) − y sin ⁡ ( β ) y ′ = y cos ⁡ ( β )…