前言
如题,记录在Spark ML LR中如何解决数据不平衡。参考:Dealing with unbalanced datasets in Spark MLlib
1、数据不平衡
指label == 1和label == 0 的数据比例的很多,如80%和20%,这样导致模型的结果的准确率也不平衡,不准确。
2、setWeightCol 主要代码
val labelCol = "label"
def balanceDataset(dataset: DataFrame): DataFrame = {
// Re-balancing (weighting) of records to be used in the logistic loss objective function
val numNegatives = dataset.filter(dataset(labelCol) === 0).count
val datasetSize = dataset.count
val balancingRatio = (datasetSize - numNegatives).toDouble / datasetSize
val calculateWeights = udf { d: Double =>
if (d == 0.0) {
1 * balancingRatio
} else {
(1 * (1.0 - balancingRatio))
}
}
val weightedDataset = dataset.withColumn("classWeightCol", calculateWeights(dataset(labelCol)))
weightedDataset
}
val df_weighted = balanceDataset(df)
val lr = new LogisticRegression().setLabelCol(labelCol).setWeightCol("classWeightCol")
这样就很方便解决了数据不平衡的问题
3、其他方法
最开始不知道有setWeightCol这个方法,我是按下面的方法解决的,记录一下
下面假设label=0的数据大于label=1的数据
/**
* 将label = 0的随机抽样,使label=1数量和label=0的数量大致相同
*/
def sample(df: DataFrame): DataFrame = {
val df0 = df.where(s"${labelCol}=0")
val df1 = df.where(s"${labelCol}=1")
val y0 = df0.count()
val y1 = df1.count()
val num = 1.0 * y1 / y0
val df00 = df0.sample(false, num) //解决类别数据平衡性问题,对没有违约样本进行随机抽样
df00.union(df1)
}
或
/**
* 是将label=1 的复制多份,使label=1数量和label=0的数量大致相同
*/
def copy(df: DataFrame): DataFrame = {
var df_res = df
val df1 = df.where(s"${labelCol}=1")
val y0 = df.where(s"${labelCol}=0").count()
val y1 = df1.count()
val num = (y0 / y1).toInt - 1
for (a <- 1 to num) {
df_res = df_res.union(df1)
}
df_res
}
每日英语
- 1、laptop n. 膝上型轻便电脑,笔记本电脑
- 2、general-purpose adj. 多用途的;一般用途的 general purpose adj. 通用的
- 3、vulnerable adj. 易受攻击的,易受…的攻击;易受伤害的;有弱点的
- 4、handful n. 少数;一把;棘手事
- 5、coordinates n. [数] 坐标;相配之衣物 v. 使协调;使调和(coordinate的第三人称单数形式)
网友评论