机器学习实战(二)-决策树的构造和绘制

一、前言

你是否玩过二十个问题的游戏?游戏规则很简单:参与游戏的一方在脑海中想某个事物,其他参与者向其提问,只允许提二十个问题,而且问题的答案只能用对和错来回答。问问题的人通过答案进行推断,逐步缩小猜测事物的范围。决策树原理与二十个问题类似,用户给出一次列输入数据,然后分类器给出答案。

决策树流程图

图 3-1 的流程图就是一个决策树,其中正方形表示判断模块,椭圆形表示终止模块。

二、决策树构造原理

在构造决策树时,我们首先需要考虑的问题就是哪个特征在划分数据集时起决定性作用。为了找到决定性特征,划分出最佳结果,我们需要对每个特征进行评估。完成测试后,原始数据集就会被分成几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支上的数据已经属于某一类别,那么该数据子集已经被正确分类,无需进一步划分,反之则需要进一步划分,直到所有相同类型的数据在同一个数据集中为止。

一些决策树采用二分法划分数据,但本文将使用 ID3 算法划分数据集。每次划分数据集只需选择一个特征,但如果某个数据集中有 20 个特征,应该如何选取?为了解决这个问题,我们必须采用量化的方法判断如何划分数据,这需要用到信息论的知识。

三、决策树先导知识

1. 信息增益

划分数据集的大原则是:将无序的数据变得有序。组织杂乱无章的数据的一种方法就是使用信息论度量信息,我们可以在划分数据前使用信息论度量信息的内容。

在划分数据集前后,信息发生的变化叫做信息增益,知道如何计算信息增益,我们就可以计算每个特征划分数据集后的信息增益,然后选择信息增益最大的特征。因此,为了解决这个问题,我们必须学习如何计算信息增益。

香农熵(熵)

集合信息的度量方式。熵是对混乱的度量。

熵定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事物可能划分在多个分类中,则符号 xi 的信息定义为:


xi 的信息定义

其中
pxi

是选择该分类的概率。

为了计算熵,我们需要所有类别所有可能值包含的信息期望值,可以通过如下公式得到:

信息熵

其中 n 是分类的数目。

计算给定数据集的香农熵

from math import log
import operator
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import LabelEncoder

