决策树实验

实验四 决策树实验

一、实验目的

1.认识决策树的构建过程;

2.分析决策树可视化的运算;

3.掌握决策树的应用。

二、实验内容

​ 决策树是什么?决策树(decision tree)是一种基本的分类与回归方法。举个通俗易懂的例子,如下图所示的流程图就是一个决策树,长方形代表判断模块,椭圆形成代表终止模块,表示已经得出结论,可以终止运行。从判断模块引出的左右箭头称作为分支,它可以达到另一个判断模块或者终止模块。我们还可以这样理解,分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点和有向边组成。结点有两种类型:内部结点和叶结点。内部结点表示一个特征或属性,叶结点表示一个类。

image-20241209145103826

​ 使用决策树做预测需要以下过程:

  1. 收集数据:可以使用任何方法。比如想构建一个系统,我们可以从亲朋好友那里获取数据。根据他们考虑的因素和最终的选择结果,就可以得到一些供我们利用的数据了。

  2. 准备数据:收集完的数据,我们要进行整理,将这些所有收集的信息按照一定规则整理出来,并排版,方便我们进行后续处理。

  3. 分析数据:可以使用任何方法,决策树构造完成之后,我们可以检查决策树图形是否符合预期。

  4. 训练算法:这个过程也就是构造决策树,同样也可以说是决策树学习,就是构造一个决策树的数据结构。

  5. 测试算法:使用经验树计算错误率。当错误率达到了可接收范围,这个决策树就可以投放使用了。

  6. 使用算法:此步骤可以使用适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

三、实验材料与工具

1.电脑一台;

2.Pycharm。

四、实验步骤

1、分析以下数据

image-20241209145712953

​ 特征选择

​ 特征选择在于选取对训练数据具有分类能力的特征。这样可以提高决策树学习的效率,如果利用一个特征进行分类的结果与随机分类的结果没有很大差别,则称这个特征没有分类能力。扔掉这样的特征对决策树学习的精度影响不大。通常特征选择的标准是信息增益(information gain)或信息增益比,本次实验使用信息增益作为选择特征的标准。那么,什么是信息增益?在讲解信息增益之前,让我们看一组实例,贷款申请样本数据表。

​ 希望通过所给的训练数据学习一个贷款申请的决策树,用于对未来的贷款申请进行分类,即当新的客户提出贷款申请时,根据申请人的特征利用决策树决定是否批准贷款申请。

​ 特征选择就是决定用哪个特征来划分特征空间。比如,我们通过上述数据表得到两个可能的决策树,分别由两个不同特征的根结点构成。

image-20241209145748807

​ 左图所示的根结点的特征是年龄,有3个取值,对应于不同的取值有不同的子结点。右图所示的根节点的特征是工作,有2个取值,对应于不同的取值有不同的子结点。两个决策树都可以从此延续下去。问题是:究竟选择哪个特征更好些?这就要求确定选择特征的准则。直观上,如果一个特征具有更好的分类能力,或者说,按照这一特征将训练数据集分割成子集,使得各个子集在当前条件下有最好的分类,那么就更应该选择这个特征。信息增益就能够很好地表示这一直观的准则。

​ 什么是信息增益呢?在划分数据集之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

2、计算香农熵

​ 在可以评测哪个数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵(entropy),这个名字来源于信息论之父克劳德·香农。

​ 熵定义为信息的期望值。在信息论与概率统计中,熵是表示随机变量不确定性的度量。如果待分类的事物可能划分在多个分类之中,则符号xi的信息定义为 :

image-20241209145827471

​ 其中p(xi)是选择该分类的概率。上述式中的对数以2为底,也可以e为底(自然对数)。

​ 通过上式,我们可以得到所有类别的信息。为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值(数学期望),通过下面的公式得到:

image-20241209145840117

​ 其中n是分类的数目。熵越大,随机变量的不确定性就越大。

​ 当熵中的概率由数据估计(特别是最大似然估计)得到时,所对应的熵称为经验熵(empirical entropy)。比如有10个数据,一共有两个类别,A类和B类。其中有7个数据属于A类,则该A类的概率即为十分之七。其中有3个数据属于B类,则该B类的概率即为十分之三。这概率是我们根据数据数出来的。我们定义贷款申请样本数据表中的数据为训练数据集D,则训练数据集D的经验熵为H(D),|D|表示其样本容量,及样本个数。设有K个类Ck, = 1,2,3,…,K,|Ck|为属于类Ck的样本个数,因此经验熵公式就可以写为 :

