文章目录
- 前言
- 一、介绍
- 1.1 原理
- 1.2 流程
- 1.3 信息熵,信息增益和基尼不纯度
- 二、构建决策树
- 2.1 特征选择
- 2.2 决策树生成
- 2.3 剪枝
- 三、经典算法
- 3.1 ID3
- 3.2 C4.5
- 3.3 CART
- 四、案例
- 4.1 Iris 数据集 鸢尾花 分类
- 4.2 基于决策树的英雄联盟游戏胜负预测
- 参考
前言
决策树(Decision Trees) 是一种基于树结构的机器学习算法,它是近年来最常见的数据挖掘算法,可以用于分类和回归问题。
它可以作为预测模型,从样本的观测数据推断出该样本的预测结果。 按预测结果的差异,决策树学习可细分两类。
- 分类树,其预测结果仅限于一组离散数值。树的每个分支对应一组由逻辑与连接的分类特征,而该分支上的叶节点对应由上述特征可以预测出的分类标签。
- 回归树,其预测结果为连续值。
决策树可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。
- if-then 规则是指一种形式化的表示方法,用于描述决策树模型中的判断过程。每个规则都由一个前提和一个结论组成。例如,如果正在使用决策树来预测一个人是否会购买某个产品,那么一个规则可能是:“如果这个人的年龄在30岁以下且收入在5万美元以上,则他会购买这个产品。”这个规则的前提是“这个人的年龄在30岁以下且收入在5万美元以上”,结论是“他会购买这个产品”。
- 特征空间是指所有样本的特征向量所构成的空间,在特征空间中,每个样本都可以表示为一个向量
- 类空间则是指所有可能的类所构成的空间,在类空间中,每个类都可以表示为一个点或一个区域
决策树算法的目标是在特征空间中找到一个划分,使得每个划分区域内的样本都属于同一类。
一、介绍
1.1 原理
决策树算法的基本原理是将数据集按照某种特定的规则进行划分,使得划分后的子集尽可能的纯,即同一子集中的样本属于同一类别。这个过程可以看作是一个递归的过程,每次选择一个最优的特征进行划分,直到所有样本都属于同一类别或者无法继续划分为止。
在构建决策树时,我们需要考虑如何选择最优的特征进行划分。常用的方法有ID3 (Iterative Dichotomiser 3)、C4.5、CART(Classification and Regression Trees)等。其中ID3和C4.5使用信息增益来进行特征选择,而CART使用基尼不纯度来进行特征选择。
- 信息增益:数据集划分前后信息发生的变化
- 基尼不纯度:简单讲就是从一个数据集随机选取子项,度量其被错误分类到其它组的概率
1.2 流程
决策树的基本流程是一个由根到叶的递归过程,在每一个中间结点寻找划分属性,递归重要的是设置停止条件:
- 当前结点包含的样本属于同一类别,无需划分;
- 当前属性集为空,或是所有样本在所有属性上取值相同无法划分,简单理解就是当分到这一节点时,所有的属性特征都用完了,没有特征可用了,就根据label数量多的给这一节点打标签使其变成叶节点(其实是在用样本出现的后验概率做先验概率);
- 当前结点包含的样本集合为空,不能划分。这种情况出现是因为该样本数据缺少这个属性取值,根据父结点的label情况为该结点打标记(其实是在用父结点出现次数最多的label做先验概率)
1.3 信息熵,信息增益和基尼不纯度
-
信息熵(entropy) 是对于样本集合的不确定性的度量,它的值越小,样本集合的纯度越高。
在决策树算法中,我们使用信息熵来计算样本集合的纯度。假设样本集合D中第k类样本所占比例为 p k ( k = 1 , 2 , … , y ) p_k(k=1,2,…,y) pk(k=1,2,…,y),则 D 的信息熵定义为:
E n t ( D ) = − ∑ k = 1 y p k l o g 2 p k Ent(D)=-\sum_{k=1}^{y}p_klog_2p_k Ent(D)=−k=1∑ypklog2pk
其中, y y y是类别个数。
-
信息增益(information gain),它是以信息熵为基础的,它表示得到信息带来的变化量,通常用于选择最优的分裂特征。信息增益的计算公式如下:
G a i n ( D , A ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D, A) = Ent(D) - \sum_{v=1}^{V}\frac{|D^v|}{|D|}Ent(D^v) Gain(D,A)=Ent(D)−v=1∑V∣D∣∣Dv∣Ent(Dv)
其中, D D D 表示当前节点的训练数据集, A A A 表示候选特征集合, V V V 表示候选特征集合中特征的个数, D v D^v Dv 表示当前节点按照特征 A A A 的第 v v v 个取值划分后的子集, E n t ( D ) Ent(D) Ent(D) 表示当前节点的熵, E n t ( D v ) Ent(D^v) Ent(Dv) 表示当前节点按照特征 A A A 的第 v v v 个取值划分后子集的熵。
信息增益越高表示该特征对分类能力的贡献度越大,即该特征可以更好地区分不同类别的样本。
-
基尼不纯度(Gini impurity) 是一种用于衡量数据集的纯度的指标,它表示从数据集中随机选取两个样本,其类别不一致的概率。
G i n i ( D ) = ∑ k = 1 y p k ( 1 − p k ) Gini(D) = \sum_{k=1}^y{p_k(1- p_k)} Gini(D)=k=1∑ypk(1−pk)基尼不纯度越低代表数据集的纯度越高,通常用于衡量一个节点的分裂效果,在节点代表该节点分裂后子节点的纯度越高,即子节点中包含的相同类别样本比例越大。
二、构建决策树
2.1 特征选择
特征选择是指从训练数据中众多的特征中选择一个特征作为当前节点的分裂标准,优缺点如下:
优点:
- 降低决策树复杂度,使模型更加简单,减少过拟合(指模型在训练集上表现良好,但在测试集上表现不佳)的风险,提高模型泛化能力。
- 减少决策树训练时间和存储空间。
缺点:
- 可能会丢失一些重要信息,导致模型精度下降。
- 可能会引入一些噪声,导致模型精度下降。
- 可能会使数据变得更加复杂,导致模型泛化能力下降。
常用的特征选择方法有信息增益、信息增益比、基尼指数等
2.2 决策树生成
决策树生成是指从训练数据中生成决策树的过程。根据上述特征选择方法,常用的决策树生成算法有ID3、C4.5、CART等。
决策树通过对训练数据进行递归分割,生成一棵树形结构,从而实现对新数据的分类。决策树的生成过程可以分为以下几个步骤:
- 特征选择:从训练数据的特征中选择一个特征作为当前节点的分裂标准。
- 节点分裂:将当前节点的训练数据按照分裂标准分成若干个子集,每个子集对应一个子节点。
- 递归生成子树:对每个子节点递归执行步骤1和步骤2,直到满足停止条件。
停止条件通常有以下几种:
- 当前节点的训练数据全部属于同一类别。
- 当前节点的训练数据为空。
- 当前节点的训练数据中所有特征都相同,无法进行进一步分割。
2.3 剪枝
决策树剪枝是一种用于减少决策树复杂度的技术,它的目的是通过删除一些不必要的节点和子树,从而提高模型泛化能力。常用的决策树剪枝算法有预剪枝和后剪枝两种。
-
预剪枝是指在生成决策树的过程中,对每个节点进行评估,如果当前节点的分裂不能提高模型泛化能力,则停止分裂,将当前节点标记为叶子节点。预剪枝的优点是简单、快速,但可能会导致欠拟合。
-
后剪枝是指在生成决策树之后,对决策树进行修剪,从而减少决策树复杂度。后剪枝的过程通常包括以下几个步骤:
- 对每个非叶子节点进行评估,计算修剪前后模型在验证集上的性能差异。
- 选择性能差异最小的节点进行修剪,将该节点及其子树删除,并将该节点标记为叶子节点。
- 重复步骤1和步骤2,直到无法继续修剪为止。
后剪枝的优点是可以避免欠拟合,但可能会导致过拟合。
三、经典算法
3.1 ID3
ID3算法的核心思想是以信息增益来度量特征选择,选择信息增益最大的特征进行分裂
- 计算数据集的信息熵;
- 对每个特征,计算其信息增益;
- 选择信息增益最大的特征作为划分属性;
- 根据该属性的取值将数据集划分为多个子集;
- 对每个子集递归调用步骤1-4,直到所有样本属于同一类别或无法继续划分。
缺点:
- ID3 没有剪枝策略,容易过拟合
- 信息增益准则对可取值数目较多的特征有所偏好,类似“编号”的特征其信息增益接近于1
- 只能用于处理离散分布的特征
- 没有考虑缺失值
代码如下:
%matplotlib inlineimport math
from collections import Counter,defaultdict
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font_set = FontProperties(fname=r"c:\\windows\\fonts\\simsun.ttc", size=15)#导入宋体字体文件class Id3DecideTree:def __init__(self, data_set, labels_set):self.tree = self.create_tree(data_set,labels_set)def calc_entropy(self, data):"""计算数据集的信息熵"""label_counts = Counter(sample[-1] for sample in data)probs = [count / len(data) for count in label_counts.values()]return -sum(p * math.log(p, 2) for p in probs)def split_data(self, data, axis, value):"""根据特征划分数据集"""return [sample[:axis] + sample[axis+1:] for sample in data if sample[axis] == value]def choose_best_feature(self, dataSet):"""选择最好的数据集划分方式"""numFeatures = len(dataSet[0]) - 1 # 最后一列用于标签baseEntropy = self.calc_entropy(dataSet) # 计算数据集的熵bestFeature = -1for i in range(numFeatures): # 遍历所有特征featList = [example[i] for example in dataSet] # 创建该特征的所有样本列表uniqueVals = set(featList) # 获取唯一值的集合newEntropy = 0.0for value in uniqueVals:subDataSet = self.split_data(dataSet, i, value) # 划分数据集prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * self.calc_entropy(subDataSet)infoGain = baseEntropy - newEntropy # 计算信息增益;即熵的减少量if (infoGain > bestInfoGain): # 比较目前为止最好的增益bestInfoGain = infoGain # 如果比当前最好的更好,则设置为最好的bestFeature = ireturn bestFeaturedef majority_count(labels):"""统计出现次数最多的类别"""label_counts = defaultdict(int)for label in labels:label_counts[label] += 1return max(label_counts, key=label_counts.get)def create_tree(self, data, labels):"""创建决策树"""class_list = [sample[-1] for sample in data]# 所有样本同一类别if class_list.count(class_list[0]) == len(class_list):return class_list[0]# 只有一个特征if len(data[0]) == 1:return majority_count(class_list)# 选择最优划分特征best_feature_index = self.choose_best_feature(data)best_feature_label = labels[best_feature_index]tree = {best_feature_label: {}}del(labels[best_feature_index])feature_values = [sample[best_feature_index] for sample in data]unique_values = set(feature_values)for value in unique_values:sub_labels = labels[:]tree[best_feature_label][value] = self.create_tree(self.split_data(data, best_feature_index, value), sub_labels)return treeclass DecisionTreePlotter:def __init__(self, tree):self.tree = treeself.decisionNode = dict(boxstyle="sawtooth", fc="0.8")self.leafNode = dict(boxstyle="round4", fc="0.8")self.arrow_args = dict(arrowstyle="<-")self.font_set = font_setdef getNumLeafs(self, node):firstStr = list(node.keys())[0]secondDict = node[firstStr]return sum([self.getNumLeafs(secondDict[key]) if isinstance(secondDict[key], dict) else 1 for key in secondDict.keys()])def getTreeDepth(self, node):firstStr = list(node.keys())[0]secondDict = node[firstStr]return max([1 + self.getTreeDepth(secondDict[key]) if isinstance(secondDict[key], dict) else 1 for key in secondDict.keys()])def plotNode(self, nodeTxt, centerPt, parentPt, nodeType):self.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=self.arrow_args, fontproperties=self.font_set )def plotMidText(self, cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]self.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30, fontproperties=self.font_set)def plotTree(self):self.totalW = float(self.getNumLeafs(self.tree))self.totalD = float(self.getTreeDepth(self.tree))self.xOff = -0.5/self.totalWself.yOff = 1.0self.fig = plt.figure(1, facecolor='white')self.fig.clf()self.axprops = dict(xticks=[], yticks=[])self.ax1 = plt.subplot(111, frameon=False, **self.axprops)self.plotTreeHelper(self.tree, (0.5,1.0), '')plt.show()def plotTreeHelper(self, node, parentPt, nodeTxt):numLeafs = self.getNumLeafs(node) depth = self.getTreeDepth(node)firstStr = list(node.keys())[0] cntrPt = (self.xOff + (1.0 + float(numLeafs))/2.0/self.totalW, self.yOff)self.plotMidText(cntrPt, parentPt, nodeTxt)self.plotNode(firstStr, cntrPt, parentPt, self.decisionNode)secondDict = node[firstStr]self.yOff = self.yOff - 1.0/self.totalDfor key in secondDict.keys():if isinstance(secondDict[key], dict):self.plotTreeHelper(secondDict[key],cntrPt,str(key)) else: self.xOff = self.xOff + 1.0/self.totalWself.plotNode(secondDict[key], (self.xOff, self.yOff), cntrPt, self.leafNode)self.plotMidText((self.xOff, self.yOff), cntrPt, str(key))self.yOff = self.yOff + 1.0/self.totalDlabels_set = ['不浮出水面', '拥有鳍','有头']data_set = [['是', '是', '是', '是鱼类'],['是', '是', '否', '不是鱼类'],['是', '否', '是', '不是鱼类'],['否', '是', '否', '不是鱼类'],['否', '否', '是', '不是鱼类']
]dt = Id3DecideTree(data_set, labels_set)print(dt.tree)plotter = DecisionTreePlotter(dt.tree)
plotter.plotTree()
3.2 C4.5
C4.5算法则采用信息增益比来度量特征选择,选择信息增益比最大的特征进行分裂。
C4.5算法是ID3算法的改进版,其具体流程如下:
- 计算数据集的信息熵;
- 对每个特征,计算其信息增益比;
- 选择信息增益比最大的特征作为划分属性;
- 根据该属性的取值将数据集划分为多个子集;
- 对每个子集递归调用步骤1-4,直到所有样本属于同一类别或无法继续划分。
C4.5算法相对于ID3算法的优点在于:
-
使用信息增益比来选择最佳划分特征,避免了ID3算法中存在的偏向选择取值较多的特征的问题
-
同时处理连续的属性和离散的属性
-
处理缺少属性值的训练数据
-
创建后剪枝树
缺点有:
- C4.5 用的是多叉树,用二叉树效率更高
- C4.5 只能用于分类
- C4.5 使用的熵模型拥有大量耗时的对数运算,连续值还有排序运算
- C4.5 在构造树的过程中,对数值属性值需要按照其大小进行排序,从中选择一个分割点,所以只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时,程序无法运行
def calc_info_gain_ratio(data, feature_index):"""计算信息增益比"""base_entropy = calc_entropy(data)feature_values = [sample[feature_index] for sample in data]unique_values = set(feature_values)new_entropy = 0.0split_info = 0.0for value in unique_values:sub_data = [sample for sample in data if sample[feature_index] == value]prob = len(sub_data) / float(len(data))new_entropy += prob * calc_entropy(sub_data)split_info -= prob * math.log(prob, 2)info_gain = base_entropy - new_entropyif split_info == 0:return 0return info_gain / split_infodef choose_best_feature(data):"""选择最好的数据集划分方式"""num_features = len(data[0]) - 1base_entropy = calc_entropy(data)best_info_gain_ratio = 0.0best_feature_index = -1for i in range(num_features):info_gain_ratio = calc_info_gain_ratio(data, i)if info_gain_ratio > best_info_gain_ratio:best_info_gain_ratio = info_gain_ratiobest_feature_index = ireturn best_feature_index
3.3 CART
CART 选择基尼不纯度来度量特征选择,选择基尼不纯度最小的特征进行分裂,是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,因此CART算法生成的决策树是结构简洁的二叉树
CART算法是一种二叉决策树,其具体流程如下:
- 选择一个特征和一个阈值,将数据集划分为两个子集;
- 对每个子集递归调用步骤1,直到所有样本属于同一类别或无法继续划分。
CART算法相对于ID3算法和C4.5算法的改进在于,它使用基尼指数来选择最佳划分特征
import numpy as npclass CARTDecisionTree:def __init__(self):self.tree = {}def calc_gini(self, data):"""计算基尼指数"""label_counts = {}for sample in data:label = sample[-1]if label not in label_counts:label_counts[label] = 0label_counts[label] += 1gini = 1.0for count in label_counts.values():prob = float(count) / len(data)gini -= prob ** 2return ginidef split_data(self, data, feature_index, value):"""根据特征划分数据集"""new_data = []for sample in data:if sample[feature_index] == value:new_sample = sample[:feature_index]new_sample.extend(sample[feature_index+1:])new_data.append(new_sample)return new_datadef choose_best_feature(self, data):"""选择最佳划分特征"""num_features = len(data[0]) - 1best_gini_index = np.infbest_feature_index = -1best_split_value = Nonefor i in range(num_features):feature_values = [sample[i] for sample in data]unique_values = set(feature_values)for value in unique_values:sub_data = self.split_data(data, i, value)prob = len(sub_data) / float(len(data))gini_index = prob * self.calc_gini(sub_data)gini_index += (1 - prob) * self.calc_gini([sample for sample in data if sample[i] != value])if gini_index < best_gini_index:best_gini_index = gini_indexbest_feature_index = ibest_split_value = valuereturn best_feature_index, best_split_valuedef majority_count(self, labels):"""统计出现次数最多的类别"""label_counts = {}for label in labels:if label not in label_counts:label_counts[label] = 0label_counts[label] += 1sorted_label_counts = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)return sorted_label_counts[0][0]def create_tree(self, data, labels):"""创建决策树"""class_list = [sample[-1] for sample in data]if class_list.count(class_list[0]) == len(class_list):return class_list[0]if len(data[0]) == 1:return self.majority_count(class_list)best_feature_index, best_split_value = self.choose_best_feature(data)best_feature_label = labels[best_feature_index]tree = {best_feature_label: {}}del(labels[best_feature_index])feature_values = [sample[best_feature_index] for sample in data]unique_values = set(feature_values)for value in unique_values:sub_labels = labels[:]tree[best_feature_label][value] = self.create_tree(self.split_data(data, best_feature_index, value), sub_labels)return treedef fit(self, X_train, y_train):"""训练模型"""data_set = np.hstack((X_train, y_train.reshape(-1, 1)))labels_set=['feature_{}'.format(i) for i in range(X_train.shape[1])]labels_set.append('label')self.tree=self.create_tree(data_set.tolist(),labels_set)def predict(self,X_test):"""预测"""y_pred=[]for x_test in X_test:node=self.tree.copy()while isinstance(node,dict):feature=list(node.keys())[0]node=node[feature]feature_idx=int(feature.split('_')[-1])if x_test[feature_idx]==list(node.keys())[0]:node=node[node.keys()[0]]else:node=node[node.keys()[1]]y_pred.append(node)return np.array(y_pred)
四、案例
4.1 Iris 数据集 鸢尾花 分类
Iris数据集。这个数据集包含150个样本,每个样本有四个特征(萼片长度、萼片宽度、花瓣长度和花瓣宽度),并且每个样本都属于三个类别之一(山鸢尾、变色鸢尾或维吉尼亚鸢尾)。
直接调用 sklearn 库实现
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split
import graphviziris = load_iris()X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=42)clf = DecisionTreeClassifier(criterion='entropy')
clf.fit(X_train, y_train)class_names = ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾']
feature_names = ['萼片长度', '萼片宽度', '花瓣长度', '花瓣宽度']
dot_data = export_graphviz(clf, out_file=None, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('iris_decision_tree')
graph
- entropy 表示节点的信息熵
- samples表示节点拥有样本数
- value表示节点中每个类别的样本数量
- class表示节点被分类为哪个类别
上述决策树有点复杂,使用参数控制和剪枝实现优化
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
import graphviziris = load_iris()X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.3, random_state=42)# 定义参数范围
param_grid = {'max_depth': range(1, 10), 'min_samples_leaf': range(1, 10)}# 使用网格搜索找到最佳参数
grid_search = GridSearchCV(DecisionTreeClassifier(criterion='entropy'), param_grid, cv=5)
grid_search.fit(X_train, y_train)# 使用最佳参数训练模型
clf = DecisionTreeClassifier(criterion='entropy', **grid_search.best_params_)
clf.fit(X_train, y_train)# 交叉验证评估每个子树的性能
cv_scores = []
for i in range(1, clf.tree_.max_depth + 1):clf_pruned = DecisionTreeClassifier(criterion='entropy', max_depth=i)scores = cross_val_score(clf_pruned, X_train, y_train, cv=5)cv_scores.append((i, scores.mean()))# 选择最佳子树进行剪枝
best_depth = max(cv_scores, key=lambda x: x[1])[0]
clf_pruned = DecisionTreeClassifier(criterion='entropy', max_depth=best_depth)
clf_pruned.fit(X_train, y_train)class_names = ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾']
feature_names = ['萼片长度', '萼片宽度', '花瓣长度', '花瓣宽度']
dot_data = export_graphviz(clf_pruned, out_file=None, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('iris_decision_tree_pruned')
graph
4.2 基于决策树的英雄联盟游戏胜负预测
数据集来源:
- https://aistudio.baidu.com/aistudio/datasetdetail/168986
- https://www.kaggle.com/datasets/bobbyscience/league-of-legends-diamond-ranked-games-10-min
特征名 | 含义 |
---|---|
gameId | 游戏Id |
blueWins | 蓝色方是否胜利 |
blueWardsPlaced | 放眼数量 |
blueWardsDestroyed | 毁眼数量 |
blueFirstBlood | 是否拿到一血 |
blueKills | 击杀数 |
blueDeaths | 死亡数 |
blueAssists | 助攻数 |
blueEliteMonsters | 龙和先锋数 |
blueDragons | 小龙数 |
blueHeralds | 峡谷先锋数 |
blueTowersDestroyed | 推塔数 |
blueTotalGold | 总经济 |
blueAvgLevel | 平均等级 |
blueTotalExperience | 总经验 |
blueTotalMinionsKilled | 总补兵数 |
blueTotalJungleMinionsKilled | 野怪击杀数 |
blueGoldDiff | 经济差值 |
blueExperienceDiff | 经验差值 |
blueCSPerMin | 平均每分钟补兵数 |
blueGoldPerMin | 平均每分钟经济 |
代码参考:https://www.kaggle.com/code/xiyuewang/lol-how-to-win#Introduction
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snsfrom sklearn import tree
from sklearn.model_selection import GridSearchCV
import graphviz# %matplotlib inline
# sns.set_style('darkgrid')df = pd.read_csv('high_diamond_ranked_10min.csv')df_clean = df.copy()# 删除冗余的列
cols = ['gameId', 'redFirstBlood', 'redKills', 'redEliteMonsters', 'redDragons','redTotalMinionsKilled','redTotalJungleMinionsKilled', 'redGoldDiff', 'redExperienceDiff', 'redCSPerMin', 'redGoldPerMin', 'redHeralds','blueGoldDiff', 'blueExperienceDiff', 'blueCSPerMin', 'blueGoldPerMin', 'blueTotalMinionsKilled']
df_clean = df_clean.drop(cols, axis = 1)# g = sns.PairGrid(data=df_clean, vars=['blueKills', 'blueAssists', 'blueWardsPlaced', 'blueTotalGold'], hue='blueWins', size=3, palette='Set1')
# g.map_diag(plt.hist)
# g.map_offdiag(plt.scatter)
# g.add_legend();# plt.figure(figsize=(16, 12))
# sns.heatmap(df_clean.drop('blueWins', axis=1).corr(), cmap='YlGnBu', annot=True, fmt='.2f', vmin=0);# 进一步抉择
cols = ['blueAvgLevel', 'redWardsPlaced', 'redWardsDestroyed', 'redDeaths', 'redAssists', 'redTowersDestroyed','redTotalExperience', 'redTotalGold', 'redAvgLevel']
df_clean = df_clean.drop(cols, axis=1)print(df_clean)# 计算与第一列的相关性,原理为计算皮尔逊相关系数,取值范围为[-1,1],可以用来衡量两个变量之间的线性相关程度。
corr_list = df_clean[df_clean.columns[1:]].apply(lambda x: x.corr(df_clean['blueWins']))cols = []
for col in corr_list.index:if (corr_list[col]>0.2 or corr_list[col]<-0.2):cols.append(col)df_clean = df_clean[cols]
# df_clean.hist(alpha = 0.7, figsize=(12,10), bins=5);from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
X = df_clean
y = df['blueWins']# scaler = MinMaxScaler()
# scaler.fit(X)
# X = scaler.transform(X)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)tree = tree.DecisionTreeClassifier(max_depth=3)# search the best params
grid = {'min_samples_split': [5, 10, 20, 50, 100]},clf_tree = GridSearchCV(tree, grid, cv=5)
clf_tree.fit(X_train, y_train)pred_tree = clf_tree.predict(X_test)# get the accuracy score
acc_tree = accuracy_score(pred_tree, y_test)
print(acc_tree)# 0,1
class_names = ['红色方胜', '蓝色方胜']
feature_names = cols
dot_data = export_graphviz(clf_tree.best_estimator_, out_file=None, feature_names=feature_names, class_names=class_names, filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('lol_decision_tree')
graph
参考
- 机器学习实战
- 决策树算法中,CART与ID3、C4.5特征选择之间的区别
- Python代码:递归实现C4.5决策树生成、剪枝、分类
- https://github.com/43254022km/C4.5-Algorithm
- https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier