美文网首页
仓库:C4.5的python实现

仓库:C4.5的python实现

作者: ylylhl | 来源:发表于2018-03-29 21:54 被阅读0次

开始后的第三周,尝试着写了一点。
edit:于4月8日增加连续值处理,具体参见https://www.jianshu.com/p/f7219841916a
edit:4月10日消减部分代码量
数据集如下:

OUTLOOK TEMPERATURE HUMIDITY WINDY ACTIVATE
sunny 85 85 weak no
sunny 80 90 strong no
overcast 83 78 weak yes
rainy 70 96 weak yes
rainy 68 80 weak yes
rainy 65 70 strong no
overcast 64 65 strong yes
sunny 72 95 weak no
sunny 69 70 weak yes
rainy 75 80 weak yes
sunny 75 70 strong yes
overcast 72 90 strong yes
overcast 81 75 weak yes
rainy 71 80 strong no

思路很迷,如果有谁不幸看到了这篇请不要当做参考会被带到沟里的……
有1点问题,关于max默认取第一个值……比如两个属性的信息增益率相等且都最大,取不同的属性作为节点会生成不同的树(。
所以该怎么办哦www

#coding = utf-8
'''
●C4.5 v1.1
作者:Kadoya
创建日期:18.3.24
最近修改时间:18.4.10
程序目的:一个树:D
主要算法说明:
        1.说白了(大概)就是权重排序,或者看下面这一堆废话也行。
            读取文档→拿到属性→调stardust()拿到具体数据(列表)→
            调continuous()处理连续值→调entropy()拿到结果熵(playgolf?)→
            将数据集,属性和结果熵传入gotcha(),即用于找到下一个节点和子集的函数
            ①如果结果是纯的或满足某种条件则return(对应书上第2~4行)
            ②调gainratio()得到每个属性对应的信息增益率并排序,取最大的作为节点
            ③生成基于该属性不同取值的不同子集,计算子集的结果熵并再次调用gotcha()(第11~14行)
            返回tree并输出(第15行)
        2.凑不齐三点了不管了x
程序备注:
        1.蜃楼啥时候上steam啊……
        2.教授!!!来我迦了!!!毒针地狱.jpg 为了刷毒针正式移民乌鲁克……
        3.一人血书医生早日实装(
更新历史:
        3.28 v1.0
            ①基本功能完成,小bug不计其数,智障操作不计其数
        4.8 update v1.05
            ①增加连续值处理
            ②可以引random达到增益率相同时的选择伪随机而非直接[0]选择第一个但是没改
        4.10 update v1.1
            ①连续值那里原本的int改了float,免得有小数
            ②精简部分代码,纯代码部分压缩到80行

说实话上次好像说了有个bug要改,想不起来了(……
'''
from math import *

def stardust(data):#具体数据,参数:读到的源文档
    attributes=[]#属性:outlook,temperature,humidity,windy之类
    for i in data[0].split(' ')[0:-1]:#最后一列是结果(play golf?)所以不算
        attributes.append(i)          #拿到所有属性
    lost_star=[]
    for i in data[1:]:#第一行是说明所以不算,下一个.jpg
        lost_star.append(i.strip('\n').split(' ')[0:])#删除换行符,以空格分割
    lost_star=continuous(lost_star,1)
    lost_star=continuous(lost_star,2)
    en = entropy(lost_star,len(lost_star[0])-1)
    return attributes,lost_star,en

def get_subset(lost_star,value,j):#生成子集     参数:数据集,属性具体取值,所在列数
    subset=[lost_star[i] for i in range(len(lost_star)) if lost_star[i][j]==value]
    return subset

def continuous(lost_star,l):#连续值的阈值     参数:数据集,连续值所在列数
    #不转float会出现15,7,89这种情况,大概是因为str先比第一位,大概。
    gear=sorted(lost_star,key=lambda st: float(st[l]))
    #可能的分割阈值点,!=是因为把取值相同的类分成不同的类没有意义
##    possible=[]
##    for i in range(len(lost_star)):
##        if lost_star[i][-1]!=lost_star[i+1][-1]:
##            possible+=[str((int(lost_star[i][l])+int(lost_star[i+1][l]))/2)]
##            print(possible)
    possible=[str((float(gear[i][l])+float(gear[i+1][l]))/2) for i in range(len(gear)-1) if gear[i][-1]!=gear[i+1][-1]]
    scrap={}
    for j in possible:
        subset1=[gear[i] for i in range(len(gear)) if float(gear[i][l])<float(j)]#小于该分割点的子集
        subset2=[gear[i] for i in range(len(gear)) if float(gear[i][l])>=float(j)]#大于ry
        scrap[j]=len(subset1)/len(lost_star)*entropy(subset1,-1)+len(subset2)/len(lost_star)*entropy(subset2,-1)#计算不同分割点的熵
    threshold=[i for i,j in scrap.items() if j == min(scrap.values())][0]#阈值,增益和熵成反比所以min
    for i in lost_star:
        #解释一下这个学来的操作,大概就相当于['>=','<'][0 or 1]
        #a cool trick, right? :D
        i[l]=['>=','<'][float(i[l])>=float(threshold)]+threshold
    return lost_star
    
def entropy(lost_star,l):#计算信息熵  参数:数据集,列数
    star={}
    entropy=0
    for i in range(len(lost_star)):
        if lost_star[i][l] not in star.keys(): #一个字典
            star[lost_star[i][l]]=0 #key:value = 结果取值:个数
        star[lost_star[i][l]]+=1
    for i in star.values():
        entropy += -(i/sum(star.values()))*log(i/sum(star.values()),2)#熵,∑-p*log2(p)
    return entropy

def gainratio(en,e,j,lost_star):#信息增益率  参数:结果熵,属性熵,列数,数据集
    if e == 0:
        return 0
    entropy_info=0 #属性∑,即∑|Dv|/D * Entropy(playgolf? in Dv)
    a_value=set([lost_star[i][j] for i in range(len(lost_star))])#属性可能的取值
    for i in a_value:
        subset=get_subset(lost_star,i,j)#得到属性取值对应子集
        entropy_info += len(subset)/len(lost_star)*entropy(subset,len(subset[0])-1)#属性∑ = 取值权重*子集熵
    return (en - entropy_info)/e    #(结果熵-属性∑)/属性熵 = gainratio

def gotcha(lost_star,attributes,en):#生成树  参数:数据集,属性,结果熵
    if en == 0:
        return lost_star[0][-1]
    if len(lost_star[0])==2:
        for i in range(len(lost_star)):
            if lost_star[i][-1] not in star:
                star[lost_star[i][-1]]=0
            star[lost_star[i][-1]]+=1
        for i,j in lost_star.items():
            if j == max(lost_star.values()):
                return i
    #↑如果结果熵是0(结果唯一)或属性只剩一个
    dust={}
    for i in range(len(attributes)):
        dust[attributes[i]]=gainratio(en,entropy(lost_star,i),i,lost_star)#属性对应信息增益率
    #↓试图压缩行数,写了个智熄操作搞到节点
    branch=[i for i,j in dust.items() if j == max(dust.values())][0]
    num=attributes.index(branch)#得到被选作节点的属性所在列数
    del(attributes[num])        #把该属性从属性列表中删除
    b_value=set([lost_star[i][num] for i in range(len(lost_star))])#被选作节点的属性可能的取值
    tree={branch:{}}
    #写个循环调get_subset拿到属性不同取值对应的子集然后递归
    for i in b_value:
        subset=get_subset(lost_star,i,num)
        for a in range(len(subset)):
            subset[a].remove(subset[a][num])#去除被选作节点的属性
        entro=entropy(subset,len(subset[0])-1)#子集的结果熵
        subattributes = attributes[:]#化腐朽为神奇.jpg 如果不备份直接传会炸
        tree[branch][i]=gotcha(subset,subattributes,entro)#拿到下一个节点ry
    return tree

if __name__ == '__main__':
    path = r'data.txt'
    with open(path) as f:
        data = f.readlines()

    attributes,lost_star,en = stardust(data)#拿到属性,具体数据集和结果熵
    tree=gotcha(lost_star,attributes,en)
    print(tree)

图像代码(复制来的)

# -*- coding: cp936 -*-  
import matplotlib.pyplot as plt  
from kadoya import *  #c4.5算法的py文件名
decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8')  
leafNode = dict(boxstyle = 'round4', fc = '0.8')  
arrow_args = dict(arrowstyle = '<-')  
  
def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',xytext = centerPt, textcoords = 'axes fraction',va = 'center', ha = 'center', bbox = nodeType, arrowprops = arrow_args)  
  
# 使用文本注解绘制树节点  
def createPlot():  
    fig = plt.figure(1, facecolor = 'white')  
    fig.clf()  
    createPlot.ax1 = plt.subplot(111, frameon = False)  
    plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode)  
    plotNode('a leaf node', (0.8, 0.1), (0.3,0.8), leafNode)  
    plt.show()  
  
