ID3决策树算法及其Python实现

目录

  • 一、决策树算法
  • 基础理论
  • 决策树的学习过程
  • ID3算法
  • 二、实现针对西瓜数据集的ID3算法
  • 实现代码
  • 三、C4.5和CART的算法代码实现
  • C4.5算法
  • CART算法
  • 总结
  • 参考文章
  • 一、决策树算法

    决策树是一种基于树结构来进行决策的分类算法,我们希望从给定的训练数据集学得一个模型(即决策树),用该模型对新样本分类。决策树可以非常直观展现分类的过程和结果,一旦模型构建成功,对新样本的分类效率也相当高。
    最经典的决策树算法有ID3、C4.5、CART,其中ID3算法是最早被提出的,它可以处理离散属性样本的分类,C4.5和CART算法则可以处理更加复杂的分类问题,本文重点介绍ID3算法。

    基础理论

    1. 纯度(purity)
      对于一个分支结点,如果该结点所包含的样本都属于同一类,那么它的纯度为1,而我们总是希望纯度越高越好,也就是尽可能多的样本属于同一类别。那么如何衡量“纯度”呢?由此引入“信息熵”的概念。

    2. 信息熵(information entropy)
      假定当前样本集合D中第k类样本所占的比例为pk(k=1,2,…,|y|),则D的信息熵定义为:
      Ent(D) = -∑k=1 pk·log2 pk(约定若p=0,则log2 p=0)
      显然,Ent(D)值越小,D的纯度越高。因为0<=pk<= 1,故log2 pk<=0,Ent(D)>=0. 极限情况下,考虑D中样本同属于同一类,则此时的Ent(D)值为0(取到最小值)。当D中样本都分别属于不同类别时,Ent(D)取到最大值log2 |y|.

    3. 信息增益(information gain)
      假定离散属性a有V个可能的取值{a1,a2,…,aV}. 若使用a对样本集D进行分类,则会产生V个分支结点,记Dv为第v个分支结点包含的D中所有在属性a上取值为av的样本。不同分支结点样本数不同,我们给予分支结点不同的权重:|Dv|/|D|, 该权重赋予样本数较多的分支结点更大的影响、由此,用属性a对样本集D进行划分所获得的信息增益定义为:

      Gain(D,a) = Ent(D)-∑v=1 |Dv|/|D|·Ent(Dv)

    其中,Ent(D)是数据集D划分前的信息熵,∑v=1 |Dv|/|D|·Ent(Dv)可以表示为划分后的信息熵。“前-后”的结果表明了本次划分所获得的信息熵减少量,也就是纯度的提升度。显然,Gain(D,a) 越大,获得的纯度提升越大,此次划分的效果越好。

    1. 增益率(gain ratio)
      基于信息增益的最优属性划分原则——信息增益准则,对可取值数据较多的属性有所偏好。C4.5算法使用增益率替代信息增益来选择最优划分属性,增益率定义为:

      Gain_ratio(D,a) = Gain(D,a)/IV(a)

    其中

              IV(a) = -∑v=1 |Dv|/|D|·log2 |Dv|/|D|
    

    称为属性a的固有值。属性a的可能取值数目越多(即V越大),则IV(a)的值通常会越大。这在一定程度上消除了对可取值数据较多的属性的偏好。

    事实上,增益率准则对可取值数目较少的属性有所偏好,C4.5算法并不是直接使用增益率准则,而是先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。

    1. 基尼指数(Gini index)
      CART决策树算法使用基尼指数来选择划分属性,基尼指数定义为:

             Gini(D) = ∑k=1 ∑k'≠1 pk·pk' = 1- ∑k=1  pk·pk
      

      可以这样理解基尼指数:从数据集D中随机抽取两个样本,其类别标记不一致的概率。Gini(D)越小,纯度越高。

      属性a的基尼指数定义:

            Gain_index(D,a) = ∑v=1 |Dv|/|D|·Gini(Dv)
      

      使用基尼指数选择最优划分属性,即选择使得划分后基尼指数最小的属性作为最优划分属性。

    决策树的学习过程

    一棵决策树的生成过程主要分为以下3个部分:

    特征选择:特征选择是指从训练数据中众多的特征中选择一个特征作为当前节点的分裂标准,如何选择特征有着很多不同量化评估标准标准,从而衍生出不同的决策树算法。

    决策树生成: 根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止生长。 树结构来说,递归结构是最容易理解的方式。

    剪枝:决策树容易过拟合,一般来需要剪枝,缩小树结构规模、缓解过拟合。剪枝技术有预剪枝和后剪枝两种。

    在讲解特征选择前,我们先了解一些概念。

    决策树节点的不纯度(impurity)

    ID3算法

    ID3算法是最早提出的一种决策树算法,ID3算法的核心是在决策树各个节点上应用信息增益准则来选择特征,递归的构建决策树。具体方法是:从根节点开始,对节点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征,由该特征的不同取值建立子节点:再对子节点递归的调用以上方法,构建决策树:直到所有的特征信息增益均很小或没有特征可以选择为止。
    决策树是根据信息增益来进行特征选择的,信息增益定义为

    其中D为总的样本,a为属性,v为在属性a中的v类样本,信息增益越大,表明该属性对分类的相关性越大。Ent()表示信息熵(entropy),公式如下:

    k表示在样本D中的第k类样本,pk表示第k类样本所占样本总体的概率。类比于现实中的熵,可以理解为,信息熵越小,表明纯度越高。

    二、实现针对西瓜数据集的ID3算法

    watermalon.txt文件

    实现代码

    import numpy as np
    import pandas as pd
    import math
    data = pd.read_csv('./watermalon.txt')
    data
    
    

    def info(x,y):
        if x != y and x != 0:
            # 计算当前情况的熵
            return -(x/y)*math.log2(x/y) - ((y-x)/y)*math.log2((y-x)/y)
        if x == y or x == 0:
            # 纯度最大,熵值为0
            return 0
            info_D = info(8,17)
            info_D
    
    # 计算每种情况的熵
    seze_black_entropy = -(4/6)*math.log2(4/6)-(2/6)*math.log2(2/6)
    seze_green_entropy = -(3/6)*math.log2(3/6)*2
    seze_white_entropy = -(1/5)*math.log2(1/5)-(4/5)*math.log2(4/5)
    
    # 计算色泽特征色信息熵
    seze_entropy = (6/17)*seze_black_entropy+(6/17)*seze_green_entropy+(5/17)*seze_white_entropy
    print(seze_entropy)
    # 计算信息增益
    info_D - seze_entropy
    
    data.根蒂.value_counts()
    # 查看每种根蒂中好坏瓜情况的分布情况
    print(data[data.根蒂=='蜷缩'])
    print(data[data.根蒂=='稍蜷'])
    print(data[data.根蒂=='硬挺'])
    
    gendi_entropy = (8/17)*info(5,8)+(7/17)*info(3,7)+(2/17)*info(0,2)
    gain_col = info_D - gendi_entropy
    gain_col
    

    data.敲声.value_counts()
    # 查看每种敲声中好坏瓜情况的分布情况
    print(data[data.敲声=='浊响'])
    print(data[data.敲声=='沉闷'])
    print(data[data.敲声=='清脆'])
    qiaosheng_entropy = (10/17)*info(6,10)+(5/17)*info(2,5)+(2/17)*info(0,2)
    info_gain = info_D - qiaosheng_entropy
    info_gain
    

    data.纹理.value_counts()
    # 查看每种纹理中好坏瓜情况的分布情况
    print(data[data.纹理=="清晰"])
    print(data[data.纹理=="稍糊"])
    print(data[data.纹理=="模糊"])
    wenli_entropy = (9/17)*info(7,9)+(5/17)*info(1,5)+(3/17)*info(0,3)
    info_gain = info_D - wenli_entropy
    info_gain
    
    

    data.脐部.value_counts()
    # 查看每种脐部中好坏瓜情况的分布情况
    print(data[data.脐部=="凹陷"])
    print(data[data.脐部=="稍凹"])
    print(data[data.脐部=="平坦"])
    qidai_entropy = (7/17)*info(5,7)+(6/17)*info(3,6)+(4/17)*info(0,4)
    info_gain = info_D - qidai_entropy
    info_gain
    
    

    # 查看触感的值得情况
    data.触感.value_counts()
    # 查看每种脐部中好坏瓜情况的分布情况
    print(data[data.触感=="硬滑"])
    print(data[data.触感=="软粘"])
    chugna_entropy = (12/17)*info(6,12)+(5/17)*info(2,5)
    info_D - chugna_entropy
    

  • 绘制可视化树
  • import matplotlib.pylab as plt
    import matplotlib
    
    # 能够显示中文
    matplotlib.rcParams['font.sans-serif'] = ['SimHei']
    matplotlib.rcParams['font.serif'] = ['SimHei']
    
    # 分叉节点,也就是决策节点
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    
    # 叶子节点
    leafNode = dict(boxstyle="round4", fc="0.8")
    
    # 箭头样式
    arrow_args = dict(arrowstyle="<-")
    
    
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        """
        绘制一个节点
        :param nodeTxt: 描述该节点的文本信息
        :param centerPt: 文本的坐标
        :param parentPt: 点的坐标,这里也是指父节点的坐标
        :param nodeType: 节点类型,分为叶子节点和决策节点
        :return:
        """
        createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                                xytext=centerPt, textcoords='axes fraction',
                                va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    
    
    def getNumLeafs(myTree):
        """
        获取叶节点的数目
        :param myTree:
        :return:
        """
        # 统计叶子节点的总数
        numLeafs = 0
    
        # 得到当前第一个key,也就是根节点
        firstStr = list(myTree.keys())[0]
    
        # 得到第一个key对应的内容
        secondDict = myTree[firstStr]
    
        # 递归遍历叶子节点
        for key in secondDict.keys():
            # 如果key对应的是一个字典,就递归调用
            if type(secondDict[key]).__name__ == 'dict':
                numLeafs += getNumLeafs(secondDict[key])
            # 不是的话,说明此时是一个叶子节点
            else:
                numLeafs += 1
        return numLeafs
    
    
    def getTreeDepth(myTree):
        """
        得到数的深度层数
        :param myTree:
        :return:
        """
        # 用来保存最大层数
        maxDepth = 0
    
        # 得到根节点
        firstStr = list(myTree.keys())[0]
    
        # 得到key对应的内容
        secondDic = myTree[firstStr]
    
        # 遍历所有子节点
        for key in secondDic.keys():
            # 如果该节点是字典,就递归调用
            if type(secondDic[key]).__name__ == 'dict':
                # 子节点的深度加1
                thisDepth = 1 + getTreeDepth(secondDic[key])
    
            # 说明此时是叶子节点
            else:
                thisDepth = 1
    
            # 替换最大层数
            if thisDepth > maxDepth:
                maxDepth = thisDepth
    
        return maxDepth
    
    
    def plotMidText(cntrPt, parentPt, txtString):
        """
        计算出父节点和子节点的中间位置,填充信息
        :param cntrPt: 子节点坐标
        :param parentPt: 父节点坐标
        :param txtString: 填充的文本信息
        :return:
        """
        # 计算x轴的中间位置
        xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
        # 计算y轴的中间位置
        yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
        # 进行绘制
        createPlot.ax1.text(xMid, yMid, txtString)
    
    
    def plotTree(myTree, parentPt, nodeTxt):
        """
        绘制出树的所有节点,递归绘制
        :param myTree: 树
        :param parentPt: 父节点的坐标
        :param nodeTxt: 节点的文本信息
        :return:
        """
        # 计算叶子节点数
        numLeafs = getNumLeafs(myTree=myTree)
    
        # 计算树的深度
        depth = getTreeDepth(myTree=myTree)
    
        # 得到根节点的信息内容
        firstStr = list(myTree.keys())[0]
    
        # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    
        # 绘制该节点与父节点的联系
        plotMidText(cntrPt, parentPt, nodeTxt)
    
        # 绘制该节点
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
    
        # 得到当前根节点对应的子树
        secondDict = myTree[firstStr]
    
        # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    
        # 循环遍历所有的key
        for key in secondDict.keys():
            # 如果当前的key是字典的话,代表还有子树,则递归遍历
            if isinstance(secondDict[key], dict):
                plotTree(secondDict[key], cntrPt, str(key))
            else:
                # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                # 打开注释可以观察叶子节点的坐标变化
                # print((plotTree.xOff, plotTree.yOff), secondDict[key])
                # 绘制叶子节点
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                # 绘制叶子节点和父节点的中间连线内容
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    
        # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    
    def createPlot(inTree):
        """
        需要绘制的决策树
        :param inTree: 决策树字典
        :return:
        """
        # 创建一个图像
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        # 计算出决策树的总宽度
        plotTree.totalW = float(getNumLeafs(inTree))
        # 计算出决策树的总深度
        plotTree.totalD = float(getTreeDepth(inTree))
        # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
        plotTree.xOff = -0.5/plotTree.totalW
        # 初始的y轴偏移量,每次向下或者向上移动1/D
        plotTree.yOff = 1.0
        # 调用函数进行绘制节点图像
        plotTree(inTree, (0.5, 1.0), '')
        # 绘制
        plt.show()
    
    
    if __name__ == '__main__':
        createPlot(mytree)
    
    

    三、C4.5和CART的算法代码实现

    C4.5算法

  • C4.5是决策树算法的一种。决策树算法作为一种分类算法,目标就是将具有p维特征的n个样本分到c个类别中去。相当于做一个投影,c=f(n),将样本经过一种变换赋予一种类别标签。决策树为了达到这一目的,可以把分类的过程表示成一棵树,每次通过选择一个特征pi来进行分叉。
    信息增益比
  • 训练数据关于特征的信息增益为,关于的熵为,,则关于的信息增益比为信息增益为与熵的比值:

    代码:

    def calcGainRatio(dataSet, labelIndex, labelPropertyi):
        """
        type: (list, int, int) -> float, int
        计算信息增益率,返回信息增益率和连续属性的划分点
        dataSet: 数据集
        labelIndex: 特征值索引
        labelPropertyi: 特征值类型,0为离散,1为连续
        """
        baseEntropy = calcShannonEnt(dataSet, labelIndex)  # 计算根节点的信息熵
        featList = [example[labelIndex] for example in dataSet]  # 特征值列表
        uniqueVals = set(featList)  # 该特征包含的所有值
        newEntropy = 0.0
        bestPartValuei = None
        IV = 0.0
        totalWeight = 0.0
        totalWeightV = 0.0
        totalWeight = calcTotalWeight(dataSet, labelIndex, True)  # 总样本权重
        totalWeightV = calcTotalWeight(dataSet, labelIndex, False)  # 非空样本权重
        if labelPropertyi == 0:  # 对离散的特征
            for value in uniqueVals:  # 对每个特征值,划分数据集, 计算各子集的信息熵
                subDataSet = splitDataSet(dataSet, labelIndex, value)
                totalWeightSub = 0.0
                totalWeightSub = calcTotalWeight(subDataSet, labelIndex, True)
                if value != 'N':
                    prob = totalWeightSub / totalWeightV
                    newEntropy += prob * calcShannonEnt(subDataSet, labelIndex)
                prob1 = totalWeightSub / totalWeight
                IV -= prob1 * log(prob1, 2)
        else:  # 对连续的特征
            uniqueValsList = list(uniqueVals)
            if 'N' in uniqueValsList:
                uniqueValsList.remove('N')
                # 计算空值样本的总权重,用于计算IV
                totalWeightN = 0.0
                dataSetNull = splitDataSet(dataSet, labelIndex, 'N')
                totalWeightN = calcTotalWeight(dataSetNull, labelIndex, True)
                probNull = totalWeightN / totalWeight
                if probNull > 0.0:
                    IV += -1 * probNull * log(probNull, 2)
     
            sortedUniqueVals = sorted(uniqueValsList)  # 对特征值排序
            listPartition = []
            minEntropy = inf
     
            if len(sortedUniqueVals) == 1:  # 如果只有一个值,可以看作只有左子集,没有右子集
                totalWeightLeft = calcTotalWeight(dataSet, labelIndex, True)
                probLeft = totalWeightLeft / totalWeightV
                minEntropy = probLeft * calcShannonEnt(dataSet, labelIndex)
                IV = -1 * probLeft * log(probLeft, 2)
            else:
                for j in range(len(sortedUniqueVals) - 1):  # 计算划分点
                    partValue = (float(sortedUniqueVals[j]) + float(
                        sortedUniqueVals[j + 1])) / 2
                    # 对每个划分点,计算信息熵
                    dataSetLeft = splitDataSet(dataSet, labelIndex, partValue, 'L')
                    dataSetRight = splitDataSet(dataSet, labelIndex, partValue, 'R')
                    totalWeightLeft = 0.0
                    totalWeightLeft = calcTotalWeight(dataSetLeft, labelIndex, True)
                    totalWeightRight = 0.0
                    totalWeightRight = calcTotalWeight(dataSetRight, labelIndex, True)
                    probLeft = totalWeightLeft / totalWeightV
                    probRight = totalWeightRight / totalWeightV
                    Entropy = probLeft * calcShannonEnt(
                        dataSetLeft, labelIndex) + probRight * calcShannonEnt(dataSetRight, labelIndex)
                    if Entropy < minEntropy:  # 取最小的信息熵
                        minEntropy = Entropy
                        bestPartValuei = partValue
                        probLeft1 = totalWeightLeft / totalWeight
                        probRight1 = totalWeightRight / totalWeight
                        IV = -1 * (probLeft * log(probLeft, 2) + probRight * log(probRight, 2))
     
            newEntropy = minEntropy
        gain = totalWeightV / totalWeight * (baseEntropy - newEntropy)
        if IV == 0.0:  # 如果属性只有一个值,IV为0,为避免除数为0,给个很小的值
            IV = 0.0000000001
        gainRatio = gain / IV
        return gainRatio, bestPartValuei
    

    CART算法

    CART算法构造的是二叉决策树,决策树构造出来后同样需要剪枝,才能更好的应用于未知数据的分类。CART算法在构造决策树时通过基尼系数来进行特征选择。

    def entropy(data):
        length = data.size
        ent = 0
        for i in data.value_counts(): #查看表格某列中有多少个不同值
            prob = i / length
            ent += - prob * (np.log2(prob))
        return ent
    print('--------')
    entD = entropy(data['好瓜'])
    print(entD) # 0.9975025463691153
    
    #计算Gini指数
    
    
    def gini_discrete(data, input_column, output_column): #离散
       
        ret = 0
        lens = data[output_column].size
        all_attribute = data[input_column].value_counts()  # 保存特征全部属性的取值个数
        for name in data[input_column].unique(): # 特征的不同属性名
            print(name)
            temp = 1
            for i in range(len(data[output_column].unique())):  # 好瓜 or 坏瓜
                attribute_num = data[input_column].where(data[output_column] == data[output_column].unique()[i]).value_counts()
                try:
                    prob = int(attribute_num[name]) / int(all_attribute[name])
                except:
                    prob = 0
                if prob == 0:
                    temp += 0
                else:
                    temp -= prob * prob
                # 还需要乘以该属性出现的概率
            ret += temp * (all_attribute[name] / lens)
        return ret
    
    def gini_continuous(data, input_column, output_column): #连续
      
        lens = data[output_column].size
        gini = 0
        T = []
        Gini = [] #用来寻找最小的gini_index
        values = sorted(data[input_column].values)
        for i in range(lens - 1):
            good_n = 0
            good_p = 0
            bad_n = 0
            bad_p = 0
            t = round(((values[i] + values[i+1]) / 2), 3)
            T.append(t)
            for index in data.index:
                if data[input_column].values[index-1] < t:
                    if data['好瓜'].values[index-1] == '是' :
                        good_n += 1
                    else:
                        bad_n += 1
                else:
                    if data['好瓜'].values[index-1] == '是' :
                        good_p += 1
                    else:
                        bad_p += 1
            
            dn_sum = i + 1 #小于候选划分总和
            dp_sum = lens - i - 1 #大于候选划分总和
            prob = dn_sum / lens
            gini_n = 1 - (np.square(good_n / dn_sum) + np.square(bad_n / dn_sum))
            gini_p = 1 - (np.square(good_p / dp_sum) + np.square(bad_p / dp_sum))
            gini = prob * gini_n + (1 - prob) * gini_p
            Gini.append(gini)
            
        print("对应划分点为:",T[Gini.index(min(Gini))])
        return T[Gini.index(min(Gini))], min(Gini)
    
    
    max_n_features = 4 #控制树的深度
    
    decisionTree = {}
    
    def transLabel(label):
        if label == '是':
            return '好瓜'
        else:
            return '坏瓜'
    
    def createTree(data, features):
        """
        data: input the name of DataFrame
        features: input the list of features
        """
        bestFeaIndex, bestFeatureName, at = chooseBestFeature(data) 
        bestFeatureValue = data[bestFeatureName].values
        attrCount = 0
        attr = []
        
        sameLvTree = {}
        tempTree = {}
        
        DataGood = pd.DataFrame()
        DataBad = pd.DataFrame()
        
        features.remove(bestFeatureName)
        
        if bestFeatureName != '密度' and '含糖率':      
            for name in data[bestFeatureName].unique():
                attrCount += 1
                attr.append(name)
            for i in range(attrCount):
                dataName='DataSubSet'+str(i) #根据属性的取值个数动态生成子集
                locals()['DataSubSet'+str(i)] = pd.DataFrame()
            for index in data.index:
                for i in range(attrCount):
                    if data[bestFeatureName].values[index-1] == attr[i]:
                        locals()['DataSubSet'+str(i)] = locals()['DataSubSet'+str(i)].append(data[index-1:index], sort=False)
                        break
            
            outputCount = 0
            for i in range(attrCount):
                print()
                print(attr[i])
                locals()['DataSubSet'+str(i)] = locals()['DataSubSet'+str(i)].drop(columns = [bestFeatureName])
                locals()['DataSubSet'+str(i)] = locals()['DataSubSet'+str(i)].reset_index(drop=True)
                locals()['DataSubSet'+str(i)].index += 1
                print(locals()['DataSubSet'+str(i)])
                print()
                for name in locals()['DataSubSet'+str(i)]['好瓜'].unique():
                    outputCount += 1
                if outputCount == 1:
                    print("*******",bestFeatureName, attr[i], locals()['DataSubSet'+str(i)]['好瓜'].values[0])
                    sameLvTree[attr[i]] = transLabel(locals()['DataSubSet'+str(i)]['好瓜'].values[0])
                    tempTree[bestFeatureName] = sameLvTree
                    outputCount = 0
                else:
                    print("*******",bestFeatureName, attr[i], '?')
                    outputCount = 0
                    if len(features) > max_n_features:
                        sameLvTree[attr[i]] = createTree(locals()['DataSubSet'+str(i)], features)
                        print(sameLvTree[attr[i]])
                        tempTree[bestFeatureName] = sameLvTree
        
        else:
            
            DataN = pd.DataFrame()
            DataP = pd.DataFrame()
            for index in data.index:
                if data[bestFeatureName].values[index-1] < at:
                    DataN = DataN.append(data[index-1:index], sort=False)
                else:
                    DataP = DataP.append(data[index-1:index], sort=False)
            
            outputCount = 0
            
            print()
            print('<=', at)
            DataN = DataN.drop(columns = [bestFeatureName])
            DataN = DataN.reset_index(drop=True)
            DataN.index += 1
            print(DataN)
            print()
            for name in DataN['好瓜'].unique():
                outputCount += 1
            if outputCount == 1:
                print("*******",bestFeatureName, '<={}'.format(at), DataN['好瓜'].values[0])
                sameLvTree['<={}'.format(at)] = transLabel(DataN['好瓜'].values[0])
                tempTree[bestFeatureName] = sameLvTree
                print(tempTree)
                outputCount = 0
            else:
                print("*******",bestFeatureName, '<={}'.format(at), '?')
                outputCount = 0
                if len(features) > max_n_features:
                    sameLvTree['<={}'.format(at)] = createTree(DataN, features)
                    tempTree[bestFeatureName] = sameLvTree
                    print(tempTree)
            
            print()
            print('>', at)
            DataP = DataP.drop(columns = [bestFeatureName])
            DataP = DataP.reset_index(drop=True)
            DataP.index += 1
            print(DataP)
            print()
            for name in DataP['好瓜'].unique():
                outputCount += 1
            if outputCount == 1:
                print("*******",bestFeatureName, '>{}'.format(at), DataP['好瓜'].values[0])
                print(bestFeatureName)
                sameLvTree['>{}'.format(at)] = transLabel(DataP['好瓜'].values[0])
                tempTree[bestFeatureName] = sameLvTree
                outputCount = 0
            else:
                print("*******",bestFeatureName, '>{}'.format(at), '?')
                outputCount = 0
                if len(features) > max_n_features:
                    sameLvTree['>{}'.format(at)] = createTree(DataP, features)
                    tempTree[bestFeatureName] = sameLvTree
            
        return tempTree
    
    
    features = list(data.columns[0:-1]) # x的表头
    decisionTree = createTree(data, features)
    print(decisionTree)
    
    

    总结

    决策树中ID3算法中的信息增益准则对取值数目较多的属性有所偏好,而C4.5算法的基本流程与ID3类似,但C4.5算法进行特征选择时不是通过计算信息增益完成的,而是通过信息增益比来进行特征选择。CART算法构造的是二叉决策树,决策树构造出来后同样需要剪枝,才能更好的应用于未知数据的分类。CART算法在构造决策树时通过基尼系数来进行特征选择。

    参考文章

    机器学习算法(3)之决策树算法
    决策树算法(ID3算法详解)

    【机器学习】CART决策树原理及python实现

    物联沃分享整理
    物联沃-IOTWORD物联网 » ID3决策树算法及其Python实现

    发表评论