美文网首页
python实现贝叶斯推断——垃圾邮件分类

python实现贝叶斯推断——垃圾邮件分类

作者: AlanLau | 来源:发表于2017-05-11 20:47 被阅读780次

    理论

    理论强推阮一峰大神的个人网站
    1.贝叶斯推断及其互联网应用(一):定理简介
    2.贝叶斯推断及其互联网应用(二):过滤垃圾邮件
    非常简明易懂,然后我下面的代码就是实现上面过滤垃圾邮件算法的。

    前期准备

    数据来源

    数据来源于《机器学习实战》中的第四章朴素贝叶斯分类器的实验数据。数据书上只提供了50条数据(25条正常邮件,25条垃圾邮件),感觉数据量偏小,以后打算使用scikit-learn提供的iris数据。

    这里需要说明下,贝叶斯推断和朴素贝叶斯不是同一个概念

    数据准备

    和很多机器学习一样,数据需要拆分成训练集和测试集。
    拆分训练集和测试集的思路如下:
    1.遍历包含50条数据的email文件夹,获取文件列表
    2.使用random.shuffle()函数打乱列表
    3.截取乱序后的文件列表前10个文件路径,并转移到test文件夹下,作为测试集。
    代码实现:

    # -*- coding: utf-8 -*-
    # @Date     : 2017-05-09 13:06:56
    # @Author   : Alan Lau (rlalan@outlook.com)
    # @Language : Python3.5
    
    # from fwalker import fun
    import random
    # from reader import writetxt, readtxt
    import shutil
    import os
    
    def fileWalker(path):
        fileArray = []
        for root, dirs, files in os.walk(path):
            for fn in files:
                eachpath = str(root+'\\'+fn)
                fileArray.append(eachpath)
        return fileArray
    
    def main():
        filepath = r'..\email'
        testpath = r'..\test'
        files = fileWalker(filepath)
        random.shuffle(files)
        top10 = files[:10]
        for ech in top10:
            ech_name = testpath+'\\'+('_'.join(ech.split('\\')[-2:]))
            shutil.move(ech, testpath)
            os.rename(testpath+'\\'+ech.split('\\')[-1], ech_name)
            print('%s moved' % ech_name)
    
    
    if __name__ == '__main__':
        main()
    

    对代码中的fwalker、reader两个包有疑问的请前往python中import自己写的.pypython3文本读取与写入常用代码

    最后获取的文件列表如下:


    copy是备份数据,防止操作错误
    ham文件列表:


    spam文件列表:



    test文件列表:


    可见,数据准备后的测试集中,有7个垃圾邮件,3个正常的邮件。

    代码实现

    # -*- coding: utf-8 -*-
    # @Date     : 2017-05-09 09:29:13
    # @Author   : Alan Lau (rlalan@outlook.com)
    # @Language : Python3.5
    
    # from fwalker import fun
    # from reader import readtxt
    import os
    
    
    def readtxt(path,encoding):
        with open(path, 'r', encoding = encoding) as f:
            lines = f.readlines()
        return lines
    
    def fileWalker(path):
        fileArray = []
        for root, dirs, files in os.walk(path):
            for fn in files:
                eachpath = str(root+'\\'+fn)
                fileArray.append(eachpath)
        return fileArray
    
    def email_parser(email_path):
        punctuations = """,.<>()*&^%$#@!'";~`[]{}|、\\/~+_-=?"""
        content_list = readtxt(email_path, 'utf8')
        content = (' '.join(content_list)).replace('\r\n', ' ').replace('\t', ' ')
        clean_word = []
        for punctuation in punctuations:
            content = (' '.join(content.split(punctuation))).replace('  ', ' ')
            clean_word = [word.lower()
                          for word in content.split(' ') if len(word) > 2]
        return clean_word
    
    
    def get_word(email_file):
        word_list = []
        word_set = []
        email_paths = fileWalker(email_file)
        for email_path in email_paths:
            clean_word = email_parser(email_path)
            word_list.append(clean_word)
            word_set.extend(clean_word)
        return word_list, set(word_set)
    
    
    def count_word_prob(email_list, union_set):
        word_prob = {}
        for word in union_set:
            counter = 0
            for email in email_list:
                if word in email:
                    counter += 1
                else:
                    continue
            prob = 0.0
            if counter != 0:
                prob = counter/len(email_list)
            else:
                prob = 0.01
            word_prob[word] = prob
        return word_prob
    
    
    def filter(ham_word_pro, spam_word_pro, test_file):
        test_paths = fileWalker(test_file)
        for test_path in test_paths:
            email_spam_prob = 0.0
            spam_prob = 0.5
            ham_prob = 0.5
            file_name = test_path.split('\\')[-1]
            prob_dict = {}
            words = set(email_parser(test_path))
            for word in words:
                Psw = 0.0
                if word not in spam_word_pro:
                    Psw = 0.4
                else:
                    Pws = spam_word_pro[word]
                    Pwh = ham_word_pro[word]
                    Psw = spam_prob*(Pws/(Pwh*ham_prob+Pws*spam_prob))
                prob_dict[word] = Psw
            numerator = 1
            denominator_h = 1
            for k, v in prob_dict.items():
                numerator *= v
                denominator_h *= (1-v)
            email_spam_prob = round(numerator/(numerator+denominator_h), 4)
            if email_spam_prob > 0.5:
                print(file_name, 'spam', email_spam_prob)
            else:
                print(file_name, 'ham', email_spam_prob)
            # print(prob_dict)
            # print('******************************************************')
            # break
    
    
    def main():
        ham_file = r'..\email\ham'
        spam_file = r'..\email\spam'
        test_file = r'..\email\test'
        ham_list, ham_set = get_word(ham_file)
        spam_list, spam_set = get_word(spam_file)
        union_set = ham_set | spam_set
        ham_word_pro = count_word_prob(ham_list, union_set)
        spam_word_pro = count_word_prob(spam_list, union_set)
        filter(ham_word_pro, spam_word_pro, test_file)
    
    
    if __name__ == '__main__':
        main()
    
    

    实验结果

    ham_24.txt ham 0.0005
    ham_3.txt ham 0.0
    ham_4.txt ham 0.0
    spam_11.txt spam 1.0
    spam_14.txt spam 0.9999
    spam_17.txt ham 0.0
    spam_18.txt spam 0.9992
    spam_19.txt spam 1.0
    spam_22.txt spam 1.0
    spam_8.txt spam 1.0
    

    可见正确率为90%,实际上严格来说,应当将所有数据随机均分十组,每一组轮流作为一次测试集,剩下九组作为训练集,再将十次计算结果求均值,这个模型求出的分类效果才具有可靠性,其次,数据量小导致准确率较小的原因不排除在外。

    所有代码以及数据GITHUB

    相关文章

      网友评论

          本文标题:python实现贝叶斯推断——垃圾邮件分类

          本文链接:https://www.haomeiwen.com/subject/uslatxtx.html