稀有猿诉

十年磨一剑,历炼出锋芒,说话千百句,不如码二行。

TopK问题,堆和快速选择

TopK问题是很常见的一种问题,它的描述是从一个数据集或者序列中取出前k大(或者前k小),或者说找出第k大(第k小)。最为典型的就是 题215. 数组中的第K个最大元素。解决TopK需要的是最基础的数据结构和算法,不但可以考查编码基本功,更能考查思维能力。

为了方便,后面就以找前k大为主要示例:输入长度为n的整数数组,找出前k大的数,1 <= k <= n。

排序大法

解决TopK问题,最简单也是最为暴力的做法就是排序,如果数据是有序的,无论你想找前k大或者第k大,都是非常容易的了。

问题就转化为排序问题了,至于排序有O(n2)的冒泡,选择和插入, 以及高效一些的归并和快速排序。如果是特殊数据集还可以用计数排序(也叫桶排序)。关于排序算法的教程太多了,就不重复了,可以参考Yu神的 十大排序从入门到入赘

用排序来解决TopK问题可行但并不高效,比如k特别小时,n特别大时效率就会特别差。甚至,对于序列(也就是输入数据接近无限)时,可能没有办法先排序再去选择前k大了。

Heap

堆是一个逻辑上的二叉树式的数据结构,但实现上通常用数组来实现,它保证根节点是所有元素中最大的称作最大堆或者大根堆,或者最小的称作最小堆或者小根堆。有些地方也称之为优先队列,比如在大Java中的就叫做PriorityQueue

以最大堆为例,它保证根节点永远不小于两个子节点,假如堆的大小(也即元素总数)是k,那么根节点就是这k个元素中的最大值,维护一次堆(Heapify)的代价是log(k),只需要不断的比较根节点和子节点即可,所以复杂度是二叉树的高度即log(k)。对于TopK问题,可以创建一个大小为k的最小堆,把n个数都填到堆里,当堆未满时,直接塞,如果满时了,堆顶是最小值,如果新元素小于最小值可直接跳过,它不可能成为TopK;否则先移除堆顶然后再塞,最后堆里面剩下的就是前k大元素,这样复杂度会降到nlog(k),当n特别大,k远小于n时,或者说对n接近无穷的序列时,用最小堆的效率会明显的高于排序大法。

堆(优先队列)是一种非常常见且基础的数据结构,标准库中都有,可以拿来就用,但是学习手撸一个堆更能加深理解。

堆的实现

来手撸一个最大堆。最常见的就是二叉堆,也就是说逻辑上是一个二叉树,但实际的存储一般是用数组,索引0就是根节点root(又叫堆顶),索引i它的左子节点是在索引2*i+1,右子节点是在2*i+2。

需要不断的维护堆的特性,也即是它的根节点总是大于两个子节点,要时刻保持这种性质。主要难点在于向堆中添加一个元素时,先把此元素放在数组最后,也即树中最右下的叶子节点,然后不断的向上更新:如果此元素大于其父节点,就互换直到它小于其父节点。

另外需要维护的地方就是移除堆顶,堆顶是堆中的最大元素,它大于其两个子节点。大哥没了,就要重新选大哥:因为逻辑上是一个二叉树,所以只需要解决一个最小的树即可,其余可以递归处理。从父节点,左子节点和右子节点中取最大的,与父节点互换,然后再递归处理刚刚转换过的子树,即可。

废话这么多,其实代码比较精简,也较容易理解,还是直接上代码吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/**
 * A bound Max Heap of int type elements.
 * This is a binary heap with array as the underlying container.
 */
public class MaxHeap {
    public static final int INF = Integer.MAX_VALUE;

    private final int capacity;
    private int size;
    /*
     * Put the elements into an array, but the logical relationship is a binary tree.
     * 0 is the root;
     * i's left child is 2*i + 1, right child is 2*i + 2;
     * i's parent is (i-1) / 2.
     */
    private final int[] elements;

    public MaxHeap(int capacity) {
        this.capacity = capacity;
        size = 0;
        elements = new int[capacity];
    }

