【opencv】示例-train_svmsgd.cpp 随机梯度下降支持向量机(SVMSGD)对二维数据进行分类的UI...

6389938ae1f28455547f39da35bee25c.png

493b36e63a1aa84805bfa9215f553ff3.png

#include "opencv2/core.hpp"                     // 引入opencv2的核心头文件
#include "opencv2/video/tracking.hpp"           // 引入opencv2视频跟踪相关功能的头文件
#include "opencv2/imgproc.hpp"                  // 引入opencv2的图像处理相关功能的头文件
#include "opencv2/highgui.hpp"                  // 引入opencv2的GUI界面和图像显示相关功能的头文件
#include "opencv2/ml.hpp"                       // 引入opencv2的机器学习模块的头文件using namespace cv;                            // 使用opencv命名空间
using namespace cv::ml;                        // 使用opencv机器学习模块的命名空间struct Data                                     
{Mat img;                                   // 存放图像Mat samples;                              // 训练样本集,其中包含图像上的点Mat responses;                            // 训练样本的响应集Data()                                    // Data 的构造函数{const int WIDTH = 841;                // 图像宽度const int HEIGHT = 594;               // 图像高度img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); // 创建全零值(黑色)图像,3个颜色通道imshow("Train svmsgd", img);          // 显示图像,窗口标题为 "Train svmsgd"}
};
// 使用 SVMSGD 算法进行训练的函数
// samples 和 responses 是训练集
// weights 是 SVMSGD 算法的决策函数所需的向量
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);// 找到画线(wx = 0)的两个点的函数
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height);// 找到线(wx = 0)和边界的交点的函数,边界为 (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT)
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);// 初始化边界线段的函数,分为 (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT)
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);// 重绘图像中的点集合和线(wx = 0)
void redraw(Data data, const Point points[2]);// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);// 实际执行训练的函数实现
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();   // 创建 SVMSGD 类的实例cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); // 创建训练数据svmsgd->train(trainData); // 训练 SVMSGD 模型if (svmsgd->isTrained())  // 如果模型训练成功{weights = svmsgd->getWeights();  // 获取模型权重shift = svmsgd->getShift();      // 获取模型偏移return true;                     // 返回训练成功}return false;                        // 如果训练失败,返回 false
}// 初始化边界线段的函数实现
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{std::pair<Point,Point> currentSegment;currentSegment.first = Point(width, 0);    // 右边界线段currentSegment.second = Point(width, height);segments.push_back(currentSegment);currentSegment.first = Point(0, height);   // 底边界线段currentSegment.second = Point(width, height);segments.push_back(currentSegment);currentSegment.first = Point(0, 0);        // 顶边界线段currentSegment.second = Point(width, 0);segments.push_back(currentSegment);currentSegment.first = Point(0, 0);        // 左边界线段currentSegment.second = Point(0, height);segments.push_back(currentSegment);
}// 函数findCrossPointWithBorders用于计算给定权重和偏移量下,直线与图像边界的交点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{// 初始化交点的坐标int x = 0;int y = 0;// 获取线段端点的横纵坐标极值int xMin = std::min(segment.first.x, segment.second.x);int xMax = std::max(segment.first.x, segment.second.x);int yMin = std::min(segment.first.y, segment.second.y);int yMax = std::max(segment.first.y, segment.second.y);// 权重矩阵必须是单精度浮点类型CV_Assert(weights.type() == CV_32FC1);// 检查线段水平还是垂直CV_Assert(xMin == xMax || yMin == yMax);// 如果是垂直线段并且权重矩阵的第二个元素不为0if (xMin == xMax && weights.at<float>(1) != 0){x = xMin;// 根据直线方程计算交点的y值y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));// 检测交点是否在图像的边界内if (y >= yMin && y <= yMax){crossPoint.x = x;crossPoint.y = y;return true;}}// 如果是水平线段并且权重矩阵的第一个元素不为0else if (yMin == yMax && weights.at<float>(0) != 0){y = yMin;// 根据直线方程计算交点的x值x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));// 检测交点是否在图像的边界内if (x >= xMin && x <= xMax){crossPoint.x = x;crossPoint.y = y;return true;}}return false;
}// 函数findPointsForLine用于计算用于绘制直线的两个点
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{// 如果权重矩阵为空,则返回失败if (weights.empty()){return false;}// 记录已找到的有效交点数量int foundPointsCount = 0;// 用于存储图像边框的4条线段(轮廓线)std::vector<std::pair<Point,Point> > segments;// 初始化线段(轮廓线)fillSegments(segments, width, height);// 遍历所有的边框线for (uint i = 0; i < segments.size(); i++){// 如果找到与边框的交点if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))foundPointsCount++; //增加有效交点的数量// 如果已找到了两个有效交点,则可以构成一条线,跳出循环if (foundPointsCount >= 2)break;}return true; //成功找到两个点
}
// 重新绘制数据集和分割线的函数实现
void redraw(Data data, const Point points[2])
{data.img.setTo(0);                             // 将图像设置为全黑Point center;int radius = 3;                               // 点的半径Scalar color;                                 // 颜色CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1)); // 断言样本和响应数据类型正确for (int i = 0; i < data.samples.rows; i++)   // 遍历所有样本{center.x = static_cast<int>(data.samples.at<float>(i,0));center.y = static_cast<int>(data.samples.at<float>(i,1));color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); // 根据响应值设定颜色circle(data.img, center, radius, color, 5); // 绘制圆形点}line(data.img, points[0], points[1],cv::Scalar(1,255,1)); // 绘制分割线imshow("Train svmsgd", data.img);             // 显示图像
}// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数实现
// 函数addPointRetrainAndRedraw用于添加新的训练点,并重新训练SVMSGD算法,然后重绘图形
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{// 创建一个1行2列的单精度浮点数矩阵,用于存储一个样本点Mat currentSample(1, 2, CV_32FC1);// 设定样本点的横纵坐标currentSample.at<float>(0,0) = (float)x;currentSample.at<float>(0,1) = (float)y;// 将新样本点加入到样本集中data.samples.push_back(currentSample);// 将样本点的响应(类别)加入到响应集中data.responses.push_back(static_cast<float>(response));// 创建权重矩阵和偏移量Mat weights(1, 2, CV_32FC1);float shift = 0;// 如果训练成功if (doTrain(data.samples, data.responses, weights, shift)){// 创建Points数组用于存储线的两个点Point points[2];// 找到用于绘制直线的两个点findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);// 使用找到的两个点重新绘制图形redraw(data, points);}
}// 鼠标回调函数,用于在图像上添加正负样本点并重新训练模型
static void onMouse( int event, int x, int y, int, void* pData)
{Data &data = *(Data*)pData;                  // 从pData转换获取Data结构体引用switch( event )                              // 根据事件类型{case EVENT_LBUTTONUP:                        // 左键松开事件addPointRetrainAndRedraw(data, x, y, 1); // 添加正样本点并重新训练绘制break;case EVENT_RBUTTONDOWN:                      // 右键按下事件addPointRetrainAndRedraw(data, x, y, -1);// 添加负样本点并重新训练绘制break;}
}// 主函数
int main()
{Data data;                                   // 创建Data结构体实例setMouseCallback( "Train svmsgd", onMouse, &data ); // 设置鼠标回调函数waitKey();                                   // 等待按键return 0;                                    // 程序结束
}