#获取叶子节点数目和树的层数  
def getNumLeafs(myTree):  
    numLeafs = 0  
    #firstStr = myTree.keys()[0]
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]  
    for key in secondDict.keys():  
        if(type(secondDict[key]).__name__ == 'dict'):  
            numLeafs += getNumLeafs(secondDict[key])  
        else: numLeafs += 1  
    return numLeafs  
  
def getTreeDepth(myTree):  
    maxDepth = 0  
    #firstStr = myTree.keys()[0]
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]  
    for key in secondDict.keys():  
        if(type(secondDict[key]).__name__ == 'dict'):  
            thisDepth = 1+ getTreeDepth(secondDict[key])  
        else: thisDepth = 1  
        if thisDepth > maxDepth: maxDepth = thisDepth  
    return maxDepth  
  
#更新createPlot代码以得到整棵树  
def plotMidText(cntrPt, parentPt, txtString):  
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)  
  
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on  
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree  
    depth = getTreeDepth(myTree)  
    #firstStr = myTree.keys()[0]     #the text label for this node should be this
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
    plotMidText(cntrPt, parentPt, nodeTxt)  
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  
    secondDict = myTree[firstStr]  
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
    for key in secondDict.keys():  
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes     
            plotTree(secondDict[key],cntrPt,str(key))        #recursion  
        else:   #it's a leaf node print the leaf node  
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
#if you do get a dictonary you know it's a tree, and the first element will be another dict  
  
def createPlot(inTree):  
    fig = plt.figure(1, facecolor='white')  
    fig.clf()  
    axprops = dict(xticks=[], yticks=[])  
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks  
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses   
    plotTree.totalW = float(getNumLeafs(inTree))  
    plotTree.totalD = float(getTreeDepth(inTree))  
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
    plotTree(inTree, (0.5,1.0), '')  
    plt.show() 

path = r'data.txt'
with open(path) as f:
    data = f.readlines()
attributes,lost_star,en = stardust(data)#拿到具体数据集和结果熵
tree=gotcha(lost_star,attributes,en)
print(tree)
createPlot(tree)

相关文章

网友评论

      本文标题:仓库:C4.5的python实现

      本文链接:https://www.haomeiwen.com/subject/gobvcftx.html