    /**
     * Nothing happens if heap is full.
     */
    public void offer(int e) {
        if (isFull()) {
            // Overflowed.
            return;
        }
        /*
         * Put the new element at the end of the heap.
         * Push it up until it is less than its parent.
         */
        size++;
        int i = size - 1;
        elements[i] = e;

        while (i != 0 && elements[parent(i)] < elements[i]) {
            swap(i, parent(i));
            i = parent(i);
        }
    }

    public int heapSize() {
        return size;
    }

    public int peek() {
        if (isEmpty()) {
            return INF;
        }
        return elements[0];
    }

    public void clear() {
        size = 0;
    }

    public int poll() {
        if (isEmpty()) {
            return INF;
        }
        if (size == 1) {
            size--;
            return elements[0];
        }

        /*
         * Root is the max value in the heap, will remove and return it to caller.
         * Push down the tree and select the max of left and right as the new parent.
         */
        int root = elements[0];
        elements[0] = elements[size - 1];
        size--;
        heapify(0);
        return root;
    }

    public boolean isEmpty() {
        return size == 0;
    }

    public boolean isFull() {
        return size == capacity;
    }

    /*
     * Heapify the sub-tree rooted with index i.
     * Find the largest value of parent, left and right;
     * If the parent is the largest, we are done.
     * Swap parent with the largest node, now parent is the largest;
     * Keep heapifying the swapped sub-tree.
     */
    private void heapify(int i) {
        int l = left(i);
        int r = right(i);
        int largest = i;
        if (l < size && elements[l] > elements[i]) {
            largest = l;
        }
        if (r < size && elements[r] > elements[largest]) {
            largest = r;
        }
        if (largest != i) {
            swap(i, largest);
            heapify(largest);
        }
    }

    private void swap(int i, int j) {
        if (i == j) {
            return;
        }
        int t = elements[i];
        elements[i] = elements[j];
        elements[j] = t;
    }

    private int parent(int i) {
        return (i - 1) / 2;
    }

    private int left(int i) {
        return (i << 1) + 1;
    }

    private int right(int i) {
        return (i << 1) + 2;
    }
}

完整代码在这里

这就是最基础的一种二叉堆(Binary Heap)的实现。注意基础堆是用于快速找最大值或者最小值,是O(1)的,其他的操作如查询非最大值或者最小值,或者移除某个特定的元素,效率差会变成O(n)的。

为此,还有其他的实现方式如Binomial HeapFibonacci Heap,这两种堆除了保证堆的基本特质外,还能把其他的操作也降低到log(n)的复杂度。

堆的应用

一是用来排序,通常称作堆排序,把n个元素都入堆,然后依次把堆顶取出来,这样就能得到一个有序数组了。复杂度是nlog(n)。

另外,就是用于解决topK问题了。更为实际一点的应用就是Job Scheduling,把一坨Job加入堆中,每次取堆顶(优先级最高的Job)来执行。

Quick select

快速选择是快速排序衍生出来的一个算法,专门适用以线性复杂度O(n)来解决TopK问题。为此我们先复习快速排序算法,然后再解释快速选择原理。

快速排序

这是一个非常经典又基础的算法,是算法入门的必讲算法。快速排序的核心思想是分治(Divide and Conquer),核心技巧是分区(partition),选取一个轴元素作为分界点(pivot),把小于轴的元素都放在它左边,把大于它的元素都放在其右边,然后再用同样的方法处理左边和右边。伪码如下:

1
2
3
4
5
6
void quickSort(int[] arr, int start, int end) {
    if (start == end) return
    int p = partition(arr, start, end);
    quickSort(arr, start, p-1);
    quickSort(arr, p, end);
}

分区