'''
函数说明:构造数据集
参数:无
返回值:新构造的数据集
'''
def createDataSet():
    dataSet = [[1,1,'yes'],
                [1,1,'yes'],
                [1,0,'no'],
                [0,1,'no'],
                [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels

dataSet,labels = createDataSet()
# print(labels)

'''
函数说明:计算给定数据集的香农熵
参数:dataSet -- 数据集
返回值:shannonEnt -- 数据集的熵
'''
def calcShannonEnt(dataset):
    totalNum = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():          
            labelCounts[currentLabel] = 1
        else:
            labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / totalNum
        shannonEnt -= prob * log(prob,2)
    return shannonEnt
dataSet,labels = createDataSet()
# print(labels)
shannonEnt = calcShannonEnt(dataSet)
# print(dataSet)
# print(shannonEnt)

执行结果:

数据集的信息熵

熵越高,则混合的数据越多。我们可以在数据集中增加分类,观察熵的变化。

dataSet[0][-1] = 'maybe'

执行结果为:


分类越多熵越高

得到熵之后,我们就可以根据最大信息增益来划分数据集。

2. 根据给定特征划分数据集

上一节我们学习如何度量数据的无序程度,也即是熵的计算,接下来我们将对每个特征划分得到的数据集进行信息熵计算,然后找出信息增益最大的划分方式所对应的特征。

按照给定的特征划分数据集

'''
函数说明:按照给定特征划分数据集
参数:dataSet -- 待划分的数据集
      axis -- 划分数据集的特征
      value -- 特征的返回值
返回值:retDataSet -- 划分好的数据集
'''
def splitDataSet(dataset,axis,value):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == value:
            #把划分数据集的特征从数据集中剔除,
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


# print(splitDataSet(dataSet,0,1))
# print(splitDataSet(dataSet,0,0))

执行结果:


按照给定特征划分

目前我们只是根据给定的特征实现了数据集的划分,接下来我们需要遍历整个数据集,对每个特征进行划分并计算信息增益,找到最佳分类特征。

3. 寻找最佳划分特征

'''
函数说明:选择最佳划分特征
参数:dataset -- 待划分的数据集
返回值:最佳划分特征
'''
def chooseBestFeatureToSplit(dataset):
    # 数据集特征数
    numFeatures = len(dataset[0]) - 1
    # 原始数据的信息熵
    bestEntropy = calcShannonEnt(dataSet)
    #最大信息增益
    bestInfoGain = 0.0
    # 最佳划分特征
    bestFeature = -1
    for i in range(numFeatures):
        featureList = [example[i] for example in dataset]
        # 特征对应的取值
        uniqueVals = set(featureList)
        newEntropy = 0.0
        for value in uniqueVals:
            # 遍历划分每个特征的每个取值
            subDataSet = splitDataSet(dataset,i,value)
            prob = len(subDataSet) / float(len(dataSet))
            # 新数据集的信息熵
            newEntropy += prob * calcShannonEnt(subDataSet)
        # 信息增益
        infoGain = bestEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
# print(dataSet)
# print('最佳划分特征:特征',chooseBestFeatureToSplit(dataSet))

执行结果:


最佳划分特征

代码运行结果显示,第 0 个特征是划分数据集的最佳特征。也即是第一个特征取值为 1 的是一组,取值为 0 的是一组。可以看到,取值为 1 的分组中有两个是鱼类,有一个是非鱼类;取值为 0 的分组中,两个都是非鱼类。如果按照第二个特征来分类,则取值为 1 的分组中,有两个是鱼类,两个是非鱼类;取值为 0 的分组中,只有一个非鱼类。

4. 递归构造决策树

目前我们已经得到从数据集构造决策树算法所需要的子功能模块,其工作原理为:得到原始数据集,然后基于最佳分类特征对数据集进行划分,由于特征值可能不止两个,所以可能存在大于两个分支的数据集划分。第一次划分之后,数据被向下传递到数分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则来处理数据集。

递归结束的条件:当程序遍历完所有划分数据集的属性或所有分支下的实例都属于同一个分类。如果所有实例具有相同的分类,则得到一个叶子节点或者说终止块。任何到达叶子节点的数据必定属于叶子节点的分类。

由于特征数目并不一定每次划分都会减少,因此这些算法在实际使用时可能会出现一些问题。不过目前我们并不需要考虑这个问题,我们只需要在算法开始运行前计算出所有列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但分类标签并不唯一的情况下,我们通常会采用多数表决的方法决定该叶子节点的分类。
在树顶(根节点)处熵最高,逐层降低,直到数据被划分为各自的类型。

'''
函数说明:在叶子节点类别标签不唯一的情况下,采用多数表决的方法决定该叶子节点的分类
参数:classList -- 数据集所有分类标签
返回值:叶子节点中出现次数最多的分类标签
'''
def majorityCnt(classList):
    #初始化字典用于记录各分类标签及其出现的次数
    classCount = {}
    # 遍历数据集中的所有分类标签,并统计其出现的次数
    for vote in classList:
        if vote not in classList.keys():
            classCount[vote] = 0
        else:
            classCount[vote] += 1
    # 对分类标签出现的次数进行排序
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

'''
函数说明:创建树的函数代码,构造决策树
参数:dataSet -- 数据集
      labels -- 数据集所有分类标签
返回值:根据数据集构造得到的决策树
'''
def createTree(dataSet,labels):
    # 得到数据集的所有类别标签
    classList = [example[-1] for example in dataSet]
    # 如果数据集中的所有数据都属于同一类别,则停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #当遍历完所有特征时(即数据集的列数为1时),仍然不能将数据集划分成分类标签唯一的分组,则返回出现次数最多的分类标签
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    # 由当前最佳分类特征组成的决策树
    myTree = {bestFeatLabel:{}}
    # 当特征被添加到决策树中后,”就将其从分类标签中删除
    del(labels[bestFeat])
    # 最佳分类特征的取值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    # 删除最佳分类特征之后的所有分类标签
    for value  in uniqueVals:
        subLabels = labels[:]
        # 递归得到当前数据集的最佳分类特征,并构造决策树
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree
mytree = createTree(dataSet,labels)
# print(mytree)

执行结果为:

决策树

变量 mytree 包含了很多代表树结构信息的嵌套字典,第一个关键字 no surfacing 是第一个最佳数据集划分特征,该关键字的值又是另一个数据字典。关键字的值可能是分类标签,也有可能是另一个数据字典。如果值时分类标签,则该子节点就是叶子节点;否则该节点就是另一个判断节点,这种格式不断重复就构成了整棵树,也就是我们刚刚看到的执行结果。但是以字典的方式呈现决策树非常不便于理解,因此,我们接下来将绘制决策树,以更加直观的方式来观察决策树。

5. 在 Python 中使用 matplotlib 注解绘制决策树

'''
函数说明:使用matplotlib绘制树节点
参数:nodeText -- 节点注解
      centerPt -- 子节点
      parentPt -- 父节点
      nodeType -- 节点类型
返回值: 无

'''
# 定义判断节点、子节点以及箭头的格式
decisionNode = dict(boxstyle = "sawtooth",fc = "0.8")
leafNode = dict(boxstyle = "round4",fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

def plotNode(nodeText,centerPt,parentPt,nodeType):
    # 绘制带箭头的注解
    createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords = 'axes fraction',xytext = centerPt,textcoords = 'axes fraction',va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args)

def createPlot():
    fig = plt.figure(1,facecolor = 'white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111,frameon = False)
    # 绘制判断节点
    plotNode('a decison node',(0.5,0.1),(0.1,0.5),decisionNode)
    # 绘制子节点
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

# createPlot()

绘制树节点:

绘制树节点

现在我们已经学会了树节点的绘制,接下来我们将要对整棵树进行绘制。

绘制一颗完整的数需要一些技巧。虽然我们有 x , y 坐标,但如何放置所有的树节点是个问题。我们必须知道有多少个叶节点,以便我们能确定 x 轴的长度;我们还需要知道数有多少层,以便确定 y 轴的长度。因此这里我们定义了两个函数 getNumLeafs() 和 getTreeDepth() 来获取叶子节点的数目以及数的高度。

'''
函数说明:获取叶子节点的数目以及数的高度
参数:myTree -- 决策树
返回值: numLeafs -- 叶子节点数
        maxDepth -- 数的高度
'''
def getNumLeafs(myTree):
    #初始化叶子节点数
    numLeafs = 0
    #得到第一个关键字
    firstStr = list(myTree.keys())[0]
    # 得到第一个关键字的值,该值又是一个字典
    secondDict = myTree[firstStr]
    # 遍历第二个字典的值,判断是字典还是叶子节点
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 递归遍历
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

numLeafs = getNumLeafs(mytree)
# print('width of tree:',numLeafs)

def getTreeDepth(myTree):
    # 初始化数的高度
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

treeDepth = getTreeDepth(mytree)
# print('depth of tree:',treeDepth)

计算得到数的宽度和高度分别为:

计算数的宽度和高度
'''
函数说明:在父子节点之间填充文本信息
参数:cntrPt -- 子节点
      parentPt -- 父节点
      txtString -- 文本信息
返回值:无
'''
def plotMidText(cntrPt,parentPt,txtString):
    # 文本信息的横坐标
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

'''
函数说明:绘制的具体步骤
参数:mytree -- 决策树
      parentPt -- 父节点
      nodeTxt -- 节点信息
返回值:无
'''
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    # 决策树的第一个关键字
    firstStr = list(myTree.keys())[0]
    # 计算子节点的位置
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,plotTree.yOff)
    # 在父子节点之间填充文本信息
    plotMidText(cntrPt,parentPt,nodeTxt)
    # 绘制节点
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    # 递归绘制决策树
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff =  plotTree.xOff + 1 / plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


'''
函数说明:将决策树以图形的形式绘制出来
参数:inTree -- 决策树
返回值:无
'''
def createPlot(inTree):
    fig = plt.figure(1,facecolor = 'white')
    fig.clf()
    #将xy坐标存放于一个字典内,不过此时的xy坐标值为空
    axprops = dict(xticks=[],ytick=[])
    createPlot.ax1 = plt.subplot(111,frameon=False)
    # 数的宽度,用于计算判断节点的位置(应该在水平方向和垂直方向的中心位置)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    #例如有三个叶子节点,那么它们将x轴平分为三等分,坐标依次为1/3,2/3,3/3,但此时整个图像靠右,并不在画布的中心,因此将其向左移
    # plotTree.xOff、plotTree.yOff用于追踪已绘制节点的位置,以及放置下一个节点的位置
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()


# createPlot(mytree)

绘制得到的决策树如下:

绘制得到的决策树

到目前为止,我们已经学习而了如何构造决策树以及绘制树形图的方法,下一节我们将实际使用这些方法,并从算法和数据中得到新知识。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,362评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,330评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,247评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,560评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,580评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,569评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,929评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,587评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,840评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,596评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,678评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,366评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,945评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,929评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,165评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,271评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,403评论 2 342

推荐阅读更多精彩内容