美文网首页
Scala spark UDF统计用户在基站停留时间(多列输入)

Scala spark UDF统计用户在基站停留时间(多列输入)

作者: FredricZhu | 来源:发表于2020-09-13 11:17 被阅读0次

数据
StopInfo.log

17316288888,20160327082400,JZ1,1
17316288887,20160327082500,JZ1,1
17316288888,20160327180000,JZ1,0
17316288887,20160327180005,JZ1,0

Location.log

JZ1,116.3,40.1
JZ2,115.2,30.1

代码

package com.sensetime.userstoptime

import java.sql.Timestamp
import java.util.TimeZone

import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.sql.functions.{col, collect_list, struct, udf}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

case class StopInfo(phoneNumber: String, time: Timestamp, jizhanName: String, isIn: Int)
case class LocInfo(jizhanName: String, longtitude: Double, latitude: Double)


// 求用户在某基站停留的时间
object UserStopTime {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder.master("local[*]").config("spark.sql.warehouse.dir","file:///")
      .getOrCreate
    val sc = spark.sparkContext
    sc.setLogLevel("ERROR")
    val stopInfoRDD = sc.textFile("src/main/resources/userstoptime/stopinfo.log")
    val locRDD = sc.textFile("src/main/resources/userstoptime/location.log")

    val timeFormat = FastDateFormat.getInstance(
      "yyyyMMddHHmmss",
      TimeZone.getTimeZone("UTC")
    )

    val stopRDDFunc = (words: Iterator[String]) => {
      val lb = ListBuffer[StopInfo]()
      for (word <- words) {
        val wArr = word.split(",")
        val timeStr = wArr(1)
        val time = new Timestamp(timeFormat.parse(timeStr).getTime)

        val isIn = wArr(3).toInt
        lb += StopInfo(wArr(0), time, wArr(2), isIn)
      }
      lb.toIterator
    }

    // 停止信息RDD
    val stopInfoSet = stopInfoRDD.mapPartitions(stopRDDFunc)

    val locRDDFunc = (locs: Iterator[String]) => {
      val lb = ListBuffer[LocInfo]()
      for(loc <- locs) {
        val wArr = loc.split(",")
        lb += LocInfo(wArr(0), wArr(1).toDouble, wArr(2).toDouble)
      }
      lb.toIterator
    }

    // 位置信息RDD
    val locInfoSet = locRDD.mapPartitions(locRDDFunc)


    import spark.implicits._
    val stopDF = stopInfoSet.toDF
    val locDF = locInfoSet.toDF

    val groupedDF = stopDF.groupBy("phoneNumber", "jizhanName")
      .agg(collect_list(struct("time", "isIn")).alias("jzInfo"))

    def getStayTime(jzInfo: mutable.WrappedArray[Row]): Float = {
      val stayInfoList = jzInfo.map{ row =>
        val isIn = row.getInt(row.fieldIndex("isIn"))
        val time = row.getTimestamp(row.fieldIndex("time"))
        isIn -> time
      }.toList

      var startTime = new Timestamp(System.currentTimeMillis())
      var endTime = new Timestamp(System.currentTimeMillis())
      var totalStay = 0.0f
      for((_in, _time) <- stayInfoList) {
        if(_in==1) {
          startTime = _time
        }
        if(_in==0) {
          endTime = _time
          totalStay += (endTime.getTime - startTime.getTime).toFloat/(1000 * 3600)
        }
      }
      totalStay
    }
    val stayTimeUDF = udf(getStayTime _)

    val stayTimeDF = groupedDF.withColumn("stayTime", stayTimeUDF(col("jzInfo")))
      .drop("jzInfo")

    val resDF = stayTimeDF.join(locDF, Seq("jizhanName"), "inner").sort(col("stayTime").desc)
    resDF.show(10, false)
  }
}

程序输出如下,


图片.png

相关文章

网友评论

      本文标题:Scala spark UDF统计用户在基站停留时间(多列输入)

      本文链接:https://www.haomeiwen.com/subject/asenektx.html