简介
通常我们需要在一大堆数中求前k大的数。比如在搜索引擎中求当天用户点击次数排名前10000的热词,在文本特征选择中求值按从大到小排名前k个文本
等问题,都涉及到一个核心问题,即TOP-K问题。
那么这种问题就会有一个比较好的算法,叫做BFPTR算法,又称为中位数的中位数算法,它的最坏时间复杂度为O(N),它是由Blum、Floyd、Pratt、Rivest、Tarjan提出。该算法的思想是修改快速选择算法的主元选取方法,提高算法在最坏情况下的时间复杂度。
参考——BFPRT算法原理
问题描述
给定一个数组arr和k,返回数组中第k小的数
常规思路——堆排序
- 创建一个长度为K的数组minArr[k]
- 将此数组构造成为大顶堆,那么数组中第一个元素就是数组中最大的元素,也就是我们要求的第k小的元素。
- 以上是思考方法,那么怎么首先这个大顶堆
- 首先将数组arr的前k个元素构造成为大顶堆
- 然后从数组的第k+1个数遍历,向大顶堆数组minArr中插入元素,不断的调整大顶堆,如果插入的元素小于堆顶,则插入,否则不插入。
- 整体思路上借鉴了堆排序的思想。
- 时间复杂度: 等同于堆排的时间复杂度 O(N * log(k))
代码
// ############################----------手撕BFPRT----------###################################
// --------------------方法一: 堆排(对比BFPRT)---------------------------
// 获取前k小的数构成数组返回
public static int[] getMinKNumsByHeap(int[] arr, int k) {
if(k < 1 || k > arr.length) return null; // k越界返回null
int[] minArr = new int[k];
// 抽取arr的前k个元素构成大顶堆
for (int i = 0; i < k; i++) {
heapInsert(minArr,arr[i],i);
}
// 然后从arr数组的第k个元素开始
for (int i = k; i < arr.length; i++) {
// 如果当前元素小于堆顶,则插入
if(arr[i] < minArr[0]) {
minArr[0] = arr[i];
// 调整数组,重新构造大顶堆
heapfiy(minArr,0,k);
}
}
return minArr;
}
// 调整数组为大顶堆
private static void heapfiy(int[] minArr, int index, int heapSize) {
int left = 2 * index + 1;
int right = 2 * index + 2;
int largest = index;
while(left < heapSize){
if(minArr[index] < minArr[left]){
largest = left;
}
if(right < heapSize && minArr[right] > minArr[largest]){
largest = right;
}
if(largest == index){
// 没有调整
break;
}
swap(minArr,index,largest);
// 更新各个指针的位置
index = largest;
left = index * 2 + 1;
right = index * 2 + 2;
}
}
private static void heapInsert(int[] minArr, int value, int index) {
minArr[index] = value;
while(index != 0){
int parent = (index - 1) / 2;
if(minArr[index] > minArr[parent]){
swap(minArr,index,parent);
index = parent;
}else{
break;
}
}
}
private static void swap(int[] arr, int x, int y){
int temp = arr[x];
arr[x] = arr[y];
arr[y] = temp;
}
//****************** 测试 ************************
public static void main(String[] args) {
// 生成随机数组,用作测试
int[] arr = ArrayTestUntil.generateRandomArray(10000,10000);
long s1 = System.currentTimeMillis();
int[] res = getMinKNumsByHeap(arr, 20);
System.out.println(Arrays.toString(res));
System.out.println("普通归并排序用时: " + (System.currentTimeMillis() - s1) + "ms");
}
BFPRT算法
基本思路:
- BFPRT的主要思想就是修改快排中的随机阈值pivot,将这个阈值尽可能的定在中间,加快快排的搜索过程。
- 确定阈值pivot的方法: 将数组以5个一组的方式进行分组,计算出每5个元素的中位数,然后对子数组以这种方式递归求解,知道最后只剩余一个元素,就是最终的pivot。
- 然后以pivot作为阈值进行快速排序。
- 然后确定第k小的值。
/**
* 解法二: 利用BFPRT算法实现O(N)级别的时间复杂度
*/
/**
* 函数入口: 获取前k小的元素数组
* @param arr
* @param k
* @return
*/
public static int[] getMinKNumsByBFPRT(int[] arr, int k) {
// 首先同样进行判断,如果K越界,则直接返回数组
if (k < 1 || k > arr.length) {
return arr;
}
// 获取第k小的元素值
int minKth = getMinKthByBFPRT(arr, k);
// 创建可以容纳k这么大的数组
int[] res = new int[k];
int index = 0;
// 将数组元素向res中添
for (int i = 0; i != arr.length; i++) {
if (arr[i] < minKth) {
res[index++] = arr[i];
}
}
// 有可能走到arr[i] 刚好等于minKth的位置,则后面一路相等即可
for (; index != res.length; index++) {
res[index] = minKth;
}
// 返回结果
return res;
}
/**
* 获取数组中第k小的元素值(同时将数组元素已经按照中间的值拍好序了)
* @param arr
* @param K
* @return
*/
public static int getMinKthByBFPRT(int[] arr, int K) {
// 利用复制好的数组执行
int[] copyArr = copyArray(arr);
// 挑选第k小的值
return select(copyArr, 0, copyArr.length - 1, K - 1);
}
/**
* 复制数组(不破坏原来数组的结构)
* @param arr
* @return
*/
public static int[] copyArray(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i != res.length; i++) {
res[i] = arr[i];
}
return res;
}
/**
* 从给定数组中挑选出中位数
* @param arr 指定数组
* @param begin 起始位置
* @param end 终止位置
* @param i 选取第i小的元素
* @return
*/
public static int select(int[] arr, int begin, int end, int i) {
if (begin == end) {
return arr[begin];
}
// 实现了一个递归调用,获取中位数(全局最好的pivot)
int pivot = medianOfMedians(arr, begin, end);
// 用这个中位数实现快排
int[] pivotRange = partition(arr, begin, end, pivot);
// 如果刚刚好查找的数就在中间位置,直接返回arr[i]
if (i >= pivotRange[0] && i <= pivotRange[1]) {
return arr[i];
} else if (i < pivotRange[0]) {
// 如果i位置小于less,则向左进行递归调用
return select(arr, begin, pivotRange[0] - 1, i);
} else {
// 如果位置大于more,则向右进行递归调用
return select(arr, pivotRange[1] + 1, end, i);
}
}
/**
* 快速获取中位数(全局最优的快排输入值)
* @param arr
* @param begin
* @param end
* @return
*/
public static int medianOfMedians(int[] arr, int begin, int end) {
// 数组总长度
int num = end - begin + 1;
// 每五个一组,查看是否有多余的数,有的话则单独成一位
int offset = num % 5 == 0 ? 0 : 1;
// 创建存储每五个数据排序后中位数的数组
int[] mArr = new int[num / 5 + offset];
// 遍历此数组
for (int i = 0; i < mArr.length; i++) {
// 当前mArr来源自原来数组中的起始位置
int beginI = begin + i * 5;
// 当前mArr来源自原来数组中的终止位置
int endI = beginI + 4;
// 计算出当前i位置5个数排序后的中位数
mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
}
// 在这些中位数的点中,挑选出排好序之后的中位数返回
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}
/**
* 快排主体: 小于pivotValue放在左边,等于pivotValue放在中间,大于pivotValue放在右边
* @param arr
* @param begin
* @param end
* @param pivotValue 快排选取的元素
* @return
*/
public static int[] partition(int[] arr, int begin, int end, int pivotValue) {
int small = begin - 1;
int cur = begin;
int big = end + 1;
while (cur != big) {
if (arr[cur] < pivotValue) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivotValue) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1;
range[1] = big - 1;
return range;
}
/**
* 获取数组排序后的中位数
* @param arr
* @param begin
* @param end
* @return
*/
public static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);
int sum = end + begin;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}
/**
* 实现简单的插入排序
* @param arr
* @param begin
* @param end
*/
public static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i != end + 1; i++) {
for (int j = i; j != begin; j--) {
if (arr[j - 1] > arr[j]) {
swap(arr, j - 1, j);
} else {
break;
}
}
}
}
/**
* 交换
* @param arr
* @param index1
* @param index2
*/
public static void swap(int[] arr, int index1, int index2) {
int tmp = arr[index1];
arr[index1] = arr[index2];
arr[index2] = tmp;
}
public static void printArray(int[] arr) {
for (int i = 0; i != arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] arr = ArrayTestUntil.generateRandomArray(10000,10000);
// sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
long s1 = System.currentTimeMillis();
printArray(getMinKNumsByHeap(arr, 2000));
System.out.println("普通归并排序用时: " + (System.currentTimeMillis() - s1) + "ms");
long s2 = System.currentTimeMillis();
printArray(getMinKNumsByBFPRT(arr, 2000));
System.out.println("BFPRT用时: " + (System.currentTimeMillis() - s2) + "ms");
}
网友评论