数据聚类:Mean-Shift和EM算法


目录

  • 1. 高斯混合分布
  • 2. Mean-Shift算法
  • 3. EM算法
  • 4. 数据聚类
  • 5. 源码地址


1. 高斯混合分布

在高斯混合分布中,我们假设数据是由多个高斯分布组合而成的。每个高斯分布被称为一个“成分”(component),这些成分通过加权和的方式来构成整个混合分布。

高斯混合分布的公式可以表示为:

p ( x ) = ∑ i = 1 K π i N ( x ∣ μ i , Σ i ) p(x) = \sum^K_{i=1} \pi_i N(x|\mu_i, \Sigma_i) p(x)=i=1KπiN(xμi,Σi)

其中:

  • p ( x ) p(x) p(x)是观测数据点 x x x的概率密度函数,
  • K K K是高斯分布的数量,
  • π i \pi_i πi是第 i i i个高斯分布的混合系数,满足 ∑ i = 1 K π i = 1 \sum^K_{i=1} \pi_i = 1 i=1Kπi=1,
  • μ i \mu_i μi是第 i i i个高斯分布的均值向量,
  • Σ i \Sigma_i Σi是第 i i i个高斯分布的协方差矩阵。

为了简单呈现结果,我们选取 K = 2 K=2 K=2个高斯分布。下图为一个高斯混合分布的采样散点图,其中两个高斯分布的 μ i \mu_i μi分别为 [ 0 , 0 ] , [ 5 , 5 ] [0,0], [5,5] [0,0],[5,5],协方差矩阵均为:
[ 1 0 0 1 ] \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} [1001]

在这里插入图片描述

Fig. 1. 高斯混合分布的采样散点图

2. Mean-Shift算法

Mean-Shift是一种非参数化的密度估计和聚类算法,用于将数据点组织成具有相似特征的群集。它是一种迭代算法,通过计算数据点的梯度信息来寻找数据点在特征空间中的密度极值点,从而确定聚类中心。

算法的核心思想是通过不断地更新数据点的位置,将它们移向密度估计函数梯度的最大方向,直到达到收敛条件。具体来说,Mean-Shift算法包括以下步骤:

  • 初始化:选择一个数据点作为初始聚类中心,或者随机选择一个点作为初始中心。
  • 确定梯度向量:对于每个数据点,计算其与其他数据点之间的距离,并根据一定的核函数(如高斯核)计算梯度向量。梯度向量的方向指向密度估计函数增加最快的方向。
  • 移动数据点:将每个数据点移动到梯度向量的方向上,即向密度估计函数增加最快的方向移动一定的步长。
  • 更新聚类中心:对于移动后的每个数据点,重新计算它们周围数据点的梯度向量,并更新它们的位置。重复这个过程直到达到收敛条件,比如聚类中心的移动距离小于某个阈值。
  • 形成聚类:最终,根据收敛后的聚类中心,将数据点分配到最近的聚类中心,形成最终的聚类结果。

Mean-Shift算法的优点是不需要事先指定聚类的个数,且能够自适应地调整聚类中心的数量和形状。它在处理非线性和非凸形状的数据集时表现出良好的聚类效果。然而,该算法对于大规模数据集的计算复杂度较高,且对初始聚类中心的选择敏感。Mean-Shift算法的具体实现见代码片:

class MeanShift:def __init__(self, bandwidth=1.0, max_iterations=100):self.min_shift = 1self.n_clusters_ = Noneself.cluster_centers_ = Noneself.labels_ = Noneself.bandwidth = bandwidthself.max_iterations = max_iterationsdef euclidean_distance(self, x1, x2):return np.sqrt(np.sum((x1 - x2) ** 2))def gaussian_kernel(self, distance, bandwidth):return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth) ** 2))def shift_point(self, point, data, bandwidth):shift_x = 0.0shift_y = 0.0total_weight = 0.0for i in range(len(data)):distance = self.euclidean_distance(point, data[i])weight = self.gaussian_kernel(distance, bandwidth)shift_x += data[i][0] * weightshift_y += data[i][1] * weighttotal_weight += weightshift_x /= total_weightshift_y /= total_weightreturn np.array([shift_x, shift_y])def fit(self, data):centroids = np.copy(data)for _ in range(self.max_iterations):shifts = np.zeros_like(centroids)for i, centroid in enumerate(centroids):distances = cdist([centroid], data)[0]weights = self.gaussian_kernel(distances, self.bandwidth)shift = np.sum(weights[:, np.newaxis] * data, axis=0) / np.sum(weights)shifts[i] = shiftshift_distances = cdist(shifts, centroids)centroids = shiftsif np.max(shift_distances) < self.min_shift:breakunique_centroids = np.unique(np.around(centroids, 3), axis=0)self.cluster_centers_ = unique_centroidsself.labels_ = np.argmin(cdist(data, unique_centroids), axis=1)self.n_clusters_ = len(unique_centroids)