image-20241209145900014

​ 根据此公式计算经验熵H(D),分析贷款申请样本数据表中的数据。最终分类结果只有两类,即放贷和不放贷。根据表中的数据统计可知,在15个数据中,9个数据的结果为放贷,6个数据的结果为不放贷。所以数据集D的经验熵H(D)为:

image-20241209145916351

​ 经过计算可知,数据集D的经验熵H(D)的值为0.971。

3、编写代码计算经验熵

​ 首先对数据集进行属性标注。

  • 年龄:0代表青年,1代表中年,2代表老年;

  • 有工作:0代表否,1代表是;

  • 有自己的房子:0代表否,1代表是;

  • 信贷情况:0代表一般,1代表好,2代表非常好;

  • 类别(是否给贷款):no代表否,yes代表是。

    ​ 然后创建数据集,并计算经验熵。

    ​ 代码编写如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from matplotlib.font_manager import FontProperties  
import matplotlib.pyplot as plt
from math import log
import operator
import pickle

def createDataSet():
dataSet = [[0,0,0,0,'no'],
[0,0,0,1,'no'],
[0,1,0,1,'yes'],
[0,1,1,0,'yes'],
[0,0,0,0,'no'],
[1,0,0,0,'no'],
[1,0,0,1,'no'],
[1,1,1,1,'yes'],
[1,0,1,2,'yes'],
[1,0,1,2,'yes'],
[2,0,1,2,'yes'],
[2,0,1,1,'yes'],
[2,1,0,1,'yes'],
[2,1,0,2,'yes'],
[2,0,0,0,'no']
]
labels = ['年龄','有工作','有自己的房子','信贷情况']
return dataSet,labels

def calcShannonEnt(dataSet):
numEntires = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntires
shannonEnt -= prob * log(prob,2)
return shannonEnt
if __name__ == '__main__':
dataSet,labels = createDataSet()
print(dataSet)
print(calcShannonEnt(dataSet))

​ 代码运行结果如下图所示,代码是先打印训练数据集,然后打印计算的经验熵H(D),程序计算的结果与我们统计计算的结果是一致的,程序没有问题。

image-20241209150622268

4、计算信息增益

​ 在上面,我们已经说过,如何选择特征,需要看信息增益。也就是说,信息增益是相对于特征而言的,信息增益越大,特征对最终的分类结果影响也就越大,我们就应该选择对最终分类结果影响最大的那个特征作为我们的分类特征。

​ 条件熵H(Y|X)表示在已知随机变量X的条件下随机变量Y的不确定性,随机变量X给定的条件下随机变量Y的条件熵(conditional entropy)H(Y|X),定义为X给定条件下Y的条件概率分布的熵对X的数学期望:

image-20241209150645183

image-20241209150648200

​ 当条件熵中的概率由数据估计(特别是极大似然估计)得到时,所对应的条件熵称为条件经验熵(empirical conditional entropy)。

​ 明确了条件熵和经验条件熵的概念。接下来,说说信息增益。前面也提到了,信息增益是相对于特征而言的。所以,特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差,即:

image-20241209150704905

​ 熵H(D)与条件熵H(D|A)之差称为互信息(mutual information)。决策树学习中的信息增益等价于训练数据集中类与特征的互信息。

​ 设特征A有n个不同的取值{a1,a2,···,an},根据特征A的取值将D划分为n个子集{D1,D2,···,Dn},|Di|为Di的样本个数。记子集Di中属于Ck的样本的集合为Dik,即Dik = Di ∩ Ck,|Dik|为Dik的样本个数。于是经验条件熵的公式可以写为:

image-20241209150723981

​ 以贷款申请样本数据表为例进行说明。看下年龄这一列的数据,也就是特征A1,一共有三个类别,分别是:青年、中年和老年。我们只看年龄是青年的数据,年龄是青年的数据一共有5个,所以年龄是青年的数据在训练数据集出现的概率是十五分之五,也就是三分之一。同理,年龄是中年和老年的数据在训练数据集出现的概率也都是三分之一。现在我们只看年龄是青年的数据的最终得到贷款的概率为五分之二,因为在五个数据中,只有两个数据显示拿到了最终的贷款,同理,年龄是中年和老年的数据最终得到贷款的概率分别为五分之三、五分之四。所以计算年龄的信息增益,过程如下:

