#include "opencv2/core.hpp" // 引入opencv2的核心头文件
#include "opencv2/video/tracking.hpp" // 引入opencv2视频跟踪相关功能的头文件
#include "opencv2/imgproc.hpp" // 引入opencv2的图像处理相关功能的头文件
#include "opencv2/highgui.hpp" // 引入opencv2的GUI界面和图像显示相关功能的头文件
#include "opencv2/ml.hpp" // 引入opencv2的机器学习模块的头文件using namespace cv; // 使用opencv命名空间
using namespace cv::ml; // 使用opencv机器学习模块的命名空间struct Data
{Mat img; // 存放图像Mat samples; // 训练样本集,其中包含图像上的点Mat responses; // 训练样本的响应集Data() // Data 的构造函数{const int WIDTH = 841; // 图像宽度const int HEIGHT = 594; // 图像高度img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); // 创建全零值(黑色)图像,3个颜色通道imshow("Train svmsgd", img); // 显示图像,窗口标题为 "Train svmsgd"}
};
// 使用 SVMSGD 算法进行训练的函数
// samples 和 responses 是训练集
// weights 是 SVMSGD 算法的决策函数所需的向量
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);// 找到画线(wx = 0)的两个点的函数
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height);// 找到线(wx = 0)和边界的交点的函数,边界为 (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT)
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);// 初始化边界线段的函数,分为 (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT)
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);// 重绘图像中的点集合和线(wx = 0)
void redraw(Data data, const Point points[2]);// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);// 实际执行训练的函数实现
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); // 创建 SVMSGD 类的实例cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); // 创建训练数据svmsgd->train(trainData); // 训练 SVMSGD 模型if (svmsgd->isTrained()) // 如果模型训练成功{weights = svmsgd->getWeights(); // 获取模型权重shift = svmsgd->getShift(); // 获取模型偏移return true; // 返回训练成功}return false; // 如果训练失败,返回 false
}// 初始化边界线段的函数实现
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{std::pair<Point,Point> currentSegment;currentSegment.first = Point(width, 0); // 右边界线段currentSegment.second = Point(width, height);segments.push_back(currentSegment);currentSegment.first = Point(0, height); // 底边界线段currentSegment.second = Point(width, height);segments.push_back(currentSegment);currentSegment.first = Point(0, 0); // 顶边界线段currentSegment.second = Point(width, 0);segments.push_back(currentSegment);currentSegment.first = Point(0, 0); // 左边界线段currentSegment.second = Point(0, height);segments.push_back(currentSegment);
}// 函数findCrossPointWithBorders用于计算给定权重和偏移量下,直线与图像边界的交点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{// 初始化交点的坐标int x = 0;int y = 0;// 获取线段端点的横纵坐标极值int xMin = std::min(segment.first.x, segment.second.x);int xMax = std::max(segment.first.x, segment.second.x);int yMin = std::min(segment.first.y, segment.second.y);int yMax = std::max(segment.first.y, segment.second.y);// 权重矩阵必须是单精度浮点类型CV_Assert(weights.type() == CV_32FC1);// 检查线段水平还是垂直CV_Assert(xMin == xMax || yMin == yMax);// 如果是垂直线段并且权重矩阵的第二个元素不为0if (xMin == xMax && weights.at<float>(1) != 0){x = xMin;// 根据直线方程计算交点的y值y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));// 检测交点是否在图像的边界内if (y >= yMin && y <= yMax){crossPoint.x = x;crossPoint.y = y;return true;}}// 如果是水平线段并且权重矩阵的第一个元素不为0else if (yMin == yMax && weights.at<float>(0) != 0){y = yMin;// 根据直线方程计算交点的x值x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));// 检测交点是否在图像的边界内if (x >= xMin && x <= xMax){crossPoint.x = x;crossPoint.y = y;return true;}}return false;
}// 函数findPointsForLine用于计算用于绘制直线的两个点
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{// 如果权重矩阵为空,则返回失败if (weights.empty()){return false;}// 记录已找到的有效交点数量int foundPointsCount = 0;// 用于存储图像边框的4条线段(轮廓线)std::vector<std::pair<Point,Point> > segments;// 初始化线段(轮廓线)fillSegments(segments, width, height);// 遍历所有的边框线for (uint i = 0; i < segments.size(); i++){// 如果找到与边框的交点if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))foundPointsCount++; //增加有效交点的数量// 如果已找到了两个有效交点,则可以构成一条线,跳出循环if (foundPointsCount >= 2)break;}return true; //成功找到两个点
}
// 重新绘制数据集和分割线的函数实现
void redraw(Data data, const Point points[2])
{data.img.setTo(0); // 将图像设置为全黑Point center;int radius = 3; // 点的半径Scalar color; // 颜色CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1)); // 断言样本和响应数据类型正确for (int i = 0; i < data.samples.rows; i++) // 遍历所有样本{center.x = static_cast<int>(data.samples.at<float>(i,0));center.y = static_cast<int>(data.samples.at<float>(i,1));color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); // 根据响应值设定颜色circle(data.img, center, radius, color, 5); // 绘制圆形点}line(data.img, points[0], points[1],cv::Scalar(1,255,1)); // 绘制分割线imshow("Train svmsgd", data.img); // 显示图像
}// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数实现
// 函数addPointRetrainAndRedraw用于添加新的训练点,并重新训练SVMSGD算法,然后重绘图形
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{// 创建一个1行2列的单精度浮点数矩阵,用于存储一个样本点Mat currentSample(1, 2, CV_32FC1);// 设定样本点的横纵坐标currentSample.at<float>(0,0) = (float)x;currentSample.at<float>(0,1) = (float)y;// 将新样本点加入到样本集中data.samples.push_back(currentSample);// 将样本点的响应(类别)加入到响应集中data.responses.push_back(static_cast<float>(response));// 创建权重矩阵和偏移量Mat weights(1, 2, CV_32FC1);float shift = 0;// 如果训练成功if (doTrain(data.samples, data.responses, weights, shift)){// 创建Points数组用于存储线的两个点Point points[2];// 找到用于绘制直线的两个点findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);// 使用找到的两个点重新绘制图形redraw(data, points);}
}// 鼠标回调函数,用于在图像上添加正负样本点并重新训练模型
static void onMouse( int event, int x, int y, int, void* pData)
{Data &data = *(Data*)pData; // 从pData转换获取Data结构体引用switch( event ) // 根据事件类型{case EVENT_LBUTTONUP: // 左键松开事件addPointRetrainAndRedraw(data, x, y, 1); // 添加正样本点并重新训练绘制break;case EVENT_RBUTTONDOWN: // 右键按下事件addPointRetrainAndRedraw(data, x, y, -1);// 添加负样本点并重新训练绘制break;}
}// 主函数
int main()
{Data data; // 创建Data结构体实例setMouseCallback( "Train svmsgd", onMouse, &data ); // 设置鼠标回调函数waitKey(); // 等待按键return 0; // 程序结束
}
该段代码是一个关于OpenCV和机器学习算法SVMSGD(支持向量机随机梯度下降)的简单示例,用于创建一个可交互的界面,在上面添加样本点,进行实时的线性分类器训练,并且通过绘制决策边界来显示分类结果。通过鼠标左键添加正样本,右键添加负样本。