因为最近忙着一个比赛,想用机器学习的方法来实现,因为我们用的是Java,所以就用到了weka,weka是java关于机器学习的库,weka的jar包可以去官网下载.
1.准备数据
在项目里创建一个txt文件,然后把下面的数据放进去
@relation weather
@attribute outlook {sunny, overcast, rainy}
@attribute temperature numeric
@attribute humidity numeric
@attribute windy {TRUE, FALSE}
@attribute play {yes, no}
@data
sunny,85,85,FALSE,no
sunny,80,90,TRUE,no
overcast,83,86,FALSE,yes
rainy,70,96,FALSE,yes
rainy,68,80,FALSE,yes
rainy,65,70,TRUE,no
overcast,64,65,TRUE,yes
sunny,72,95,FALSE,no
sunny,69,70,FALSE,yes
rainy,75,80,FALSE,yes
sunny,75,70,TRUE,yes
overcast,72,90,TRUE,yes
overcast,81,75,FALSE,yes
rainy,71,91,TRUE,no
这里的数据主要是样本数据,一部分用于训练集,一部分用于测试集。
训练集
我的理解就是,选择适当的模型后,通过一部分数据进行训练,然后这个模型就有了初始的一个决策逻辑,当然这都和你的数据有关。
测试集
我觉得就是,在模型有了决策逻辑之后,然后通过一些数据,进行测试,查看模型的准确度。
2.选择合适的模型
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.rules.DecisionTable;
import weka.classifiers.rules.PART;
import weka.classifiers.trees.DecisionStump;
import weka.classifiers.trees.J48;
import weka.core.FastVector;
import weka.core.Instances;
public class WekaTest {
public static BufferedReader readDataFile(String filename) {
BufferedReader inputReader = null;
try {
inputReader = new BufferedReader(new FileReader(filename));
} catch (FileNotFoundException ex) {
System.err.println("File not found: " + filename);
}
return inputReader;
}
public static Evaluation classify(Classifier model,
Instances trainingSet, Instances testingSet) throws Exception {
Evaluation evaluation = new Evaluation(trainingSet);
model.buildClassifier(trainingSet);
evaluation.evaluateModel(model, testingSet);
return evaluation;
}
public static double calculateAccuracy(FastVector predictions) {
double correct = 0;
for (int i = 0; i < predictions.size(); i++) {
NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
if (np.predicted() == np.actual()) {
correct++;
}
}
return 100 * correct / predictions.size();
}
public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {
Instances[][] split = new Instances[2][numberOfFolds];
for (int i = 0; i < numberOfFolds; i++) {
split[0][i] = data.trainCV(numberOfFolds, i);
split[1][i] = data.testCV(numberOfFolds, i);
}
return split;
}
public static void main(String[] args) throws Exception {
BufferedReader datafile = readDataFile("weather.txt");
Instances data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
// Do 10-split cross validation
Instances[][] split = crossValidationSplit(data, 10);
// Separate split into training and testing arrays
Instances[] trainingSplits = split[0];
Instances[] testingSplits = split[1];
// Use a set of classifiers
Classifier[] models = {
new J48(), // a decision tree
new PART(),
new DecisionTable(),//decision table majority classifier
new DecisionStump() //one-level decision tree
};
// Run for each model
for (int j = 0; j < models.length; j++) {
// Collect every group of predictions for current model in a FastVector
FastVector predictions = new FastVector();
// For each training-testing split pair, train and test the classifier
for (int i = 0; i < trainingSplits.length; i++) {
Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);
predictions.appendElements(validation.predictions());
// Uncomment to see the summary for each training-testing pair.
//System.out.println(models[j].toString());
}
// Calculate overall accuracy of current classifier on all splits
double accuracy = calculateAccuracy(predictions);
// Print current classifier's name and accuracy in a complicated,
// but nice-looking way.
System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "
+ String.format("%.2f%%", accuracy)
+ "\n---------------------------------");
}
}
}
3.运行结果
Accuracy of J48: 50.00%
---------------------------------
Accuracy of PART: 50.00%
---------------------------------
Accuracy of DecisionTable: 64.29%
---------------------------------
Accuracy of DecisionStump: 21.43%
---------------------------------
网友评论