#导入包
from math import log
from operator import itemgetter
from graphviz import Digraph
# 构造数据集
def createDataset():
dataset = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
]
label = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
return dataset,label
# 计算香农熵
def calcShannonEnt(dataset,feat):
numEntropies = len(dataset) #数据总量
labelCounts = {} #存放特征及对应数量
for featval in dataset: #计算每种类别的总数
feature = featval[feat] #feature是每行中某一个的特征名称
if feature not in labelCounts.keys(): #如果特征集中没有这个新特征,就添加进来,并且赋初值为0
labelCounts[feature] = 0
labelCounts[feature]+=1 #对应的特征数 总量加1
entropy = 0
for key in labelCounts:
prob = labelCounts[key]/numEntropies #计算每个特征的概率
entropy -= prob*(log(prob,2)) # H = sum(p(xi)*log2(p(xi)))
return entropy
#按照某个特征划分数据集
def splitDtatset(dataset,index,value):
split_result = [] #存放按照特征划分之后的数据集
for data in dataset:
if data[index]==value:
split_result.append(data[:index]+data[index+1:])
return split_result
#选择最好的特征(即信息增益最大的特征)
def chooseBestFeature(dataset):
numFeature = len(dataset[0])-1 #特征个数,最后一列是分类结果,所以删掉
baseEntropy = calcShannonEnt(dataset,-1) #数据集的经验熵
bestFeat = -1
bestinfoGain = 0
bestGainRate = 0
for i in range(numFeature): #遍历所有的特征
featlist = [rowdata[i] for rowdata in dataset] #存放着所有的特征
uniquevals = set(featlist) # 对特征集进行去重操作
newEntropy = 0
selfEntropy = -1 #自身经验熵
for value in uniquevals:
subDataset = splitDtatset(dataset,i,value)
prob = len(subDataset)/len(dataset)
newEntropy = prob*calcShannonEnt(subDataset,i)
selfEntropy = calcShannonEnt(subDataset, i)
infoGain = baseEntropy-newEntropy
## ID3
if infoGain>bestinfoGain:
bestinfoGain = infoGain
bestFeat = i
## C4.5
# if selfEntropy==0:
# continue
# GainRate = infoGain/selfEntropy #增益比
# if GainRate>bestGainRate:
# bestGainRate = GainRate
# bestFeat = i
return bestFeat
#按照分类后各个特征的信息增益进行排序
def majorityEnt(classlist):
c_count = {}
for i in classlist:
if i not in c_count.keys():
c_count[i]=0
c_count[i] += 1
classout = sorted(c_count.items(),key=itemgetter(1),reverse=True)
print(classout[0][0])
return classout[0][0] #一维是元素二维是对应的元素个数
#递归构建决策树
def createTree(dataset,labels):
classlist = [rowdata[-1] for rowdata in dataset]
if classlist.count(classlist[0]) == len(classlist):
return classlist[0]
if len(dataset[0])==1:
return majorityEnt(classlist)
bestFeat = chooseBestFeature(dataset)
bestLab = labels[bestFeat]
mytree = {bestLab:{}}
del(labels[bestFeat])
featvalues = [rowdata[bestFeat] for rowdata in dataset]
uniquelvalues = set(featvalues)
for value in uniquelvalues:
subLabels = labels[:]
mytree[bestLab][value]=createTree(splitDtatset(dataset,bestFeat,value),subLabels)
return mytree
#可视化展示
def plot_model(tree, name):
g = Digraph("G", filename=name, format='png', strict=False,encoding="utf-8")
first_label = list(tree.keys())[0]
g.node("0", first_label,fontname="Kaiti")
_sub_plot(g, tree, "0")
g.view()
root = "0"
def _sub_plot(g, tree, inc):
global root
first_label = list(tree.keys())[0]
ts = tree[first_label]
for i in ts.keys():
if isinstance(tree[first_label][i], dict):
root = str(int(root) + 1)
g.node(root, list(tree[first_label][i].keys())[0],fontname="Kaiti")
g.edge(inc, root, str(i),fontname="Kaiti")
_sub_plot(g, tree[first_label][i], root)
else:
root = str(int(root) + 1)
g.node(root, tree[first_label][i],fontname="Kaiti")
g.edge(inc, root, str(i),fontname="Kaiti")
#主函数
if __name__=="__main__":
dataset,labels = createDataset()
tree = createTree(dataset,labels)
plot_model(tree,"decision_tree.gv")
print(tree)
网友评论