在笔记本跑了一个简单的贝叶斯分类示例,工程级的代码原理类似,只不过有些细节需要修改。
主要代码如下:
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.ml.feature.{HashingTF, }
import org.apache.hadoop.fs.Path
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.
object bayes {
def main(args: Array[String]) {
val spark = SparkSession
.builder
.appName("bayes")
.getOrCreate()
import spark.implicits._
val sentenceDataFrame = spark.createDataFrame(Seq( //比较简单的样本数据 0分类 水果; 1分类 粮食
(0,"水果","苹果 橘子 香蕉"),
(1, "粮食","大米 小米 土豆")
)).toDF("label","category", "text")
val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
var wordData = tokenizer.transform(sentenceDataFrame)
val stopwordFile: String = "/applications/stopWords" //引入停用词
val customizedStopWords: Array[String] = if (stopwordFile.isEmpty()) {
Array.empty[String]
} else {
val stopWordText = spark.read.text(stopwordFile).as[String].collect()
stopWordText.flatMap(_.stripMargin.split("\\s+"))
}
val stopWordsRemover = new StopWordsRemover()
.setInputCol("words")
.setOutputCol("token")
stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
var wordDataWithOutStopWord = stopWordsRemover.transform(wordData)
var hashingTF = new HashingTF()
.setInputCol("token").setOutputCol("tf")
val tf= hashingTF.transform(wordDataWithOutStopWord)
tf.cache()
tf.show(false)
val idf=new IDF().setInputCol("tf").setOutputCol("features").fit(tf) //根据以上数据训练的idf模型,实际需要根据大量数据训练
val tfidf =idf.transform(tf)
tfidf.show(false)
val naiveBayesModel = new NaiveBayes() //创建贝叶斯模型,用上面数据训练
.setSmoothing(1)
.fit(tfidf)
val training = spark.createDataFrame(List( //待预测的测试数据
(0, "大米")
)).toDF("id", "text")
var tokenfeature = tokenizer.transform(training)
wordDataWithOutStopWord = stopWordsRemover.transform(tokenfeature)
var trainRescaledData = hashingTF.transform(wordDataWithOutStopWord)
val tfidf1 = idf.transform(trainRescaledData)
val predictions = naiveBayesModel
.transform(tfidf1)
predictions.printSchema()
val predict = predictions.first().getAs[Double]("prediction") //预测结果 输出label 为1 粮食分类
println("predict aaaaa:")
println(predict)
spark.stop()
}
}
网友评论