该段代码是一个关于OpenCV和机器学习算法SVMSGD(支持向量机随机梯度下降)的简单示例,用于创建一个可交互的界面,在上面添加样本点,进行实时的线性分类器训练,并且通过绘制决策边界来显示分类结果。通过鼠标左键添加正样本,右键添加负样本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/618979.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是 MVVM、mvc 模型

mvc模型 MVC: MVC 即 model-view-controller&#xff08;模型-视图-控制器)是项目的一种分层架构思想&#xff0c;它把复杂的业务逻辑&#xff0c; 抽离为职能单一的小模块&#xff0c;每个模块看似相互独立&#xff0c;其实又各自有相互依赖关系。它的好处是&#xff1a;保证了…

微服务相关

1. 微服务主要七个模块 中央管理平台&#xff1a;生产者、消费者注册&#xff0c;服务发现&#xff0c;服务治理&#xff0c;调用关系生产者消费者权限管理流量管理自定义传输协议序列化反序列化 2. 中央管理平台 生产者A在中央管理平台注册后&#xff0c;中央管理平台会给他…

移动Web学习06-移动端适配Less预处理器项目案例

项目目标&#xff1a;实现在不同宽度设备中等比缩放的网页效果 Less代码 import ./base; import ./normalize;// 变量: 存储37.5 rootSize: 37.5rem; *{margin: 0;padding: 0; } body {background-color: #F0F0F0; }// 主体内容 .main {// padding-bottom: (50 / 37.5rem);pa…