3. EM算法

EM算法是一种迭代的数值优化算法,用于求解包含隐变量的概率模型参数的最大似然估计。它在统计学和机器学习领域被广泛应用,尤其在聚类问题中有着重要的应用。其基于观测数据和隐变量之间的概率模型,通过交替进行两个步骤:E步骤(Expectation Step)和M步骤(Maximization Step)来迭代地优化模型参数。下面是EM算法的基本步骤:

  • 初始化:选择一组初始参数来开始迭代过程。
  • E步骤:根据当前的参数估计,计算隐变量的后验概率,即给定观测数据下隐变量的条件概率分布。
  • M步骤:使用在E步骤中计算得到的后验概率,对参数进行更新,以最大化对数似然函数。
  • 重复步骤2-3至收敛:重复执行E步骤和M步骤,直到参数的变化很小或满足收敛条件。

在聚类问题中,EM算法可以用于估计混合高斯模型的参数,从而实现数据的聚类。EM算法在聚类中的应用优点是能够处理具有隐变量的概率模型,适用于灵活的聚类问题。然而,EM算法对于初始参数的选择敏感,可能会陷入局部最优解,并且在处理大规模数据集时可能会面临计算复杂度的挑战。EM算法(包含正则化步骤)的具体实现见代码片:

class RegularizedEMClustering:def __init__(self, n_clusters, max_iterations=100, epsilon=1e-4, regularization=1e-6):self.labels_ = Noneself.X = Noneself.n_features = Noneself.n_samples = Noneself.cluster_probs_ = Noneself.cluster_centers_ = Noneself.n_clusters = n_clustersself.max_iterations = max_iterationsself.epsilon = epsilonself.regularization = regularizationdef fit(self, X):self.X = Xself.n_samples, self.n_features = X.shapeself.cluster_centers_ = X[np.random.choice(self.n_samples, self.n_clusters, replace=False)]self.cluster_probs_ = np.ones((self.n_samples, self.n_clusters)) / self.n_clusters# EMfor iteration in range(self.max_iterations):# E-stepprev_cluster_probs = self.cluster_probs_self._update_cluster_probs()# M-stepself._update_cluster_centers()delta = np.linalg.norm(self.cluster_probs_ - prev_cluster_probs)if delta < self.epsilon:breakself.labels_ = np.argmax(self.cluster_probs_, axis=1)def _update_cluster_probs(self):distances = np.linalg.norm(self.X[:, np.newaxis, :] - self.cluster_centers_, axis=2)# Calculate the cluster probabilities with regularizationnumerator = np.exp(-distances) + self.regularizationdenominator = np.sum(numerator, axis=1, keepdims=True)self.cluster_probs_ = numerator / denominatordef _update_cluster_centers(self):self.cluster_centers_ = np.zeros((self.n_clusters, self.n_features))for k in range(self.n_clusters):self.cluster_centers_[k] = np.average(self.X, axis=0, weights=self.cluster_probs_[:, k])def predict(self, X):distances = np.linalg.norm(X[:, np.newaxis, :] - self.cluster_centers_, axis=2)return np.argmin(distances, axis=1)

4. 数据聚类

Mean-Shift和EM算法的聚类结果分别如图2的a-b子图所示,由于MoG比较简单,两种算法均可以合理且完整地实现聚类,聚类中心也没有显著差异。

在这里插入图片描述

Fig. 2. Mean-Shift(a)和EM(b)算法的聚类结果

5. 源码地址

如果对您有用的话可以点点star哦~

https://github.com/Jurio0304/cs-math/blob/main/hw3_clustering.ipynb
https://github.com/Jurio0304/cs-math/blob/main/func.py


创作不易,麻烦点点赞和关注咯!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/645255.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cesium分屏对比功能实现,完整版代码案例

