大家好,这篇文章我们探讨下,决策树算法的相关的知识,决策树是一种分类算法,现在也可以应用与回归,决策树算法的实现有很多种,你可以写Python 代码,也可以调用现成的sklearn包实现!
本文,主要包括三个部分:
1.第一部分,通过图文的方式介绍决策树算法的基本原理
2.第二部分,通过简单的例子,用Python 代码实现一个分类的问题
3.第三部分,调用sklearn包实现决策树算法
Part1: 决策树算法
决策数据算法是通过一系列,精心构建的问题来,通过是或和否的方式,直到找到记录所属的类。现在假设我们发现一个新物种,我们如何如何确定它的科属了,可以通过下面的例子进行说明。
决策构建的基本思路是,根据属性不断的对集合进行划分,直到集合中所有的元素都属于同一中类型或者达到指定的样本数量
决策树构建算法1-Hunt:
基本思路:1.如果集合Ct中的的元素所属类型一致,则节点t为叶节点,用yt进行标记
2.如果Dt中包含多个类的记录,则选择一个属性测试条件(体温,胎生),
把记录划分为较小的集合,然后对集合递归的调用这个算法(见上图)。
需要解决的问题:
1.如何确定合适的属性进行划分(体温or胎生)?
2.什么时候停止分裂(一直分裂会导致树变得很复杂)?
1.哪个属性更好
确定属性的常用方法有三种,详见下图
2.什么时候停止分裂
(待续……)
决策树计算的伪代码
检测数据集中的每个子项是否属于同一分类:
If so return 类标签
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
return 分支节点
决策树构建算法2: ID3,C4.5,CART
Part2:python 代码的实现
from math import log
import operator
#预定义数据集合
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
return dataSet, labels
#定义函数返回集合的熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
#❶ (以下五行)为所有可能分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #❷ 以2为底求对数
return shannonEnt
#数据集划分函数
def splitDataSet(dataSet, axis, value): #待划分的数据集合,划分数据集合的特征,需要返回的特征值
#创建新的list对象
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
#(以下三行)抽取
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选取数据集合的最佳划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):
#❶ (以下两行)创建唯一的分类标签列表
featList = [example[i] for example in dataSet]
# print(1,featList)
uniqueVals = set(featList) #对取值去重
# print(2,uniqueVals)
newEntropy = 0.0
#❷ (以下五行)计算每种划分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
#❸ 计算最好的信息增益
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#返回出现次数最多的分类标签
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): #如果不存就添加
classCount[vote] = 0
else:
classCount[vote] += 1 # 存在就+1
sortedClassCount=sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0] #返回出现次数最多的类标签
print("majorityCnt",majorityCnt([1,2,2,3,1,1,3]))
#生成决策树
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
#类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #返回最佳特征
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}} #生成根节点
# 得到列表包含的所有属性值
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet
(dataSet, bestFeat, value),subLabels)
return myTree
3
#读取数据
myDat,labels2=createDataSet()
print(1,myDat,labels2)
#计算集合的熵
# print("计算集合的熵:\n",calcShannonEnt(myDat))
#返回划分后的数据集合
# print(splitDataSet(myDat,0,0)) #集合,特征,特征取值
# print(splitDataSet(myDat,0,1))
#选择最佳划分特征
# print(chooseBestFeatureToSplit(myDat))
#生成决策树
myTree = createTree(myDat,labels2)
print(myTree)
#利用决策树,对新的数据进行预测
def classify(inputTree,featLabels,testVec): #树,可划分的属性列表
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
#将标签字符串转换为索引
print(featLabels)
featIndex = featLabels.index(str(firstStr))
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else: classLabel = secondDict[key]
return classLabel
#对测试数据进行划分
print(2,myDat,labels2)
print("划分结果\n",classify(myTree,['no surfacing', 'flippers'],[1,1])) #有问题