StratifiedGroupKFold解释和代码实现

StratifiedGroupKFold解释和代码实现

文章目录

  • 一、StratifiedGroupKFold解释和代码实现是什么?
  • 二、 实验数据设置
    • 2.1 实验数据生成代码
    • 2.2 代码结果
  • 三、实验代码
    • 3.1 实验代码
    • 3.2 实验结果
    • 3.3 结果解释
  • 四、样本类别类别不平衡


一、StratifiedGroupKFold解释和代码实现是什么?

0,1,2,3:每一行表示测试集和训练集的划分的一种方式。
class:表示类别的个数(下图显示的是3类),有些交叉验证根据类别的比例划分测试集和训练集(例三)。
group:表示从不同的组采集到的样本,颜色的个数表示组的个数(有些时候我们关注在一组特定组上训练的模型是否能很好地泛化到看不见的组)。举个例子(解释“组”的意思):我们有10个人,我们想要希望训练集上所用的数据来自(1,2,3,4,5,6,7,8),测试集上的数据来自(9,10),也就是说我们不希望测试集上的数据和训练集上的数据来自同一个人(如果来自同一个人的话,训练集上的信息泄漏到测试集上了,模型的泛化性能会降低,测试结果会偏好)。
在这里插入图片描述
StratifiedGroupKFold 是一种交叉验证方案,结合了 StratifiedKFold 和 GroupKFold 两种方法。这个想法是尝试保留每个拆分中类(class)的分布,同时将每个组(group)保持在单个拆分中(拆分指的是训练集和测试集的拆分)。当您有一个不平衡的数据集时,这可能很有用,因此仅使用GroupKFold 可能会产生倾斜的拆分(类别的倾斜)。

二、 实验数据设置

2.1 实验数据生成代码

X, y = np.arange(0, 60).reshape((30, 2)), np.hstack(([0] * 9, [1] * 9, [2] * 12))
groups = np.hstack((["a"] * 4, ["b"] * 3, ["c"] * 5, ["d"] * 4, ["e"] * 5, ["f"] * 4, ["g"] * 5))
print("数据:", end=" ")
for l in X:print(l, end=' ')
print("")
print("标签:", y)
print("组别:", groups)

2.2 代码结果

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
标签: [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
组别: ['a' 'a' 'a' 'a' 'b' 'b' 'b' 'c' 'c' 'c' 'c' 'c' 'd' 'd' 'd' 'd' 'e' 'e''e' 'e' 'e' 'f' 'f' 'f' 'f' 'g' 'g' 'g' 'g' 'g']

数据个数、标签个数:30个
类别个数:3个(分别是0,1,2,比例是0.3:0.3:0.4和class每类对应)
组别(group):9个(分别是a-g,个数是4,3,5,4,5,4,5)

三、实验代码

3.1 实验代码

代码如下:

# Group k-fold
import numpy as np
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import StratifiedGroupKFoldX, y = np.arange(0, 60).reshape((30, 2)), np.hstack(([0] * 9, [1] * 9, [2] * 12))
groups = np.hstack((["a"] * 4, ["b"] * 3, ["c"] * 5, ["d"] * 4, ["e"] * 5, ["f"] * 4, ["g"] * 5))
print("数据:", end=" ")
for l in X:print(l, end=' ')
print("")
print("标签:", y)
print("组别:", groups)
sgkf = StratifiedGroupKFold(n_splits=3)
# for train, test in sgkf.split(X, y, groups=groups):
#     print("%s %s" % (train, test))
for i, (train, test) in enumerate(sgkf.split(X, y, groups=groups)):print("=================StratifiedGroupKFold 第%d折叠 ====================" % (i + 1))# print('train -  {}'.format(np.bincount(y[train])))print("  训练集索引:%s" % train)print("  训练集标签:", y[train])print("  训练集组别标签", groups[train])print("  训练集数据:", end=" ")for l in X[train]:print(l, end=' ')print("")# print("  训练集数据:", X[train])# print("test  -  {}".format(np.bincount(y[test])))print("  测试集索引:%s" % test)print("  测试集标签:", y[test])print("  测试集组别标签", groups[test])print("  测试集数据:", end=" ")for l in X[test]:print(l, end=' ')print("")# print("  测试集数据:", X[test])print("=============================================================")

3.2 实验结果

结果如下:

数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 
标签: [0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]
组别: ['a' 'a' 'a' 'a' 'b' 'b' 'b' 'c' 'c' 'c' 'c' 'c' 'd' 'd' 'd' 'd' 'e' 'e''e' 'e' 'e' 'f' 'f' 'f' 'f' 'g' 'g' 'g' 'g' 'g']
=================StratifiedGroupKFold 第1折叠 ====================训练集索引:[ 0  1  2  3  4  5  6 12 13 14 15 16 17 18 19 20 21 22 23 24]训练集标签: [0 0 0 0 0 0 0 1 1 1 1 1 1 2 2 2 2 2 2 2]训练集组别标签 ['a' 'a' 'a' 'a' 'b' 'b' 'b' 'd' 'd' 'd' 'd' 'e' 'e' 'e' 'e' 'e' 'f' 'f''f' 'f']训练集数据: [0 1] [2 3] [4 5] [6 7] [8 9] [10 11] [12 13] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] 测试集索引:[ 7  8  9 10 11 25 26 27 28 29]测试集标签: [0 0 1 1 1 2 2 2 2 2]测试集组别标签 ['c' 'c' 'c' 'c' 'c' 'g' 'g' 'g' 'g' 'g']测试集数据: [14 15] [16 17] [18 19] [20 21] [22 23] [50 51] [52 53] [54 55] [56 57] [58 59] 
=============================================================
=================StratifiedGroupKFold 第2折叠 ====================训练集索引:[ 4  5  6  7  8  9 10 11 12 13 14 15 25 26 27 28 29]训练集标签: [0 0 0 0 0 1 1 1 1 1 1 1 2 2 2 2 2]训练集组别标签 ['b' 'b' 'b' 'c' 'c' 'c' 'c' 'c' 'd' 'd' 'd' 'd' 'g' 'g' 'g' 'g' 'g']训练集数据: [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [50 51] [52 53] [54 55] [56 57] [58 59] 测试集索引:[ 0  1  2  3 16 17 18 19 20 21 22 23 24]测试集标签: [0 0 0 0 1 1 2 2 2 2 2 2 2]测试集组别标签 ['a' 'a' 'a' 'a' 'e' 'e' 'e' 'e' 'e' 'f' 'f' 'f' 'f']测试集数据: [0 1] [2 3] [4 5] [6 7] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] 
=============================================================
=================StratifiedGroupKFold 第3折叠 ====================训练集索引:[ 0  1  2  3  7  8  9 10 11 16 17 18 19 20 21 22 23 24 25 26 27 28 29]训练集标签: [0 0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2]训练集组别标签 ['a' 'a' 'a' 'a' 'c' 'c' 'c' 'c' 'c' 'e' 'e' 'e' 'e' 'e' 'f' 'f' 'f' 'f''g' 'g' 'g' 'g' 'g']训练集数据: [0 1] [2 3] [4 5] [6 7] [14 15] [16 17] [18 19] [20 21] [22 23] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] [50 51] [52 53] [54 55] [56 57] [58 59] 测试集索引:[ 4  5  6 12 13 14 15]测试集标签: [0 0 0 1 1 1 1]测试集组别标签 ['b' 'b' 'b' 'd' 'd' 'd' 'd']测试集数据: [8 9] [10 11] [12 13] [24 25] [26 27] [28 29] [30 31] 
=============================================================进程已结束,退出代码 0

3.3 结果解释

  • 可以看到每一折叠的测试集都有所有类别的样本,但是训练集可能只有部分类别的样本(如第3折叠)
  • 这种交叉验证只适用于类别相对不平衡的样本,但是当样本类别极不平衡时,这种交叉验证将会不具有参考价值。
  • 该种交叉验证即考虑到样本的组别(group),又考虑到样本的标签比例,是一个相对较好的交叉验证。

四、样本类别类别不平衡

X, y = np.arange(0, 60).reshape((30, 2)), np.hstack(([0] * 9, [1] * 9, [2] * 12))
改为下面的
X, y = np.arange(0, 60).reshape((30, 2)), np.hstack(([0] * 3, [1] * 6, [2] * 21))

类别个数:3个(分别是0,1,2,比例是0.1:0.3:0.7和class每类对应)

=================StratifiedGroupKFold 第2折叠 ====================训练集索引:[ 4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24]训练集标签: [1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]训练集组别标签 ['b' 'b' 'b' 'c' 'c' 'c' 'c' 'c' 'd' 'd' 'd' 'd' 'e' 'e' 'e' 'e' 'e' 'f''f' 'f' 'f']训练集数据: [8 9] [10 11] [12 13] [14 15] [16 17] [18 19] [20 21] [22 23] [24 25] [26 27] [28 29] [30 31] [32 33] [34 35] [36 37] [38 39] [40 41] [42 43] [44 45] [46 47] [48 49] 测试集索引:[ 0  1  2  3 25 26 27 28 29]测试集标签: [0 0 0 1 2 2 2 2 2]测试集组别标签 ['a' 'a' 'a' 'a' 'g' 'g' 'g' 'g' 'g']测试集数据: [0 1] [2 3] [4 5] [6 7] [50 51] [52 53] [54 55] [56 57] [58 59] 
=============================================================

