美文网首页深度学习实战演练
第六章(1.6)机器学习实战——打造属于自己的贝叶斯分类器

第六章(1.6)机器学习实战——打造属于自己的贝叶斯分类器

作者: _两只橙_ | 来源:发表于2019-01-13 18:27 被阅读3次

    github项目地址:https://github.com/liangzhicheng120/bayes

    一、简介

    • 项目使用SpringBoot做了一层web封装

    • 项目使用的分词工具hanlp

    • 项目使用JDK8

    • 贝叶斯法则
      事件A在事件B(发生)的条件下的概率,与事件B在事件A的条件下的概率是不一样的;然而,这两者是有确定的关系,贝叶斯法则就是这种关系的陈述。

    • 贝叶斯术语
      [图片上传失败...(image-d286a7-1547375244426)]
      其中L(A|B)是在B发生的情况下A发生的可能性。
      在贝叶斯法则中,每个名词都有约定俗成的名称:
      Pr(A)A的先验概率或边缘概率。之所以称为"先验"是因为它不考虑任何B方面的因素。
      Pr(A|B)是已知B发生后A的条件概率,也由于得自B的取值而被称作A的后验概率。
      Pr(B|A)是已知A发生后B的条件概率,也由于得自A的取值而被称作B的后验概率。
      Pr(B)B的先验概率或边缘概率,也作标准化常量(normalized constant)。
      后验概率 = (似然度 * 先验概率)/标准化常量 也就是说,后验概率与先验概率和似然度的乘积成正比。

    • 贝叶斯推断的含义
      对条件概率公式进行变形,可以得到如下形式:

    [图片上传失败...(image-3fbd35-1547375244427)]

    • 我们把P(A)称为"先验概率"(Prior probability),即在B事件发生之前,我们对A事件概率的一个判断。P(A|B)称为"后验概率"(Posterior probability),即在B事件发生之后,我们对A事件概率的重新评估。P(B|A)/P(B)称为"可能性函数"(Likelyhood),这是一个调整因子,使得预估概率更接近真实概率。
      后验概率 = 先验概率 x 调整因子

    • 这就是贝叶斯推断的含义。我们先预估一个"先验概率",然后加入实验结果,看这个实验到底是增强还是削弱了"先验概率",由此得到更接近事实的"后验概率"。
      在这里,如果"可能性函数"P(B|A)/P(B)>1,意味着"先验概率"被增强,事件A的发生的可能性变大;如果"可能性函数"=1,意味着B事件无助于判断事件A的可能性;如果"可能性函数"<1,意味着"先验概率"被削弱,事件A的可能性变小。

    二、例子

    • 别墅和狗
      一座别墅在过去的 20 年里一共发生过 2 次被盗,别墅的主人有一条狗,狗平均每周晚上叫 3 次,在盗贼入侵时狗叫的概率被估计为 0.9,问题是:在狗叫的时候发生入侵的概率是多少?
      我们假设 A 事件为狗在晚上叫,B 为盗贼入侵,则P(A) = 3 / 7,P(B)=2/(20·365)=2/7300,P(A | B) = 0.9,按照公式很容易得出结果:P(B|A)=0.9*(2/7300)/(3/7)=0.00058

    三、实战代码

    • 模型文件(classify.txt
    火影忍者 火影
    火影忍者 秘传
    火影忍者 大蛇丸
    火影忍者 剧场版
    火影忍者 动作
    火影忍者 激斗
    火影忍者 战斗
    火影忍者 转生
    火影忍者 佐助
    火影忍者 村子
    火影忍者 第六代火影
    火影忍者 克拉
    火影忍者 卡卡
    火影忍者 带土
    火影忍者 疾风
    火影忍者 自来
    火影忍者 火影忍者
    火影忍者 仙人
    火影忍者 六道
    火影忍者 大战
    火影忍者 九尾
    火影忍者 忍者
    火影忍者 究极
    火影忍者 纲手
    火影忍者 鸣人
    火影忍者 木叶
    火影忍者 忍术
    火影忍者 秽土
    火影忍者 宇智波
    火影忍者 九尾妖狐
    火影忍者 阿飞
    海贼王 正文
    海贼王 尾田
    海贼王 海贼王
    海贼王 弗兰奇
    海贼王 草帽
    海贼王 海贼
    海贼王 武海
    海贼王 事件
    海贼王 悬赏
    海贼王 第话
    海贼王 梦想
    海贼王 血型
    海贼王 王下
    海贼王 航路
    海贼王 历史
    海贼王 德雷斯
    海贼王 船长
    海贼王 恶魔
    海贼王 路飞
    海贼王 漫画
    海贼王 超新星
    海贼王 罗萨篇
    海贼王 世界
    海贼王 果实
    海贼王 冥王
    海贼王 荣一郎
    海贼王 海贼团
    海贼王 司法
    海贼王 超人
    海贼王 成为
    海贼王 寻找
    海贼王 传说
    海贼王 海贼王
    海贼王 中海
    海贼王 罗杰
    海贼王 秘宝
    海贼王 留下
    海贼王 伙伴
    海贼王 ONE
    海贼王 PIECE
    海贼王 海贼
    海贼王 志同道合
    海贼王 扬起
    海贼王 实现
    龙珠 复活
    龙珠 仙人
    龙珠 武道
    龙珠 得到
    龙珠 军团
    龙珠 找寻
    龙珠 魔王
    龙珠 饺子
    龙珠 特典
    龙珠 打败
    龙珠 花梨
    龙珠 缎带
    龙珠 发售日期
    龙珠 龙珠
    龙珠 天津
    龙珠 七龙珠
    龙珠 比克
    龙珠 天神
    龙珠 修练
    龙珠 悟空
    龙珠 封入
    龙珠 次郎
    龙珠 拉夫
    龙珠 封印
    龙珠 许愿
    龙珠 兵卫
    龙珠 一武道
    龙珠 动画
    
    • TestBayes.java
    package com.xinrui.util;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    import org.apache.commons.io.Charsets;
    import org.apache.commons.io.FileUtils;
    import org.apache.commons.lang.StringUtils;
    import org.apache.log4j.Logger;
    
    import com.hankcs.hanlp.HanLP;
    
    /**
     * 贝叶斯计算器主体类
     */
    public class Bayes {
    
        private static Logger logger = Logger.getLogger(Bayes.class);
    
        /**
         * 将原训练元组按类别划分
         * 
         * @param datas
         *            训练元组
         * @return Map<类别,属于该类别的训练元组>
         */
        public static Map<String, ArrayList<ArrayList<String>>> classifyByCategory(ArrayList<ArrayList<String>> datas) {
            if (datas == null) {
                return null;
            }
    
            Map<String, ArrayList<ArrayList<String>>> map = new HashMap<String, ArrayList<ArrayList<String>>>();
            ArrayList<String> singleTrainning = null;
            String classificaion = "";
            for (int i = 0; i < datas.size(); i++) {
                singleTrainning = datas.get(i);
                classificaion = singleTrainning.get(0);
                singleTrainning.remove(0);
                if (map.containsKey(classificaion)) {
                    map.get(classificaion).add(singleTrainning);
                } else {
                    ArrayList<ArrayList<String>> list = new ArrayList<ArrayList<String>>();
                    list.add(singleTrainning);
                    map.put(classificaion, list);
                }
            }
    
            return map;
        }
    
        /**
         * 在训练数据的基础上预测测试元组的类别
         * 
         * @param datas
         *            训练元组
         * @param testData
         *            测试元组
         * @return 测试元组的类别
         */
        public static String predictClassify(ArrayList<ArrayList<String>> datas, ArrayList<String> testData) {
    
            if (datas == null || testData == null) {
                return null;
            }
    
            int maxPIndex = -1;
            Map<String, ArrayList<ArrayList<String>>> map = classifyByCategory(datas);
            Object[] classes = map.keySet().toArray();
            double maxProbability = 0.0;
            for (int i = 0; i < map.size(); i++) {
                double p = 0.0;
                for (int j = 0; j < testData.size(); j++) {
                    p += calProbabilityClassificationInKey(map, classes[i].toString(), testData.get(j));
                }
                if (p > maxProbability) {
                    maxProbability = p;
                    maxPIndex = i;
                }
            }
    
            return maxPIndex == -1 ? "其他" : classes[maxPIndex].toString();
        }
    
        /**
         * 在训练数据的基础上预测测试元组的类别
         * 
         * @param testData
         *            测试元组
         * @return 测试元组的类别
         * @throws Exception
         */
        public String predictClassify(ArrayList<String> testData, String mId) throws Exception {
            return predictClassify(read(mId), testData);
        }
    
        /**
         * 某一特征值在某一分类上的概率分布[ P(key|Classify) ]
         * 
         * @param classify
         *            某一分类特征向量集
         * @param value
         *            某一特征值
         * @return 概率分布
         */
        private static double calProbabilityKeyInClassification(ArrayList<ArrayList<String>> classify, String value) {
            if (classify == null || StringUtils.isEmpty(value)) {
                return 0.0;
            }
            int totleKeyCount = 0;
            int foundKeyCount = 0;
            ArrayList<String> featureVector = null; // 分类中的某一特征向量
            for (int i = 0; i < classify.size(); i++) {
                featureVector = classify.get(i);
                for (int j = 0; j < featureVector.size(); j++) {
                    totleKeyCount++;
                    if (featureVector.get(j).equalsIgnoreCase(value)) {
                        foundKeyCount++;
                    }
                }
            }
            return totleKeyCount == 0 ? 0.0 : 1.0 * foundKeyCount / totleKeyCount;
        }
    
        /**
         * 获得某一分类的概率 [ P(Classify) ]
         * 
         * @param classes
         *            分类集合
         * @param classify
         *            某一特定分类
         * @return 某一分类的概率
         */
        private static double calProbabilityClassification(Map<String, ArrayList<ArrayList<String>>> map, String classify) {
            if (map == null | StringUtils.isEmpty(classify)) {
                return 0;
            }
            Object[] classes = map.keySet().toArray();
            int totleClassifyCount = 0;
            for (int i = 0; i < classes.length; i++) {
                totleClassifyCount += map.get(classes[i].toString()).size();
            }
            return 1.0 * map.get(classify).size() / totleClassifyCount;
        }
    
        /**
         * 获得关键词的总概率
         * 
         * @param map
         *            所有分类的数据集
         * @param key
         *            某一特征值
         * @return 某一特征值在所有分类数据集中的比率
         */
        private static double calProbabilityKey(Map<String, ArrayList<ArrayList<String>>> map, String key) {
            if (map == null || StringUtils.isEmpty(key)) {
                return 0;
            }
            int foundKeyCount = 0;
            int totleKeyCount = 0;
            Object[] classes = map.keySet().toArray();
            for (int i = 0; i < map.size(); i++) {
                ArrayList<ArrayList<String>> classify = map.get(classes[i]);
                ArrayList<String> featureVector = null; // 分类中的某一特征向量
                for (int j = 0; j < classify.size(); j++) {
                    featureVector = classify.get(j);
                    for (int k = 0; k < featureVector.size(); k++) {
                        totleKeyCount++;
                        if (featureVector.get(k).equalsIgnoreCase(key)) {
                            foundKeyCount++;
                        }
                    }
                }
            }
            return totleKeyCount == 0 ? 0.0 : 1.0 * foundKeyCount / totleKeyCount;
        }
    
        /**
         * 计算在出现key的情况下,是分类classify的概率 [ P(Classify | key) ]
         * 
         * @param map
         *            所有分类的数据集
         * @param classify
         *            某一特定分类
         * @param key
         *            某一特定特征
         * @return P(Classify | key)
         */
        private static double calProbabilityClassificationInKey(Map<String, ArrayList<ArrayList<String>>> map, String classify, String key) {
            ArrayList<ArrayList<String>> classifyList = map.get(classify);
            double pkc = calProbabilityKeyInClassification(classifyList, key); // p(key|classify)
            double pc = calProbabilityClassification(map, classify); // p(classify)
            double pk = calProbabilityKey(map, key); // p(key)
            return pk == 0 ? 0 : pkc * pc / pk; // p(classify | key)
        }
    
        /**
         * 读取训练文档中的训练数据 并进行封装
         * 
         * @param filePath
         *            训练文档的路径
         * @return 训练数据集
         * @throws Exception
         */
        public static ArrayList<ArrayList<String>> read(String clzss) throws Exception {
            ArrayList<String> singleTrainning = null;
            ArrayList<ArrayList<String>> trainningSet = new ArrayList<ArrayList<String>>();
            List<String> datas = new ArrayList<String>(FileUtils.readLines(new File(clzss), Charsets.UTF_8));
            if (datas.size() == 0) {
                logger.error("[" + "模型文件加载错误" + "]" + clzss);
                throw new Exception("模型文件加载错误!");
            }
            for (int i = 0; i < datas.size(); i++) {
                String[] characteristicValues = datas.get(i).split(" ");
                singleTrainning = new ArrayList<String>();
                for (int j = 0; j < characteristicValues.length; j++) {
                    if (StringUtils.isNotEmpty(characteristicValues[j])) {
                        singleTrainning.add(characteristicValues[j]);
                    }
                }
                trainningSet.add(singleTrainning);
            }
            return trainningSet;
        }
    
        /**
         * 
         * @param fileName
         *            训练文件
         * @param size
         *            关键词个数
         */
        public static void trainBayes(String fileName, String mId, int size) {
            try {
                Bayes bayes = new Bayes();
                BufferedReader reader = new BufferedReader(new FileReader(fileName));
                String line = null;
                int total = 0;
                int right = 0;
                long start = System.currentTimeMillis();
                while ((line = reader.readLine()) != null) {
                    ArrayList<String> testData = (ArrayList<String>) HanLP.extractKeyword(line, size);
                    String classification = bayes.predictClassify(testData, mId);
                    if (classification.equals(fileName.split("\\.")[0])) {
                        right += 1;
                    }
                    System.out.print("\n分类:" + classification);
                    total++;
                }
                reader.close();
                long end = System.currentTimeMillis();
                System.out.println("正确分类:" + right);
                System.out.println("总行数:" + total);
                System.out.println("正确率:" + MathUtil.div(right, total, 4) * 100 + "%");
                System.out.println("程序运行时间: " + (end - start) / 1000 + "s");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    
    }
    
    
    • TestBayes.java
    package com.xinrui.test;
    
    import java.util.ArrayList;
    
    import com.hankcs.hanlp.HanLP;
    import com.xinrui.util.Bayes;
    
    public class TestBayes {
        public static void main(String[] args) throws Exception {
            // 获取当前工程存放位置
            String path = TestBayes.class.getResource("").getPath();
            String classPath = path.substring(0, path.indexOf("/com/xinrui"));
            // 模型文件存放位置
            String modelName = classPath + "/model/classify_model.txt";
            ArrayList<ArrayList<String>> model = Bayes.read(modelName);
            // 抽取10个关键词组成一个元祖
            ArrayList<String> testData = (ArrayList<String>) HanLP
                    .extractKeyword(
                            "时值“大海贼时代”,为了寻找传说中海贼王罗杰所留下的大秘宝“ONE PIECE”,无数海贼扬起旗帜,互相争斗。有一个梦想成为海盗的少年叫路飞,他因误食“恶魔果实”而成为了橡皮人,在获得超人能力的同时付出了一辈子无法游泳的代价。十年后,路飞为实现与因救他而断臂的香克斯的约定而出海,他在旅途中不断寻找志同道合的伙伴,开始了以成为海贼王为目标的伟大的冒险旅程[9]  ",
                            15);
            // 输出预测结果
            System.out.println(Bayes.predictClassify(model, testData));
        }
    }
    
    
    • 结果


      image

    关注我的技术公众号,每天推送优质文章
    关注我的音乐公众号,工作之余放松自己
    微信扫一扫下方二维码即可关注:


    音乐公众号
    技术公众号

    相关文章

      网友评论

        本文标题:第六章(1.6)机器学习实战——打造属于自己的贝叶斯分类器

        本文链接:https://www.haomeiwen.com/subject/sosldqtx.html