image-20241209150737040

​ 同理,计算其余特征的信息增益g(D,A2)、g(D,A3)和g(D,A4)。分别为:

image-20241209150748169

image-20241209150752446

image-20241209150755655

​ 最后,比较特征的信息增益,由于特征A3(有自己的房子)的信息增益值最大,所以选择A3作为最优特征。

5、编写代码计算信息增益

​ 我们已经学会了通过公式计算信息增益,接下来编写代码,计算信息增益。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def calcShannonEnt(dataSet):  
numEntires = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntires
shannonEnt -= prob * log(prob,2)
return shannonEnt

def createDataSet():
dataSet = [[0,0,0,0,'no'],
[0,0,0,1,'no'],
[0,1,0,1,'yes'],
[0,1,1,0,'yes'],
[0,0,0,0,'no'],
[1,0,0,0,'no'],
[1,0,0,1,'no'],
[1,1,1,1,'yes'],
[1,0,1,2,'yes'],
[1,0,1,2,'yes'],
[2,0,1,2,'yes'],
[2,0,1,1,'yes'],
[2,1,0,1,'yes'],
[2,1,0,2,'yes'],
[2,0,0,0,'no']
]
labels = ['年龄','有工作','有自己的房子','信贷情况']
return dataSet,labels

def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
beatFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
print("第%d个特征的增益为%.3f" % (i,infoGain))
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature


if __name__ == '__main__':
dataSet,labels = createDataSet()
print(dataSet)
print(labels)
print(calcShannonEnt(dataSet))
print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))

​ splitDataSet函数是用来选择各个特征的子集的,比如选择年龄(第0个特征)的青年(用0代表)的自己,我们可以调用splitDataSet(dataSet,0,0)这样返回的子集就是年龄为青年的5个数据集。chooseBestFeatureToSplit是选择选择最优特征的函数。运行代码结果如下:

image-20241209151321875

​ 对比我们自己计算的结果,最优特征的索引值为2,也就是特征A3(有自己的房子)。

​ 这样就生成了一个决策树,该决策树只用了两个特征(有两个内部结点),生成的决策树如下图所示。

image-20241209151335494

​ 我们已经学习了从数据集构造决策树算法所需要的子功能模块,包括经验熵的计算和最优特征的选择,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据集被向下传递到树的分支的下一个结点。在这个结点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。

​ 决策树生成算法递归地产生决策树,直到不能继续下去未为止。这样产生的树往往对训练数据的分类很准确,但对未知的测试数据的分类却没有那么准确,即出现过拟合现象。过拟合的原因在于学习时过多地考虑如何提高对训练数据的正确分类,从而构建出过于复杂的决策树。解决这个问题的办法是考虑决策树的复杂度,对已生成的决策树进行简化。

6、构建决策树

​ 利用求得的结果,由于特征A3(有自己的房子)的信息增益值最大,所以选择特征A3作为根结点的特征。它将训练集D划分为两个子集D1(A3取值为”是”)和D2(A3取值为”否”)。由于D1只有同一类的样本点,所以它成为一个叶结点,结点的类标记为“是”。

​ 对D2则需要从特征A1(年龄),A2(有工作)和A4(信贷情况)中选择新的特征,计算各个特征的信息增益:

image-20241209151412616

​ 根据计算,选择信息增益最大的特征A2(有工作)作为结点的特征。由于A2有两个可能取值,从这一结点引出两个子结点:一个对应”是”(有工作)的子结点,包含3个样本,它们属于同一类,所以这是一个叶结点,类标记为”是”;另一个是对应”否”(无工作)的子结点,包含6个样本,它们也属于同一类,所以这也是一个叶结点,类标记为”否”。

​ 这样就生成了一个决策树,该决策树只用了两个特征(有两个内部结点),生成的决策树如下图所示。

image-20241209151431267

​ 这样我们就使用ID3算法构建出来了决策树,我们使用字典存储决策树的结构:

image-20241209151442098

