美文网首页
基于spark的时间序列预测包Sparkts._的使用

基于spark的时间序列预测包Sparkts._的使用

作者: e辉 | 来源:发表于2017-11-12 22:16 被阅读899次

    最近研究了一下时间序列预测的使用,网上找了大部分的资源,都是使用python来实现的,使用python来实现虽然能满足大部分的需求,但是python有一点缺点按就是只能使用一台计算资源进行计算,如果数据量大的时候,就有可能不能胜任,虽然这种情况很少,但是还是有可能会发生,因此就查了一下spark有没有这方面的资料,没想到还真的有,使用spark集群进行计算速度方面提升明显。

    项目接地址:https://github.com/sryza/spark-timeseries

    首先非常感谢这位博主,我是在学习了他的代码之下才能更好的理解spark-timeseries的使用。

    博客链接:http://blog.csdn.net/qq_30232405/article/details/70622400

    下面是我对代码的改进,主要是调整的是时间类型的通用性与arima模型能自定义pdq参数等,能通用大部分类型的时间。

    TimeFormatUtils.java

    import java.text.ParseException;
    import java.text.SimpleDateFormat;
    import java.util.HashMap;
    import java.util.regex.Pattern;
    
    public class TimeFormatUtils {
    
    
        /**
         * 获取时间类型格式
         *
         * @param timeStr
         * @return
         */
        public static String getDateType(String timeStr) {
            HashMap<String, String> dateRegFormat = new HashMap<String, String>();
            dateRegFormat.put("^\\d{4}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D*$", "yyyy-MM-dd HH:mm:ss");//2014年3月12日 13时5分34秒,2014-03-12 12:05:34,2014/3/12 12:5:34
            dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH:mm");//2014-03-12 12:05
            dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH");//2014-03-12 12
            dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd");//2014-03-12
            dateRegFormat.put("^\\d{4}\\D+\\d{2}$", "yyyy-MM");//2014-03
            dateRegFormat.put("^\\d{4}$", "yyyy");//2014
            dateRegFormat.put("^\\d{14}$", "yyyyMMddHHmmss");//20140312120534
            dateRegFormat.put("^\\d{12}$", "yyyyMMddHHmm");//201403121205
            dateRegFormat.put("^\\d{10}$", "yyyyMMddHH");//2014031212
            dateRegFormat.put("^\\d{8}$", "yyyyMMdd");//20140312
            dateRegFormat.put("^\\d{6}$", "yyyyMM");//201403
    
            try {
                for (String key : dateRegFormat.keySet()) {
                    if (Pattern.compile(key).matcher(timeStr).matches()) {
                        String formater = "";
                        if (timeStr.contains("/"))
                            return dateRegFormat.get(key).replaceAll("-", "/");
                        else
                            return dateRegFormat.get(key);
    
                    }
                }
            } catch (Exception e) {
                System.err.println("-----------------日期格式无效:" + timeStr);
                e.printStackTrace();
            }
            return null;
        }
    
        public static String fromatData(String time, SimpleDateFormat format) {
            try {
                SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
                return formatter.format(format.parse(time));
            } catch (ParseException e) {
                e.printStackTrace();
            }
            return null;
        }
    }
    

    TimeSeriesTrain.scala

    
    import java.sql.Timestamp
    import java.text.SimpleDateFormat
    import java.time.{ZoneId, ZonedDateTime}
    
    import com.cloudera.sparkts._
    import com.sendi.TimeSeries.Util.TimeFormatUtils
    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.mllib.linalg.{Vector, Vectors}
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    
    /**
      * 时间序列模型time-series的建立
      */
    object TimeSeriesTrain {
    
      /**
        * 总方法调用
        */
      def timeSeries(args: Array[String]) {
        args.foreach(println)
    
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    
        /**
          * 1、初始化spark环境
          */
        val sparkSession = SparkSession.builder
          .master("local[4]").appName("SparkTest")
          .enableHiveSupport() //创建支持HiveContext;
          .getOrCreate()
    
        /**
          * 2、初始化参数
          */
        //hive中的数据库名字
        val databaseTableName = args(0)
        //输入的列名必须是time data
        val hiveColumnName = List(args(1).toString.split(","): _*)
        //开始与结束时间
        val startTime = args(2)
        val endTime = args(3)
        //获取时间类型
        val sdf = new SimpleDateFormat(TimeFormatUtils.getDateType(startTime))
        //时间跨度
        val timeSpanType = args(4)
        val timeSpan = args(5).toInt
    
        //预测后面N个值
        val predictedN = args(6).toInt
        //存放的表名字
        val outputTableName = args(7)
    
        var listPDQ: List[String] = List("")
        var period = 0
        var holtWintersModelType = ""
    
        //选择模型(holtwinters或者是arima)
        val modelName = args(8)
    
        //根据不同的类型赋值不同的参数
        if (modelName.equals("arima")) {
          listPDQ = List(args(9).toString.split(","): _*)
        } else {
          //季节性参数(12或者4)
          period = args(9).toInt
          //holtWinters选择模型:additive(加法模型)、Multiplicative(乘法模型)
          holtWintersModelType = args(10)
        }
    
        /**
          * 3、 读取数据源,最终转换成 {time key data} 这种类型的RDD格式
          */
        val timeDataKeyDf = readHiveData(sparkSession, databaseTableName, hiveColumnName)
        val zonedDateDataDf = timeChangeToDate(sparkSession, timeDataKeyDf, hiveColumnName, startTime, sdf)
    
        /**
          * 4、创建数据中时间的跨度(Create an daily DateTimeIndex):开始日期+结束日期+递增数
          * 日期的格式要与数据库中time数据的格式一样
          */
        val dtIndex = getTimeSpan(startTime, endTime, timeSpanType, timeSpan, sdf)
    
        /**
          * 5、创建训练数据
          */
        val trainTsrdd = TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex, zonedDateDataDf,
          hiveColumnName(0), hiveColumnName(0) + "Key", hiveColumnName(1))
        trainTsrdd.cache()
        //填充缺失值
        val filledTrainTsrdd = trainTsrdd.fill("linear")
    
        /**
          * 6、建立模型对象,并使用训练数据进行训练
          */
        val timeSeriesKeyModel = new TimeSeriesKeyModel(predictedN, outputTableName)
        var forecastValue: RDD[(String, Vector)] = sparkSession.sparkContext.parallelize(Seq(("", Vectors.dense(1))))
        //选择模型
        modelName match {
          case "arima" => {
            //创建和训练arima模型
            val (forecast, coefficients) = timeSeriesKeyModel.arimaModelTrainKey(filledTrainTsrdd, listPDQ)
            //Arima模型评估参数的保存
            forecastValue = forecast
            timeSeriesKeyModel.arimaModelKeyEvaluationSave(sparkSession, coefficients, forecast)
          }
          case "holtwinters" => {
            //创建和训练HoltWinters模型(季节性模型)
            val (forecast, sse) = timeSeriesKeyModel.holtWintersModelTrainKey(filledTrainTsrdd, period, holtWintersModelType)
            //HoltWinters模型评估参数的保存
            forecastValue = forecast
            timeSeriesKeyModel.holtWintersModelKeyEvaluationSave(sparkSession, sse, forecast)
          }
          case _ => throw new UnsupportedOperationException("Currently only supports 'ariam' and 'holtwinters")
        }
    
        /**
          * 7、合并实际值和预测值,并加上日期,形成dataframe(Date,Data),并保存
          */
        timeSeriesKeyModel.actualForcastDateKeySaveInHive(sparkSession, filledTrainTsrdd, forecastValue, predictedN, startTime,
          endTime, timeSpanType, timeSpan, sdf, hiveColumnName)
      }
    
      /**
        * 读取hive中的数据,并对其进行处理操作,返回 time data key
        *
        * @param sparkSession
        * @param databaseTableName
        * @param hiveColumnName
        */
      def readHiveData(sparkSession: SparkSession, databaseTableName: String, hiveColumnName: List[String]): DataFrame = {
        //read the data form the hive  where的作用是取出字段为time的列
        var hiveDataDf = sparkSession.sql("select * from " + databaseTableName + " where " + hiveColumnName(0) + " !='" + hiveColumnName(0) + "'")
          .select(hiveColumnName.head, hiveColumnName.tail: _*)
    
        //去除空值
        hiveDataDf = hiveDataDf.filter(hiveColumnName(1) + " != ''")
    
        //In hiveDataDF:increase a new column.This column's name is hiveColumnName(0)+"Key",it's value is 0.
        //timeDataKeyDf : time data timeKey column
        val timeDataKeyDf = hiveDataDf.withColumn(hiveColumnName(0) + "Key", hiveDataDf(hiveColumnName(1)) * 0.toString)
          .select(hiveColumnName(0), hiveColumnName(1), hiveColumnName(0) + "Key")
        timeDataKeyDf
      }
    
    
      /**
        * 把数据中的“time”列转换成固定时间格式:ZonedDateTime(such as 2007-12-03T10:15:30+01:00 Europe/Paris.)
        *
        * @param sparkSession
        * @param timeDataKeyDf
        * @param hiveColumnName
        * @param startTime
        * @param sdf
        * @return
        */
      def timeChangeToDate(sparkSession: SparkSession, timeDataKeyDf: DataFrame, hiveColumnName: List[String], startTime: String,
                           sdf: SimpleDateFormat): DataFrame = {
        var rowRDD: RDD[Row] = sparkSession.sparkContext.parallelize(Seq(Row(""), Row("")))
        rowRDD = timeDataKeyDf.rdd.map { row =>
          row match {
            case Row(time, data, key) => {
              val date = sdf.parse(time.toString)
              val timestamp = new Timestamp(date.getTime)
              Row(timestamp, key.toString, data.toString.toDouble)
            }
          }
        }
    
        //根据模式字符串生成模式,转化成dataframe格式
        var field = Seq(
          StructField(hiveColumnName(0), TimestampType, true),
          StructField(hiveColumnName(0) + "Key", StringType, true),
          StructField(hiveColumnName(1), DoubleType, true))
        val schema = StructType(field)
        val zonedDateDataDf = sparkSession.createDataFrame(rowRDD, schema)
        return zonedDateDataDf
      }
    
      /**
        * 获取时间区间与时间跨度
        *
        * @param timeSpanType
        * @param timeSpan
        * @param sdf
        * @param startTime
        * @param endTime
        */
      def getTimeSpan(startTime: String, endTime: String, timeSpanType: String, timeSpan: Int, sdf: SimpleDateFormat): UniformDateTimeIndex = {
        val start = TimeFormatUtils.fromatData(startTime, sdf)
        val end = TimeFormatUtils.fromatData(endTime, sdf)
    
        val zone = ZoneId.systemDefault()
        val frequency = timeSpanType match {
          case "year" => new YearFrequency(timeSpan);
          case "month" => new MonthFrequency(timeSpan);
          case "day" => new DayFrequency(timeSpan);
          case "hour" => new HourFrequency(timeSpan);
          case "minute" => new MinuteFrequency(timeSpan);
        }
    
        val dtIndex: UniformDateTimeIndex = DateTimeIndex.uniformFromInterval(
          ZonedDateTime.of(start.substring(0, 4).toInt, start.substring(5, 7).toInt, start.substring(8, 10).toInt,
            start.substring(11, 13).toInt, start.substring(14, 16).toInt, 0, 0, zone),
          ZonedDateTime.of(end.substring(0, 4).toInt, end.substring(5, 7).toInt, end.substring(8, 10).toInt,
            end.substring(11, 13).toInt, end.substring(14, 16).toInt, 0, 0, zone),
          frequency)
        return dtIndex
      }
    }
    

    TimeSeriesKeyModel.scala

    import java.text.SimpleDateFormat
    import java.util.Calendar
    
    import com.cloudera.sparkts.TimeSeriesRDD
    import com.cloudera.sparkts.models.{ARIMA}
    import org.apache.spark.mllib.linalg.{Vector, Vectors}
    import org.apache.spark.mllib.stat.Statistics
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{Row, SaveMode, SparkSession}
    import org.apache.spark.sql.types.{StringType, StructField, StructType}
    
    import scala.collection.mutable.ArrayBuffer
    
    /**
      * 时间序列模型(处理的数据多一个key列)
      * Created by llq on 2017/5/3.
      */
    class TimeSeriesKeyModel {
      //预测后面N个值
      private var predictedN = 1
      //存放的表名字
      private var outputTableName = "time_series.timeseries_output"
    
      def this(predictedN: Int, outputTableName: String) {
        this()
        this.predictedN = predictedN
        this.outputTableName = outputTableName
      }
    
      /**
        * 实现Arima模型,处理数据是多一个key列
        *
        * @param trainTsrdd
        * @return
        */
      def arimaModelTrainKey(trainTsrdd: TimeSeriesRDD[String], listPDQ: List[String]): (RDD[(String, Vector)], RDD[(String, (String, (String, String, String), String, String))]) = {
        /** *参数设置 ******/
        val predictedN = this.predictedN
    
        /** *创建arima模型 ***/
        //创建和训练arima模型.其RDD格式为(ArimaModel,Vector)
        val arimaAndVectorRdd = trainTsrdd.map { line =>
          line match {
            case (key, denseVector) => {
              if (listPDQ.size >= 3) {
                (key, ARIMA.fitModel(listPDQ(0).toInt, listPDQ(1).toInt, listPDQ(2).toInt, denseVector), denseVector)
              } else {
                (key, ARIMA.autoFit(denseVector), denseVector)
              }
            }
          }
        }
    
        /** 参数输出:p,d,q的实际值和其系数值、最大似然估计值、aic值 **/
        val coefficients = arimaAndVectorRdd.map { line =>
          line match {
            case (key, arimaModel, denseVector) => {
              (key, (arimaModel.coefficients.mkString(","),
                (arimaModel.p.toString,
                  arimaModel.d.toString,
                  arimaModel.q.toString),
                arimaModel.logLikelihoodCSS(denseVector).toString,
                arimaModel.approxAIC(denseVector).toString))
            }
          }
        }
    
        coefficients.collect().map {
          _ match {
            case (key, (coefficients, (p, d, q), logLikelihood, aic)) =>
              println(key + " coefficients:" + coefficients + "=>" + "(p=" + p + ",d=" + d + ",q=" + q + ")")
          }
        }
    
        /** *预测出后N个的值 *****/
        val forecast = arimaAndVectorRdd.map { row =>
          row match {
            case (key, arimaModel, denseVector) => {
              (key, arimaModel.forecast(denseVector, predictedN))
            }
          }
        }
    
        //取出预测值
        val forecastValue = forecast.map {
          _ match {
            case (key, value) => {
              val partArray = value.toArray.mkString(",").split(",")
              var forecastArrayBuffer = new ArrayBuffer[Double]()
              var i = partArray.length - predictedN
              while (i < partArray.length) {
                forecastArrayBuffer += partArray(i).toDouble
                i = i + 1
              }
              (key, Vectors.dense(forecastArrayBuffer.toArray))
            }
          }
        }
    
        println("Arima forecast of next " + predictedN + " observations:")
        forecastValue.foreach(println)
        return (forecastValue, coefficients)
      }
    
    
      /**
        * Arima模型评估参数的保存
        * coefficients、(p、d、q)、logLikelihoodCSS、Aic、mean、variance、standard_deviation、max、min、range、count
        *
        * @param sparkSession
        * @param coefficients
        * @param forecastValue
        */
      def arimaModelKeyEvaluationSave(sparkSession: SparkSession, coefficients: RDD[(String, (String, (String, String, String), String, String))], forecastValue: RDD[(String, Vector)]): Unit = {
        /** 把vector转置 **/
        val forecastRdd = forecastValue.map {
          _ match {
            case (key, forecast) => forecast.toArray
          }
        }
        // Split the matrix into one number per line.
        val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
          case (row, rowIndex) => row.zipWithIndex.map {
            case (number, columnIndex) => columnIndex -> (rowIndex, number)
          }
        }
        // Build up the transposed matrix. Group and sort by column index first.
        val byColumn = byColumnAndRow.groupByKey.sortByKey().values
        // Then sort by row index.
        val transposed = byColumn.map {
          indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
        }
        val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
    
        /** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
        //评估模型的参数+预测出来数据的统计值
        val evaluation = coefficients.join(forecastValue.map {
          _ match {
            case (key, forecast) => {
              (key, (summary.mean.toArray(0).toString,
                summary.variance.toArray(0).toString,
                math.sqrt(summary.variance.toArray(0)).toString,
                summary.max.toArray(0).toString,
                summary.min.toArray(0).toString,
                (summary.max.toArray(0) - summary.min.toArray(0)).toString,
                summary.count.toString))
            }
          }
        })
    
        val evaluationRddRow = evaluation.map {
          _ match {
            case (key, ((coefficients, pdq, logLikelihoodCSS, aic), (mean, variance, standardDeviation, max, min, range, count))) => {
              Row(coefficients, pdq.toString, logLikelihoodCSS, aic, mean, variance, standardDeviation, max, min, range, count)
            }
          }
        }
    
        //形成评估dataframe
        val schemaString = "coefficients,pdq,logLikelihoodCSS,aic,mean,variance,standardDeviation,max,min,range,count"
        val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
        val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
    
        println("Evaluation in Arima:")
        evaluationDf.show()
    
        /**
          * 把这份数据保存到hive与db中
          */
        evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_arima_evaluation")
      }
    
    
      /**
        * 实现holtwinters模型,处理的数据多一个key列
        *
        * @param trainTsrdd
        * @param period
        * @param holtWintersModelType
        * @return
        */
      def holtWintersModelTrainKey(trainTsrdd: TimeSeriesRDD[String], period: Int, holtWintersModelType: String): (RDD[(String, Vector)], RDD[(String, Double)]) = {
        /** *参数设置 ******/
        //往后预测多少个值
        val predictedN = this.predictedN
    
        /** *创建HoltWinters模型 ***/
        //创建和训练HoltWinters模型.其RDD格式为(HoltWintersModel,Vector)
        val holtWintersAndVectorRdd = trainTsrdd.map { line =>
          line match {
            case (key, denseVector) =>
              (key, HoltWinters.fitModel(denseVector, period, holtWintersModelType), denseVector)
          }
        }
    
        /** *预测出后N个的值 *****/
        //构成N个预测值向量,之后导入到holtWinters的forcast方法中
        val predictedArrayBuffer = new ArrayBuffer[Double]()
        var i = 0
        while (i < predictedN) {
          predictedArrayBuffer += i
          i = i + 1
        }
        val predictedVectors = Vectors.dense(predictedArrayBuffer.toArray)
    
        //预测
        val forecast = holtWintersAndVectorRdd.map { row =>
          row match {
            case (key, holtWintersModel, denseVector) => {
              (key, holtWintersModel.forecast(denseVector, predictedVectors))
            }
          }
        }
        println("HoltWinters forecast of next " + predictedN + " observations:")
        forecast.foreach(println)
    
        /** holtWinters模型评估度量:SSE和方差 **/
        val sse = holtWintersAndVectorRdd.map { row =>
          row match {
            case (key, holtWintersModel, denseVector) => {
              (key, holtWintersModel.sse(denseVector))
            }
          }
        }
        return (forecast, sse)
      }
    
      /**
        * HoltWinters模型评估参数的保存
        * sse、mean、variance、standard_deviation、max、min、range、count
        *
        * @param sparkSession
        * @param sse
        * @param forecastValue
        */
      def holtWintersModelKeyEvaluationSave(sparkSession: SparkSession, sse: RDD[(String, Double)], forecastValue: RDD[(String, Vector)]): Unit = {
        /** 把vector转置 **/
        val forecastRdd = forecastValue.map {
          _ match {
            case (key, forecast) => forecast.toArray
          }
        }
        // Split the matrix into one number per line.
        val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
          case (row, rowIndex) => row.zipWithIndex.map {
            case (number, columnIndex) => columnIndex -> (rowIndex, number)
          }
        }
        // Build up the transposed matrix. Group and sort by column index first.
        val byColumn = byColumnAndRow.groupByKey.sortByKey().values
        // Then sort by row index.
        val transposed = byColumn.map {
          indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
        }
        val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
    
        /** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
        //评估模型的参数+预测出来数据的统计值
        val evaluation = sse.join(forecastValue.map {
          _ match {
            case (key, forecast) => {
              (key, (summary.mean.toArray(0).toString,
                summary.variance.toArray(0).toString,
                math.sqrt(summary.variance.toArray(0)).toString,
                summary.max.toArray(0).toString,
                summary.min.toArray(0).toString,
                (summary.max.toArray(0) - summary.min.toArray(0)).toString,
                summary.count.toString))
            }
          }
        })
    
        val evaluationRddRow = evaluation.map {
          _ match {
            case (key, (sse, (mean, variance, standardDeviation, max, min, range, count))) => {
              Row(sse.toString, mean, variance, standardDeviation, max, min, range, count)
            }
          }
        }
        //形成评估dataframe
        val schemaString = "sse,mean,variance,standardDeviation,max,min,range,count"
        val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
        val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
    
        println("Evaluation in HoltWinters:")
        evaluationDf.show()
    
        /**
          * 存入到hive与db中
          */
        evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_holtwinters_evaluation")
      }
    
      /**
        * 把信息存储到hive中
        *
        * @param sparkSession
        * @param dateDataRdd
        * @param hiveColumnName
        */
      private def keySaveInHive(sparkSession: SparkSession, dateDataRdd: RDD[Row], hiveColumnName: List[String]): Unit = {
        //把dateData转换成dataframe
        val schemaString = hiveColumnName(0) + " " + hiveColumnName(1)
        val schema = StructType(schemaString.split(" ")
          .map(fieldName => StructField(fieldName, StringType, true)))
        val dateDataDf = sparkSession.createDataFrame(dateDataRdd, schema)
    
        //dateDataDf存进hive中
        dateDataDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName)
      }
    
    
      /**
        * 合并实际值和预测值,并加上日期,形成dataframe(Date,Data)
        *
        * @param sparkSession
        * @param trainTsrdd
        * @param forecastValue
        * @param predictedN
        * @param startTime
        * @param endTime
        * @param timeSpanType
        * @param timeSpan
        * @param sdf
        * @param hiveColumnName
        */
      def actualForcastDateKeySaveInHive(sparkSession: SparkSession, trainTsrdd: TimeSeriesRDD[String], forecastValue: RDD[(String, Vector)],
                                         predictedN: Int, startTime: String, endTime: String, timeSpanType: String, timeSpan: Int,
                                         sdf: SimpleDateFormat, hiveColumnName: List[String]): Unit = {
        //在真实值后面追加预测值
        val actualAndForcastRdd = trainTsrdd.map {
          _ match {
            case (key, actualValue) => (key, actualValue.toArray.mkString(","))
          }
        }.join(forecastValue.map {
          _ match {
            case (key, forecastValue) => (key, forecastValue.toArray.mkString(","))
          }
        })
    
        //获取从开始预测到预测后的时间,转成RDD形式
        val dateArray = productStartDatePredictDate(predictedN, timeSpanType, timeSpan, sdf, startTime, endTime)
    
        val dateRdd = sparkSession.sparkContext.parallelize(dateArray.toArray.mkString(",").split(",").map(date => (date)))
    
        //合并日期和数据值,形成RDD[Row]+keyName
        val actualAndForcastArray = actualAndForcastRdd.collect()
        for (i <- 0 until actualAndForcastArray.length) {
          val dateDataRdd = actualAndForcastArray(i) match {
            case (key, value) => {
              val actualAndForcast = sparkSession.sparkContext.parallelize(value.toString().split(",")
                .map(data => {
                  data.replaceAll("\\(", "").replaceAll("\\)", "")
                }))
              dateRdd.zip(actualAndForcast).map {
                _ match {
                  case (date, data) => Row(date, data)
                }
              }
    
            }
          }
          //保存信息
          if (dateDataRdd.collect()(0).toString() != "[1]") {
            keySaveInHive(sparkSession, dateDataRdd, hiveColumnName)
          }
        }
      }
    
      /**
        * 批量生成日期,时间段为:训练数据的开始到预测的结束
        *
        * @param predictedN
        * @param timeSpanType
        * @param timeSpan
        * @param format
        * @param startTime
        * @param endTime
        * @return
        */
      def productStartDatePredictDate(predictedN: Int, timeSpanType: String, timeSpan: Int,
                                      format: SimpleDateFormat, startTime: String, endTime: String): ArrayBuffer[String] = {
        //形成开始start到预测predicted的日期
        val cal1 = Calendar.getInstance()
        cal1.setTime(format.parse(startTime))
        val cal2 = Calendar.getInstance()
        cal2.setTime(format.parse(endTime))
    
        /**
          * 获取时间差
          */
        var field = 1
        var diff: Long = 0
        timeSpanType match {
          case "year" => {
            field = Calendar.YEAR
            diff = (cal2.getTime.getYear() - cal1.getTime.getYear()) / timeSpan + predictedN;
          }
          case "month" => {
            field = Calendar.MONTH
            diff = ((cal2.getTime.getYear() - cal1.getTime.getYear()) * 12 + (cal2.getTime.getMonth() - cal1.getTime.getMonth())) / timeSpan + predictedN
          }
          case "day" => {
            field = Calendar.DATE
            diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60 * 24) / timeSpan + predictedN
          }
          case "hour" => {
            field = Calendar.HOUR
            diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60) / timeSpan + predictedN
          }
          case "minute" => {
            field = Calendar.MINUTE
            diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60) / timeSpan + predictedN;
          }
        }
    
        var iDiff = 0L;
        var dateArrayBuffer = new ArrayBuffer[String]()
        while (iDiff <= diff) {
          //保存日期
          dateArrayBuffer += format.format(cal1.getTime)
          cal1.add(field, timeSpan)
          iDiff = iDiff + 1;
        }
        dateArrayBuffer
      }
    }
    

    相关文章

      网友评论

          本文标题:基于spark的时间序列预测包Sparkts._的使用

          本文链接:https://www.haomeiwen.com/subject/iutomxtx.html