最近研究了一下时间序列预测的使用,网上找了大部分的资源,都是使用python来实现的,使用python来实现虽然能满足大部分的需求,但是python有一点缺点按就是只能使用一台计算资源进行计算,如果数据量大的时候,就有可能不能胜任,虽然这种情况很少,但是还是有可能会发生,因此就查了一下spark有没有这方面的资料,没想到还真的有,使用spark集群进行计算速度方面提升明显。
项目接地址:https://github.com/sryza/spark-timeseries
首先非常感谢这位博主,我是在学习了他的代码之下才能更好的理解spark-timeseries的使用。
博客链接:http://blog.csdn.net/qq_30232405/article/details/70622400
下面是我对代码的改进,主要是调整的是时间类型的通用性与arima模型能自定义pdq参数等,能通用大部分类型的时间。
TimeFormatUtils.java
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.HashMap;
import java.util.regex.Pattern;
public class TimeFormatUtils {
/**
* 获取时间类型格式
*
* @param timeStr
* @return
*/
public static String getDateType(String timeStr) {
HashMap<String, String> dateRegFormat = new HashMap<String, String>();
dateRegFormat.put("^\\d{4}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D*$", "yyyy-MM-dd HH:mm:ss");//2014年3月12日 13时5分34秒,2014-03-12 12:05:34,2014/3/12 12:5:34
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH:mm");//2014-03-12 12:05
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH");//2014-03-12 12
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd");//2014-03-12
dateRegFormat.put("^\\d{4}\\D+\\d{2}$", "yyyy-MM");//2014-03
dateRegFormat.put("^\\d{4}$", "yyyy");//2014
dateRegFormat.put("^\\d{14}$", "yyyyMMddHHmmss");//20140312120534
dateRegFormat.put("^\\d{12}$", "yyyyMMddHHmm");//201403121205
dateRegFormat.put("^\\d{10}$", "yyyyMMddHH");//2014031212
dateRegFormat.put("^\\d{8}$", "yyyyMMdd");//20140312
dateRegFormat.put("^\\d{6}$", "yyyyMM");//201403
try {
for (String key : dateRegFormat.keySet()) {
if (Pattern.compile(key).matcher(timeStr).matches()) {
String formater = "";
if (timeStr.contains("/"))
return dateRegFormat.get(key).replaceAll("-", "/");
else
return dateRegFormat.get(key);
}
}
} catch (Exception e) {
System.err.println("-----------------日期格式无效:" + timeStr);
e.printStackTrace();
}
return null;
}
public static String fromatData(String time, SimpleDateFormat format) {
try {
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
return formatter.format(format.parse(time));
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
TimeSeriesTrain.scala
import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.time.{ZoneId, ZonedDateTime}
import com.cloudera.sparkts._
import com.sendi.TimeSeries.Util.TimeFormatUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* 时间序列模型time-series的建立
*/
object TimeSeriesTrain {
/**
* 总方法调用
*/
def timeSeries(args: Array[String]) {
args.foreach(println)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
/**
* 1、初始化spark环境
*/
val sparkSession = SparkSession.builder
.master("local[4]").appName("SparkTest")
.enableHiveSupport() //创建支持HiveContext;
.getOrCreate()
/**
* 2、初始化参数
*/
//hive中的数据库名字
val databaseTableName = args(0)
//输入的列名必须是time data
val hiveColumnName = List(args(1).toString.split(","): _*)
//开始与结束时间
val startTime = args(2)
val endTime = args(3)
//获取时间类型
val sdf = new SimpleDateFormat(TimeFormatUtils.getDateType(startTime))
//时间跨度
val timeSpanType = args(4)
val timeSpan = args(5).toInt
//预测后面N个值
val predictedN = args(6).toInt
//存放的表名字
val outputTableName = args(7)
var listPDQ: List[String] = List("")
var period = 0
var holtWintersModelType = ""
//选择模型(holtwinters或者是arima)
val modelName = args(8)
//根据不同的类型赋值不同的参数
if (modelName.equals("arima")) {
listPDQ = List(args(9).toString.split(","): _*)
} else {
//季节性参数(12或者4)
period = args(9).toInt
//holtWinters选择模型:additive(加法模型)、Multiplicative(乘法模型)
holtWintersModelType = args(10)
}
/**
* 3、 读取数据源,最终转换成 {time key data} 这种类型的RDD格式
*/
val timeDataKeyDf = readHiveData(sparkSession, databaseTableName, hiveColumnName)
val zonedDateDataDf = timeChangeToDate(sparkSession, timeDataKeyDf, hiveColumnName, startTime, sdf)
/**
* 4、创建数据中时间的跨度(Create an daily DateTimeIndex):开始日期+结束日期+递增数
* 日期的格式要与数据库中time数据的格式一样
*/
val dtIndex = getTimeSpan(startTime, endTime, timeSpanType, timeSpan, sdf)
/**
* 5、创建训练数据
*/
val trainTsrdd = TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex, zonedDateDataDf,
hiveColumnName(0), hiveColumnName(0) + "Key", hiveColumnName(1))
trainTsrdd.cache()
//填充缺失值
val filledTrainTsrdd = trainTsrdd.fill("linear")
/**
* 6、建立模型对象,并使用训练数据进行训练
*/
val timeSeriesKeyModel = new TimeSeriesKeyModel(predictedN, outputTableName)
var forecastValue: RDD[(String, Vector)] = sparkSession.sparkContext.parallelize(Seq(("", Vectors.dense(1))))
//选择模型
modelName match {
case "arima" => {
//创建和训练arima模型
val (forecast, coefficients) = timeSeriesKeyModel.arimaModelTrainKey(filledTrainTsrdd, listPDQ)
//Arima模型评估参数的保存
forecastValue = forecast
timeSeriesKeyModel.arimaModelKeyEvaluationSave(sparkSession, coefficients, forecast)
}
case "holtwinters" => {
//创建和训练HoltWinters模型(季节性模型)
val (forecast, sse) = timeSeriesKeyModel.holtWintersModelTrainKey(filledTrainTsrdd, period, holtWintersModelType)
//HoltWinters模型评估参数的保存
forecastValue = forecast
timeSeriesKeyModel.holtWintersModelKeyEvaluationSave(sparkSession, sse, forecast)
}
case _ => throw new UnsupportedOperationException("Currently only supports 'ariam' and 'holtwinters")
}
/**
* 7、合并实际值和预测值,并加上日期,形成dataframe(Date,Data),并保存
*/
timeSeriesKeyModel.actualForcastDateKeySaveInHive(sparkSession, filledTrainTsrdd, forecastValue, predictedN, startTime,
endTime, timeSpanType, timeSpan, sdf, hiveColumnName)
}
/**
* 读取hive中的数据,并对其进行处理操作,返回 time data key
*
* @param sparkSession
* @param databaseTableName
* @param hiveColumnName
*/
def readHiveData(sparkSession: SparkSession, databaseTableName: String, hiveColumnName: List[String]): DataFrame = {
//read the data form the hive where的作用是取出字段为time的列
var hiveDataDf = sparkSession.sql("select * from " + databaseTableName + " where " + hiveColumnName(0) + " !='" + hiveColumnName(0) + "'")
.select(hiveColumnName.head, hiveColumnName.tail: _*)
//去除空值
hiveDataDf = hiveDataDf.filter(hiveColumnName(1) + " != ''")
//In hiveDataDF:increase a new column.This column's name is hiveColumnName(0)+"Key",it's value is 0.
//timeDataKeyDf : time data timeKey column
val timeDataKeyDf = hiveDataDf.withColumn(hiveColumnName(0) + "Key", hiveDataDf(hiveColumnName(1)) * 0.toString)
.select(hiveColumnName(0), hiveColumnName(1), hiveColumnName(0) + "Key")
timeDataKeyDf
}
/**
* 把数据中的“time”列转换成固定时间格式:ZonedDateTime(such as 2007-12-03T10:15:30+01:00 Europe/Paris.)
*
* @param sparkSession
* @param timeDataKeyDf
* @param hiveColumnName
* @param startTime
* @param sdf
* @return
*/
def timeChangeToDate(sparkSession: SparkSession, timeDataKeyDf: DataFrame, hiveColumnName: List[String], startTime: String,
sdf: SimpleDateFormat): DataFrame = {
var rowRDD: RDD[Row] = sparkSession.sparkContext.parallelize(Seq(Row(""), Row("")))
rowRDD = timeDataKeyDf.rdd.map { row =>
row match {
case Row(time, data, key) => {
val date = sdf.parse(time.toString)
val timestamp = new Timestamp(date.getTime)
Row(timestamp, key.toString, data.toString.toDouble)
}
}
}
//根据模式字符串生成模式,转化成dataframe格式
var field = Seq(
StructField(hiveColumnName(0), TimestampType, true),
StructField(hiveColumnName(0) + "Key", StringType, true),
StructField(hiveColumnName(1), DoubleType, true))
val schema = StructType(field)
val zonedDateDataDf = sparkSession.createDataFrame(rowRDD, schema)
return zonedDateDataDf
}
/**
* 获取时间区间与时间跨度
*
* @param timeSpanType
* @param timeSpan
* @param sdf
* @param startTime
* @param endTime
*/
def getTimeSpan(startTime: String, endTime: String, timeSpanType: String, timeSpan: Int, sdf: SimpleDateFormat): UniformDateTimeIndex = {
val start = TimeFormatUtils.fromatData(startTime, sdf)
val end = TimeFormatUtils.fromatData(endTime, sdf)
val zone = ZoneId.systemDefault()
val frequency = timeSpanType match {
case "year" => new YearFrequency(timeSpan);
case "month" => new MonthFrequency(timeSpan);
case "day" => new DayFrequency(timeSpan);
case "hour" => new HourFrequency(timeSpan);
case "minute" => new MinuteFrequency(timeSpan);
}
val dtIndex: UniformDateTimeIndex = DateTimeIndex.uniformFromInterval(
ZonedDateTime.of(start.substring(0, 4).toInt, start.substring(5, 7).toInt, start.substring(8, 10).toInt,
start.substring(11, 13).toInt, start.substring(14, 16).toInt, 0, 0, zone),
ZonedDateTime.of(end.substring(0, 4).toInt, end.substring(5, 7).toInt, end.substring(8, 10).toInt,
end.substring(11, 13).toInt, end.substring(14, 16).toInt, 0, 0, zone),
frequency)
return dtIndex
}
}
TimeSeriesKeyModel.scala
import java.text.SimpleDateFormat
import java.util.Calendar
import com.cloudera.sparkts.TimeSeriesRDD
import com.cloudera.sparkts.models.{ARIMA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import scala.collection.mutable.ArrayBuffer
/**
* 时间序列模型(处理的数据多一个key列)
* Created by llq on 2017/5/3.
*/
class TimeSeriesKeyModel {
//预测后面N个值
private var predictedN = 1
//存放的表名字
private var outputTableName = "time_series.timeseries_output"
def this(predictedN: Int, outputTableName: String) {
this()
this.predictedN = predictedN
this.outputTableName = outputTableName
}
/**
* 实现Arima模型,处理数据是多一个key列
*
* @param trainTsrdd
* @return
*/
def arimaModelTrainKey(trainTsrdd: TimeSeriesRDD[String], listPDQ: List[String]): (RDD[(String, Vector)], RDD[(String, (String, (String, String, String), String, String))]) = {
/** *参数设置 ******/
val predictedN = this.predictedN
/** *创建arima模型 ***/
//创建和训练arima模型.其RDD格式为(ArimaModel,Vector)
val arimaAndVectorRdd = trainTsrdd.map { line =>
line match {
case (key, denseVector) => {
if (listPDQ.size >= 3) {
(key, ARIMA.fitModel(listPDQ(0).toInt, listPDQ(1).toInt, listPDQ(2).toInt, denseVector), denseVector)
} else {
(key, ARIMA.autoFit(denseVector), denseVector)
}
}
}
}
/** 参数输出:p,d,q的实际值和其系数值、最大似然估计值、aic值 **/
val coefficients = arimaAndVectorRdd.map { line =>
line match {
case (key, arimaModel, denseVector) => {
(key, (arimaModel.coefficients.mkString(","),
(arimaModel.p.toString,
arimaModel.d.toString,
arimaModel.q.toString),
arimaModel.logLikelihoodCSS(denseVector).toString,
arimaModel.approxAIC(denseVector).toString))
}
}
}
coefficients.collect().map {
_ match {
case (key, (coefficients, (p, d, q), logLikelihood, aic)) =>
println(key + " coefficients:" + coefficients + "=>" + "(p=" + p + ",d=" + d + ",q=" + q + ")")
}
}
/** *预测出后N个的值 *****/
val forecast = arimaAndVectorRdd.map { row =>
row match {
case (key, arimaModel, denseVector) => {
(key, arimaModel.forecast(denseVector, predictedN))
}
}
}
//取出预测值
val forecastValue = forecast.map {
_ match {
case (key, value) => {
val partArray = value.toArray.mkString(",").split(",")
var forecastArrayBuffer = new ArrayBuffer[Double]()
var i = partArray.length - predictedN
while (i < partArray.length) {
forecastArrayBuffer += partArray(i).toDouble
i = i + 1
}
(key, Vectors.dense(forecastArrayBuffer.toArray))
}
}
}
println("Arima forecast of next " + predictedN + " observations:")
forecastValue.foreach(println)
return (forecastValue, coefficients)
}
/**
* Arima模型评估参数的保存
* coefficients、(p、d、q)、logLikelihoodCSS、Aic、mean、variance、standard_deviation、max、min、range、count
*
* @param sparkSession
* @param coefficients
* @param forecastValue
*/
def arimaModelKeyEvaluationSave(sparkSession: SparkSession, coefficients: RDD[(String, (String, (String, String, String), String, String))], forecastValue: RDD[(String, Vector)]): Unit = {
/** 把vector转置 **/
val forecastRdd = forecastValue.map {
_ match {
case (key, forecast) => forecast.toArray
}
}
// Split the matrix into one number per line.
val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
case (row, rowIndex) => row.zipWithIndex.map {
case (number, columnIndex) => columnIndex -> (rowIndex, number)
}
}
// Build up the transposed matrix. Group and sort by column index first.
val byColumn = byColumnAndRow.groupByKey.sortByKey().values
// Then sort by row index.
val transposed = byColumn.map {
indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
}
val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
/** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
//评估模型的参数+预测出来数据的统计值
val evaluation = coefficients.join(forecastValue.map {
_ match {
case (key, forecast) => {
(key, (summary.mean.toArray(0).toString,
summary.variance.toArray(0).toString,
math.sqrt(summary.variance.toArray(0)).toString,
summary.max.toArray(0).toString,
summary.min.toArray(0).toString,
(summary.max.toArray(0) - summary.min.toArray(0)).toString,
summary.count.toString))
}
}
})
val evaluationRddRow = evaluation.map {
_ match {
case (key, ((coefficients, pdq, logLikelihoodCSS, aic), (mean, variance, standardDeviation, max, min, range, count))) => {
Row(coefficients, pdq.toString, logLikelihoodCSS, aic, mean, variance, standardDeviation, max, min, range, count)
}
}
}
//形成评估dataframe
val schemaString = "coefficients,pdq,logLikelihoodCSS,aic,mean,variance,standardDeviation,max,min,range,count"
val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
println("Evaluation in Arima:")
evaluationDf.show()
/**
* 把这份数据保存到hive与db中
*/
evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_arima_evaluation")
}
/**
* 实现holtwinters模型,处理的数据多一个key列
*
* @param trainTsrdd
* @param period
* @param holtWintersModelType
* @return
*/
def holtWintersModelTrainKey(trainTsrdd: TimeSeriesRDD[String], period: Int, holtWintersModelType: String): (RDD[(String, Vector)], RDD[(String, Double)]) = {
/** *参数设置 ******/
//往后预测多少个值
val predictedN = this.predictedN
/** *创建HoltWinters模型 ***/
//创建和训练HoltWinters模型.其RDD格式为(HoltWintersModel,Vector)
val holtWintersAndVectorRdd = trainTsrdd.map { line =>
line match {
case (key, denseVector) =>
(key, HoltWinters.fitModel(denseVector, period, holtWintersModelType), denseVector)
}
}
/** *预测出后N个的值 *****/
//构成N个预测值向量,之后导入到holtWinters的forcast方法中
val predictedArrayBuffer = new ArrayBuffer[Double]()
var i = 0
while (i < predictedN) {
predictedArrayBuffer += i
i = i + 1
}
val predictedVectors = Vectors.dense(predictedArrayBuffer.toArray)
//预测
val forecast = holtWintersAndVectorRdd.map { row =>
row match {
case (key, holtWintersModel, denseVector) => {
(key, holtWintersModel.forecast(denseVector, predictedVectors))
}
}
}
println("HoltWinters forecast of next " + predictedN + " observations:")
forecast.foreach(println)
/** holtWinters模型评估度量:SSE和方差 **/
val sse = holtWintersAndVectorRdd.map { row =>
row match {
case (key, holtWintersModel, denseVector) => {
(key, holtWintersModel.sse(denseVector))
}
}
}
return (forecast, sse)
}
/**
* HoltWinters模型评估参数的保存
* sse、mean、variance、standard_deviation、max、min、range、count
*
* @param sparkSession
* @param sse
* @param forecastValue
*/
def holtWintersModelKeyEvaluationSave(sparkSession: SparkSession, sse: RDD[(String, Double)], forecastValue: RDD[(String, Vector)]): Unit = {
/** 把vector转置 **/
val forecastRdd = forecastValue.map {
_ match {
case (key, forecast) => forecast.toArray
}
}
// Split the matrix into one number per line.
val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
case (row, rowIndex) => row.zipWithIndex.map {
case (number, columnIndex) => columnIndex -> (rowIndex, number)
}
}
// Build up the transposed matrix. Group and sort by column index first.
val byColumn = byColumnAndRow.groupByKey.sortByKey().values
// Then sort by row index.
val transposed = byColumn.map {
indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
}
val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
/** 统计求出预测值的均值、方差、标准差、最大值、最小值、极差、数量等;合并模型评估数据+统计值 **/
//评估模型的参数+预测出来数据的统计值
val evaluation = sse.join(forecastValue.map {
_ match {
case (key, forecast) => {
(key, (summary.mean.toArray(0).toString,
summary.variance.toArray(0).toString,
math.sqrt(summary.variance.toArray(0)).toString,
summary.max.toArray(0).toString,
summary.min.toArray(0).toString,
(summary.max.toArray(0) - summary.min.toArray(0)).toString,
summary.count.toString))
}
}
})
val evaluationRddRow = evaluation.map {
_ match {
case (key, (sse, (mean, variance, standardDeviation, max, min, range, count))) => {
Row(sse.toString, mean, variance, standardDeviation, max, min, range, count)
}
}
}
//形成评估dataframe
val schemaString = "sse,mean,variance,standardDeviation,max,min,range,count"
val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
println("Evaluation in HoltWinters:")
evaluationDf.show()
/**
* 存入到hive与db中
*/
evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_holtwinters_evaluation")
}
/**
* 把信息存储到hive中
*
* @param sparkSession
* @param dateDataRdd
* @param hiveColumnName
*/
private def keySaveInHive(sparkSession: SparkSession, dateDataRdd: RDD[Row], hiveColumnName: List[String]): Unit = {
//把dateData转换成dataframe
val schemaString = hiveColumnName(0) + " " + hiveColumnName(1)
val schema = StructType(schemaString.split(" ")
.map(fieldName => StructField(fieldName, StringType, true)))
val dateDataDf = sparkSession.createDataFrame(dateDataRdd, schema)
//dateDataDf存进hive中
dateDataDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName)
}
/**
* 合并实际值和预测值,并加上日期,形成dataframe(Date,Data)
*
* @param sparkSession
* @param trainTsrdd
* @param forecastValue
* @param predictedN
* @param startTime
* @param endTime
* @param timeSpanType
* @param timeSpan
* @param sdf
* @param hiveColumnName
*/
def actualForcastDateKeySaveInHive(sparkSession: SparkSession, trainTsrdd: TimeSeriesRDD[String], forecastValue: RDD[(String, Vector)],
predictedN: Int, startTime: String, endTime: String, timeSpanType: String, timeSpan: Int,
sdf: SimpleDateFormat, hiveColumnName: List[String]): Unit = {
//在真实值后面追加预测值
val actualAndForcastRdd = trainTsrdd.map {
_ match {
case (key, actualValue) => (key, actualValue.toArray.mkString(","))
}
}.join(forecastValue.map {
_ match {
case (key, forecastValue) => (key, forecastValue.toArray.mkString(","))
}
})
//获取从开始预测到预测后的时间,转成RDD形式
val dateArray = productStartDatePredictDate(predictedN, timeSpanType, timeSpan, sdf, startTime, endTime)
val dateRdd = sparkSession.sparkContext.parallelize(dateArray.toArray.mkString(",").split(",").map(date => (date)))
//合并日期和数据值,形成RDD[Row]+keyName
val actualAndForcastArray = actualAndForcastRdd.collect()
for (i <- 0 until actualAndForcastArray.length) {
val dateDataRdd = actualAndForcastArray(i) match {
case (key, value) => {
val actualAndForcast = sparkSession.sparkContext.parallelize(value.toString().split(",")
.map(data => {
data.replaceAll("\\(", "").replaceAll("\\)", "")
}))
dateRdd.zip(actualAndForcast).map {
_ match {
case (date, data) => Row(date, data)
}
}
}
}
//保存信息
if (dateDataRdd.collect()(0).toString() != "[1]") {
keySaveInHive(sparkSession, dateDataRdd, hiveColumnName)
}
}
}
/**
* 批量生成日期,时间段为:训练数据的开始到预测的结束
*
* @param predictedN
* @param timeSpanType
* @param timeSpan
* @param format
* @param startTime
* @param endTime
* @return
*/
def productStartDatePredictDate(predictedN: Int, timeSpanType: String, timeSpan: Int,
format: SimpleDateFormat, startTime: String, endTime: String): ArrayBuffer[String] = {
//形成开始start到预测predicted的日期
val cal1 = Calendar.getInstance()
cal1.setTime(format.parse(startTime))
val cal2 = Calendar.getInstance()
cal2.setTime(format.parse(endTime))
/**
* 获取时间差
*/
var field = 1
var diff: Long = 0
timeSpanType match {
case "year" => {
field = Calendar.YEAR
diff = (cal2.getTime.getYear() - cal1.getTime.getYear()) / timeSpan + predictedN;
}
case "month" => {
field = Calendar.MONTH
diff = ((cal2.getTime.getYear() - cal1.getTime.getYear()) * 12 + (cal2.getTime.getMonth() - cal1.getTime.getMonth())) / timeSpan + predictedN
}
case "day" => {
field = Calendar.DATE
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60 * 24) / timeSpan + predictedN
}
case "hour" => {
field = Calendar.HOUR
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60) / timeSpan + predictedN
}
case "minute" => {
field = Calendar.MINUTE
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60) / timeSpan + predictedN;
}
}
var iDiff = 0L;
var dateArrayBuffer = new ArrayBuffer[String]()
while (iDiff <= diff) {
//保存日期
dateArrayBuffer += format.format(cal1.getTime)
cal1.add(field, timeSpan)
iDiff = iDiff + 1;
}
dateArrayBuffer
}
}
网友评论