使用cesium开发的小伙伴们,分屏对比功能是视图功能中比较常见的一个需求。 这篇文章我们来教会大家如何实现这个功能。 首先我们要准备一左一右2个div容器,用来挂在两个cesium实例。 其实分屏对比的关键就在于左右两个视图如何联动起来。 那么我们需要借助相机之间的参数…

JavaScript精粹:26个关键字深度解析,编写高质量代码的秘诀!

JavaScript关键字是一种特殊的标识符&#xff0c;它们在语言中有固定的含义&#xff0c;不能用作变量名或函数名。这些关键字是JavaScript的基础&#xff0c;理解它们是掌握JavaScript的关键。 今天&#xff0c;我们将一起探索JavaScript中的26个关键字&#xff0c;了解这些关…

CountDownLatch使用错误+未最终断开连接导致线程池资源耗尽

错误描述&#xff1a; 我设置了CountDownLatch对线程的协作做出了一些限制&#xff0c;但是我发现运行一段时间以后便发现定时任务不运行了。 具体代码&#xff1a; public void sendToCertainWeb() throws IOException, InterruptedException {List<String> urlList …

动手学深度学习12 Dropout丢弃法

动手学深度学习12 Dropout丢弃法 1. 丢弃法2. 代码实现源码实现简洁实现torch.rand() 和 torch.randn() 两个函数的区别 3. QA 1. 丢弃法 在层之间加入噪音&#xff0c;不对输入层做处理。不是在输入数据上加噪音。 核心&#xff1a;为什么除以1-p 以上是训练过程使用的。…

C++之STL-String

目录 一、STL简介 1.1 什么是STL 1.2 STL的版本 1.3 STL的六大组件 ​编辑 1.4 STL的重要性 二、String类 2.1 Sting类的简介 2.2 string之构造函数 2.3 string类对象的容量操作 2.3.1 size() 2.3.2 length() 2.3.3 capacity() 2.3.4 empty() 2.3.5 clear() 2.3.6…

IEEE论文Word转高清PDF

一、问题描述 简单的操作word直接导出为PDF&#xff0c;会导致图片的模糊。 甚至在高级选项里选择分辨率为"高保真"&#xff08;图1&#xff09;&#xff0c;输出PDF时选择“标准”&#xff08;图2&#xff09;&#xff0c;也无法逃避图片的模糊&#xff08;图3&am…

UDS的0x19服务

0x19读取故障码信息 0x19的子功能01 19 01 用于读取故障码的数量。 DTC SM故障码的状态掩码 DTC FID所支持的故障码状态的情况 DTC Count存储故障码格式的标识符 DTC FID&#xff08;DTC的格式标识符&#xff09;&#xff0c;如下所示 常用的为00 0x19的子功能02 19 02 用…

开源社区与开发者的故事

开源社区与开发者的故事 什么是开源社区你参加开源社区的主要目的你是否在开源社区中贡献&#xff0c;或者开源自己的项目&#xff1f;你认为个人开发者是否应该从开源中获利&#xff1f;如果是&#xff0c;该如何获利&#xff1f; 今天要谈及的主题是开源社区&#xff0c;那么…

【InternLM实战营---第六节课笔记】

一、本期课程内容概述 本节课的主讲老师是【樊奇】。教学内容主要包括以下三个部分&#xff1a; 1.大模型智能体的背景及介绍 2. Lagent&AgentLego框架介绍 3.Lagent&AgentLego框架实战 二、学习收获 智能体出现的背景 智能体的引入旨在克服大模型在应对复杂、动态任…

Unity类银河恶魔城学习记录13-5,6 p146 Delete save file,p147 Encryption of saved data源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili FileDataHandler.cs using System; using System.IO; using UnityEngine; p…

数据科学/分析党的福音—亚马逊云科技Amazon Zero ETL(零ETL)技术介绍

2023年亚马逊云科技全球大会Re:invent上&#xff0c;数据产品VP Swami博士正式推出了Amazon Zero ETL服务&#xff0c;支持业务大数据从Aurora向Redshift的实时导入、分析。 过去在亚马逊云科技上构建数据分析平台&#xff0c;最令人头疼的莫属ETL环节。遇到的挑战包括:▶️提取…

ECharts海量数据渲染解决卡顿

file模块用来写文件 我们首先使用node来生成10万条数据; 借助node的fs模块就行; 如果不会的小伙伴;也不要担心;超级简单// 引入模块 let fs = require(fs); // 数据内容 let fileCont=我是文件内容 /*** 第一个参数是文件名* 第二个参数是文件内容,这个文件的内容必须是字…