​ 创建函数majorityCnt统计classList中出现此处最多的元素(类标签),创建函数createTree用来递归构建决策树。编写代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def majorityCnt(classList):  
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]

def createTree(dataSet,labels,featLabels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1 or len(labels) == 0:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels,featLabels)
return myTree
if __name__ == '__main__':
dataSet,labels = createDataSet()
print(dataSet)
print(labels)
print(calcShannonEnt(dataSet))
print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
featLabels = []
myTree = createTree(dataSet,labels,featLabels)
print(myTree)

​ 递归创建决策树时,递归有两个终止条件:第一个停止条件是所有的类标签完全相同,则直接返回该类标签;第二个停止条件是使用完了所有特征,仍然不能将数据划分仅包含唯一类别的分组,即决策树构建失败,特征不够用。此时说明数据纬度不够,由于第二个停止条件无法简单地返回唯一的类标签,这里挑选出现数量最多的类别作为返回值。

五、实验代码汇总

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from matplotlib.font_manager import FontProperties  
import matplotlib.pyplot as plt
from math import log
import operator
import pickle

def createDataSet():
dataSet = [[0,0,0,0,'no'],
[0,0,0,1,'no'],
[0,1,0,1,'yes'],
[0,1,1,0,'yes'],
[0,0,0,0,'no'],
[1,0,0,0,'no'],
[1,0,0,1,'no'],
[1,1,1,1,'yes'],
[1,0,1,2,'yes'],
[1,0,1,2,'yes'],
[2,0,1,2,'yes'],
[2,0,1,1,'yes'],
[2,1,0,1,'yes'],
[2,1,0,2,'yes'],
[2,0,0,0,'no']
]
labels = ['年龄','有工作','有自己的房子','信贷情况']
return dataSet,labels

def calcShannonEnt(dataSet):
numEntires = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntires
shannonEnt -= prob * log(prob,2)
return shannonEnt

def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
beatFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
print("第%d个特征的增益为%.3f" % (i,infoGain))
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature

def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]

def createTree(dataSet,labels,featLabels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1 or len(labels) == 0:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels,featLabels)
return myTree

#决策树可视化
def getNumLeafs(myTree):
numLeafs = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs

def getTreeDepth(myTree):
maxDepth = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:maxDepth = thisDepth
return maxDepth

def plotNode(nodeTxt,centerPt,parentPt,nodeType):
arrow_args = dict(arrowstyle="<-")
font = FontProperties(fname=r"c:\windows\fonts\simkai.ttf",size=14)
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
xytext=centerPt,textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args,fontproperties=font)

def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)

def plotTree(myTree,parentPt,nodeTxt):
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = next(iter(myTree))
cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD

def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -0.5/plotTree.totalW;plotTree.y0ff = 1.0;
plotTree(inTree,(0.5,1.0), '')
plt.show()

#分类
def classify(inputTree,featLabels,testVec):
firstStr = next(iter(inputTree))
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel

#存储
def storeTree(inputTree,filename):
with open(filename,'wb') as fw:
pickle.dump(inputTree,fw)

#载入
def grabTree(filename):
fr = open(filename,'rb')
return pickle.load(fr)

if __name__ == '__main__':
dataSet,labels = createDataSet()
print(dataSet)
print(labels)
print(calcShannonEnt(dataSet))
print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
featLabels = []
myTree = createTree(dataSet,labels,featLabels)
print(myTree)
createPlot(myTree)

#测试
test = [0,0]
result = classify(myTree,featLabels,test)
print("\n测试0,0的决策树结果:")
if result == 'yes':
print("发放贷款")
else:
print("不予贷款")

#存储决策树
storeTree(myTree,'classifierStore.txt')

#载入决策树
print("\n载入决策树:")
loadMyTree = grabTree('classifierStore.txt')
print(loadMyTree)

运行代码截图

image-20241209152040685

六、实验结论

​ ID3算法使用信息增益进行特征选择,集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差。特征所对应的信息增益值最大,该特征就为最优特征,也就是说信息增益越大,越应该放在决策树的上层。

​ 利用递归方法建树,每次对象为当前最优特征。

​ 递归函数的第一个停止条件是所有的类标签都相同,递归函数第二个停止条件是使用完数据集中所有的特征,即数据集不能继续划分;字典变量myTree储存了树的所有信息,bestFeature则是当前最优特征。