【uniapp踩坑记】——微信小程序转发保存图片

关于微信小程序转发&保存图片 微信小程序图片转发保存简单说明网络图片的转发保存base64流形式图片转发保存 已经好多年没写博客了&#xff0c;最近使用在用uniapp开发一个移动版管理后台&#xff0c;记录下自己踩过的一些坑 吃相别太难看&#xff0c;搞一堆下头僵尸号来点…

Spring Cloud学习笔记:Eureka集群搭建样例

这是本人学习的总结&#xff0c;主要学习资料如下 - 马士兵教育 1、项目架构2、Dependency3、项目启动类4、application.yml5、启动项目 1、项目架构 因为这是单机模拟集群搭建&#xff0c;为了方便管理就都放在了一个项目中。这次准备搭建三个项目server1, server2, server3 …

Linux网络基础 (二) ——(IP、MAC、端口号、TCPUDP协议、网络字节序)

文章目录 IP 地址基本概念源IP地址 & 目的IP地址 MAC 地址基本概念源MAC地址 & 目的MAC地址 端口号基本概念源端口号 & 目的端口号 TCP & UDP 协议基本概念TCP 与 UDP 的抉择 网络字节序大端、小端字节序 &#x1f396; 博主的CSDN主页&#xff1a;Ryan.Alask…

攻防世界---Web_php_include

1.题目链接 2.补充知识&#xff1a; 3.构造&#xff1a;执行成功 /?pagedata://text/plain,<?php phpinfo()?> 4.构造下面url&#xff0c;得到目录路径 /?pagedata://text/plain,<?php echo $_SERVER[DOCUMENT_ROOT]?> 5构造下面url&#xff0c;读取该路径的…

Linux的学习之路:10、进程(2)

摘要 本章主要是说一下fork的一些用法、进程状态、优先级和环境变量。 目录 摘要 一、fork 1、fork的基本用法 2、分流 二、进程状态 三、优先级 四、环境变量 1、常见环境变量 2、和环境变量相关的命令 3、通过代码如何获取环境变量 五、导图 一、fork 1、fork…

微信小程序实现预约生成二维码

业务需求&#xff1a;点击预约按钮即可生成二维码凭码入校参观~ 一.创建页面 如下是博主自己写的wxml&#xff1a; <swiper indicator-dots indicator-color"white" indicator-active-color"blue" autoplay interval"2000" circular > &…

Zookeeper和Kafka的部署

目录 一、Zookeeper的基本概念 1. Zookeeper定义 2. Zookeeper工作机制 3. Zookeeper特点 4. Zookeeper数据结构 5. Zookeeper应用场景 5.1 统一命名服务 5.2 统一配置管理 5.3 统一集群管理 5.4 服务器动态上下线 5.5 软负载均衡 6. Zookeeper 选举机制 6.1 第一…

Cortex-M3/M4处理器的bit-band(位带)技术

ARM Cortex-M3/M4的位带&#xff08;Bit-Band&#xff09;技术是一种内存映射技术&#xff0c;它允许对单个位进行直接操作&#xff0c;而不需要对整个字&#xff08;通常是32位&#xff09;进行操作。这项技术主要用于对特定的位进行高效的读写&#xff0c;特别是在需要对GPIO…

python-numpy(3)-线性代数

一、方程求解 参考资料 对于Ax b 这种方程&#xff1a; np.linalg.inv(A).dot(B)np.linalg.solve(A,b) 1.1 求解多元一次方程一个直观的例子 # AXB # X A^(-1)*B A np.array([[7, 3, 0, 1], [0, 1, 0, -1], [1, 0, 6, -3], [1, 1, -1, -1]]) B np.array([8, 6, -3, 1]…