分区partition是快排的核心技巧,当然也是快速选择的核心,它是先选出一个轴元素pivot,然后以它为界把数组分成两段。比如说数组arr = [5,3,7,1,8,2,9,4]。如果选择索引位置0,元素5作为pivot,那么partition之后的数组会变为arr=[3,1,2,4,5,7,8,9],partition的返回值,是pivot元素在分区之后的新索引p,即此例中的索引4。可以看出经过partition后,数组左区[0,p-1]都是小于pivot的,而右区[p,n-1]则是大于等于pivot的。这就是分区的作用。

分区算法轴元素的选择至关重要,为了达到最好的效果,在区间内随机选择一个索引位置的元素作为pivot是最理想的,摊还分析后可以达到O(n)。快排的复杂则是nlog(n)。

对于数组arr,做partition的具体做法是:

  1. 随机选择一个元素为轴元素,记其索引为pivot
  2. 先把pivot与最后一个元素交换swap(arr, pivot, end),注意交换后轴元素在end,即arr[end]
  3. 用双指针,左指针left总是指针向小于轴元素arr[end]的最后一个元素,也即分区好了时的左边界的最后一个位置。
  4. 右指针right则从start开始,遍历到end - 1,如果arr[right]小于轴,即arr[right]<arr[end],则交换并更新左指针
  5. 最后left索引即是轴应该在的索引,与轴交换swap(arr, left, end)
  6. 返回left。这是分区后的轴所在的位置。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private int partition(Random random, int[] nums, int start, int end) {
        int pivot = random.nextInt(end - start + 1) + start;
        swap(nums, pivot, end);
        int left = start - 1;
        for (int right = start; right < end; right++) {
            if (nums[right] < nums[end]) {
                left++;
                swap(nums, left, right);
            }
        }

        left++;
        swap(nums, left, end);

        return left;
    }

记住,分区返回一个轴的索引,轴左边的元素都小于轴,轴右边的元素都大于轴。这是快速排序和快速选择的核心奥妙精华所在。

快速选择

基于分区就能开发出快速选择算法。对于长度为n的数组arr,进行partition后,得到一个轴的位置pivot,[0,pivot-1]都小于arr[pivot],而[pivot+1,n-1]都大于arr[pivot]。那么,对于想找出前k大的TopK问题而方,如果pivot=n-k,那么[pivot, n - 1]分区后的右边部分不就刚好前k大元素么?

有同学举手问了,咋可能那么巧嘛。这位同学请先坐下,不巧也没关系,如果pivot大于n-k,说明比pivot大的数不够k个,就得往左找,所以在左部分递归处理就可以了;同理,如果pivot小于n-k,说明右部分太多了,往右找即可。代码大概这样子的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public int findKthLargest(int[] nums, int k) {
        Random random = new Random();
        int target = nums.length - k;
        int start = 0;
        int end = nums.length - 1;
        int index = partition(random, nums, start, end);
        while (index != target) {
            if (index > target) {
                end = index - 1;
            } else {
                start = index + 1;
            }
            index = partition(random, nums, start, end);
        }

        return nums[index];
    }

这是迭代式的,看起来可能不那么直观,我们用递归来写,就相当直观了:

1
2
3
4
5
6
7
8
9
10
11
12
13
int quickSelect(arr, start, end, k) {
    if (start == end) {
         return start;
    }
    int p = partition(arr, start, end);
    if (p == n - k) {
        return p;
    } else if (p < n - k) {
        return quickSelect(arr, p + 1, end, k);
    } else {
        return quickSelect(arr, start, p - 1, k - p);
    }
}

总结

TopK问题是非常常见且基础的一个问题,通常是融合在了其他问题里面,不会以比较直观的方式求TopK。如果是问题中的一个子问题,那么通常用堆来当作辅助数据结构是最优的做法。如果TopK问题是最问题的最后一步的话,那么排序或者用快速选择也是可以的。

典型问题

题目 题解 说明
215. 数组中的第K个最大元素 题解 典型TopK问题
23. 合并 K 个升序链表 题解
239. 滑动窗口最大值 题解
347. 前 K 个高频元素 题解
973. 最接近原点的 K 个点 题解
324. 摆动排序 II 题解
2462. 雇佣 K 位工人的总代价 题解
题解
题解

参考资料

Comments