可以看到测试集标签里面有0,但是训练集标签里没有0——这没办法做测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/318313.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AUTOSAR软件架构描述文档,AUTOSAR_EXP_LayeredSoftwareArchitecture

AUTOSAR软件架构描述文档,我们常见的经典的CP架构及OS双核等架构描述 下载链接:https://www.autosar.org/fileadmin/standards/R21-11/CP/AUTOSAR_EXP_LayeredSoftwareArchitecture.pdf

每日一练:LeeCode-503. 下一个更大元素 II (中)【单调栈】

本文是力扣LeeCode-503. 下一个更大元素 II 学习与理解过程,本文仅做学习之用,对本题感兴趣的小伙伴可以出门左拐LeeCode。 给定一个循环数组 nums ( nums[nums.length - 1] 的下一个元素是 nums[0] ),返回 nums 中每个…

【hyperledger-fabric】部署和安装

简介 对hyperledger-fabric进行安装,话不多说,直接开干。但是需要申明一点,也就是本文章全程是开着加速器进行的资源操作,所以对于没有开加速器的情况可能会由于网络原因导致下载资源失败。 资料提供 1.官方部署文档在此&#…

动手学深度学习一:环境安装与数据学习

2024,重新开始深度学习。 第一步:李沐动手学深度学习 课程网址:https://courses.d2l.ai/zh-v2/ 包含教材和视频网址链接 Jupyter notebook安装 目前在本地先使用cpu版本pytorch,我的本地已经安装好conda,跟着教材创建…

闭着眼睛都要会的Linux命令

😄作者简介: 小曾同学.com,一个致力于测试开发的博主⛽️,主要职责:测试开发、CI/CD 如果文章知识点有错误的地方,还请大家指正,让我们一起学习,一起进步。 😊 座右铭:不…

互联网加竞赛 Yolov安全帽佩戴检测 危险区域进入检测 - 深度学习 opencv

1 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 Yolov安全帽佩戴检测 危险区域进入检测 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:3分创新点:4分 该项目较为新颖&am…

从零开始了解大数据(七):总结

系列文章目录 从零开始了解大数据(一):数据分析入门篇-CSDN博客 从零开始了解大数据(二):Hadoop篇-CSDN博客 从零开始了解大数据(三):HDFS分布式文件系统篇-CSDN博客 从零开始了解大数据(四):MapReduce篇-CSDN博客 从零开始了解大…

1_并发编程_线程的基本概念和线程终止及线程问题排查

1.线程的运行状态 在Java中,线程的状态一共是6种状态,分别是 NEW:初始状态,线程被构建,但是还没有调用start方法 RUNNABLED:运行状态,JAVA线程把操作系统中的就绪和运行两种状态统一称为“运行…

c++day6

vector容器主要的功能函数&#xff1a; #include <iostream> #include <vector> using namespace std;int main() {//无参构造vector <int> v1;//有参构造vector <int> v2(5,99);//判空cout<<v1.empty()<<endl;//1cout<<v2.empty()…

c++编程要养成的好习惯

1、缩进 你说有缩进看的清楚还是没缩进看的清楚 2、i和i i运行起来和i更快 3、 n%20和n&1 不要再用n%20来判断n是不是偶数了&#xff0c;又慢又土&#xff0c;用n&10&#xff0c;如果n&10就说明n是偶数 同理&#xff0c;n&11说明n是奇数 4、*2和<<…

uniapp选择android非图片文件的方案踩坑记录

这个简单的问题我遇到下面6大坑&#xff0c;原始需求是选择app如android的excel然后读取到页面并上传表格数据json 先看看效果 uniapp 选择app excel文件读取 1.uniapp自带不支持 uniapp选择图片和视频非常方便自带已经支持可以直接上传和读取 但是选择word excel的时候就出现…

k8s之pod

pod是k8s中最小的资源管理组件 pod也是最小化运行容器化的应用的资源管理对象 pod是一个抽象的概念&#xff0c;可以理解成一个或者多个容器化应用的集合 pod可以是一个或者多个 在一个pod中运行一个容器&#xff08;最常用的方式&#xff09; 在一个pod中同时运行多个容器…