packageorg.apache.spark.mllib.clustering
importscala.util.Random
importorg.apache.spark.internal.Logging
importorg.apache.spark.mllib.linalg.BLAS.{axpy,scal}
importorg.apache.spark.mllib.linalg.Vectors
/**
* An utility object to run K-means locally. This is private to the ML package because it's used
* in the initialization of KMeans but not meant to be publicly exposed.
*/
private[mllib]objectLocalKMeansextendsLogging{
/**
* Run K-means++ on the weighted point set `points`. This first does the K-means++
* initialization procedure and then rounds of Lloyd's algorithm.
*/
defkMeansPlusPlus(
seed:Int,
points:Array[VectorWithNorm],
weights:Array[Double],
k:Int,
maxIterations:Int
):Array[VectorWithNorm]={
valrand=newRandom(seed)
valdimensions=points(0).vector.size
valcenters=newArray[VectorWithNorm](k)
//Initialize centers by sampling using the k-means++ procedure.
centers(0)=pickWeighted(rand, points, weights).toDense
valcostArray=points.map(KMeans.fastSquaredDistance(_, centers(0)))
for(i<-1until k) {
valsum=costArray.zip(weights).map(p=>p._1*p._2).sum
valr=rand.nextDouble()*sum
varcumulativeScore=0.0
varj=0
while(j
cumulativeScore+=weights(j)*costArray(j)
j+=1
}
if(j==0) {
logWarning("kMeansPlusPlus initialization ran out of distinct points for centers."+
s"Using duplicate point for center k = $i.")
centers(i)=points(0).toDense
}else{
centers(i)=points(j-1).toDense
}
//update costArray
for(p<-points.indices) {
costArray(p)=math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p))
}
}
//Run up to maxIterations iterations of Lloyd's algorithm
valoldClosest=Array.fill(points.length)(-1)
variteration=0
varmoved=true
while(moved&&iteration
moved=false
valcounts=Array.fill(k)(0.0)
valsums=Array.fill(k)(Vectors.zeros(dimensions))
vari=0
while(i
valp=points(i)
valindex=KMeans.findClosest(centers, p)._1
axpy(weights(i), p.vector, sums(index))
counts(index)+=weights(i)
if(index!=oldClosest(i)) {
moved=true
oldClosest(i)=index
}
i+=1
}
//Update centers
varj=0
while(j
if(counts(j)==0.0) {
//Assign center to a random point
centers(j)=points(rand.nextInt(points.length)).toDense
}else{
scal(1.0/counts(j), sums(j))
centers(j)=newVectorWithNorm(sums(j))
}
j+=1
}
iteration+=1
}
if(iteration==maxIterations) {
logInfo(s"Local KMeans++ reached the max number of iterations: $maxIterations.")
}else{
logInfo(s"Local KMeans++ converged in $iteration iterations.")
}
centers
}
privatedefpickWeighted[T](rand:Random,data:Array[T],weights:Array[Double]):T={
valr=rand.nextDouble()*weights.sum
vari=0
varcurWeight=0.0
while(i
curWeight+=weights(i)
i+=1
}
data(i-1)
}
}
网友评论