美文网首页
java实现一个简单的机器学习的demo

java实现一个简单的机器学习的demo

作者: freelands | 来源:发表于2017-01-08 15:11 被阅读919次

    因为最近忙着一个比赛,想用机器学习的方法来实现,因为我们用的是Java,所以就用到了weka,weka是java关于机器学习的库,weka的jar包可以去官网下载.

    1.准备数据

    在项目里创建一个txt文件,然后把下面的数据放进去

    @relation weather  
      
    @attribute outlook {sunny, overcast, rainy}  
    @attribute temperature numeric  
    @attribute humidity numeric  
    @attribute windy {TRUE, FALSE}  
    @attribute play {yes, no}  
      
    @data  
    sunny,85,85,FALSE,no  
    sunny,80,90,TRUE,no  
    overcast,83,86,FALSE,yes  
    rainy,70,96,FALSE,yes  
    rainy,68,80,FALSE,yes  
    rainy,65,70,TRUE,no  
    overcast,64,65,TRUE,yes  
    sunny,72,95,FALSE,no  
    sunny,69,70,FALSE,yes  
    rainy,75,80,FALSE,yes  
    sunny,75,70,TRUE,yes  
    overcast,72,90,TRUE,yes  
    overcast,81,75,FALSE,yes  
    rainy,71,91,TRUE,no  
    

    这里的数据主要是样本数据,一部分用于训练集,一部分用于测试集。

    训练集

    我的理解就是,选择适当的模型后,通过一部分数据进行训练,然后这个模型就有了初始的一个决策逻辑,当然这都和你的数据有关。

    测试集

    我觉得就是,在模型有了决策逻辑之后,然后通过一些数据,进行测试,查看模型的准确度。

    2.选择合适的模型

    
    import java.io.BufferedReader;  
    import java.io.FileNotFoundException;  
    import java.io.FileReader;  
    import weka.classifiers.Classifier;  
    import weka.classifiers.Evaluation;  
    import weka.classifiers.evaluation.NominalPrediction;  
    import weka.classifiers.rules.DecisionTable;  
    import weka.classifiers.rules.PART;  
    import weka.classifiers.trees.DecisionStump;  
    import weka.classifiers.trees.J48;  
    import weka.core.FastVector;  
    import weka.core.Instances;  
       
    public class WekaTest {  
        public static BufferedReader readDataFile(String filename) {  
            BufferedReader inputReader = null;  
       
            try {  
                inputReader = new BufferedReader(new FileReader(filename));  
            } catch (FileNotFoundException ex) {  
                System.err.println("File not found: " + filename);  
            }  
       
            return inputReader;  
        }  
       
        public static Evaluation classify(Classifier model,  
                Instances trainingSet, Instances testingSet) throws Exception {  
            Evaluation evaluation = new Evaluation(trainingSet);  
       
            model.buildClassifier(trainingSet);  
            evaluation.evaluateModel(model, testingSet);  
       
            return evaluation;  
        }  
       
        public static double calculateAccuracy(FastVector predictions) {  
            double correct = 0;  
       
            for (int i = 0; i < predictions.size(); i++) {  
                NominalPrediction np = (NominalPrediction) predictions.elementAt(i);  
                if (np.predicted() == np.actual()) {  
                    correct++;  
                }  
            }  
       
            return 100 * correct / predictions.size();  
        }  
       
        public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {  
            Instances[][] split = new Instances[2][numberOfFolds];  
       
            for (int i = 0; i < numberOfFolds; i++) {  
                split[0][i] = data.trainCV(numberOfFolds, i);  
                split[1][i] = data.testCV(numberOfFolds, i);  
            }  
       
            return split;  
        }  
       
        public static void main(String[] args) throws Exception {  
            BufferedReader datafile = readDataFile("weather.txt");  
       
            Instances data = new Instances(datafile);  
            data.setClassIndex(data.numAttributes() - 1);  
       
            // Do 10-split cross validation  
            Instances[][] split = crossValidationSplit(data, 10);  
       
            // Separate split into training and testing arrays  
            Instances[] trainingSplits = split[0];  
            Instances[] testingSplits = split[1];  
       
            // Use a set of classifiers  
            Classifier[] models = {   
                    new J48(), // a decision tree  
                    new PART(),   
                    new DecisionTable(),//decision table majority classifier  
                    new DecisionStump() //one-level decision tree  
            };  
       
            // Run for each model  
            for (int j = 0; j < models.length; j++) {  
       
                // Collect every group of predictions for current model in a FastVector  
                FastVector predictions = new FastVector();  
       
                // For each training-testing split pair, train and test the classifier  
                for (int i = 0; i < trainingSplits.length; i++) {  
                    Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);  
       
                    predictions.appendElements(validation.predictions());  
       
                    // Uncomment to see the summary for each training-testing pair.  
                    //System.out.println(models[j].toString());  
                }  
       
                // Calculate overall accuracy of current classifier on all splits  
                double accuracy = calculateAccuracy(predictions);  
       
                // Print current classifier's name and accuracy in a complicated,  
                // but nice-looking way.  
                System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "  
                        + String.format("%.2f%%", accuracy)  
                        + "\n---------------------------------");  
            }  
       
        }  
    }  
    

    3.运行结果

    Accuracy of J48: 50.00%
    ---------------------------------
    Accuracy of PART: 50.00%
    ---------------------------------
    Accuracy of DecisionTable: 64.29%
    ---------------------------------
    Accuracy of DecisionStump: 21.43%
    ---------------------------------
    

    相关文章

      网友评论

          本文标题:java实现一个简单的机器学习的demo

          本文链接:https://www.haomeiwen.com/subject/biiqbttx.html