给定两个有序数组,让我们从每个数组中任意取出一个数字来组成不同的数字对,返回前k个和最小的数字对。
直接的思路就是直接将所有pair排序,复杂度O(nlogn),n为两个数组的平均长度; 因为只求前k个最小的,所以稍微改进一点的就是类似347. Top K Frequent Elements的思路, 利用堆复杂度O(nlogk),利用快排复杂度O(n),可参考347题解。
上面的思路复杂度最小可至O(n),但是都没有利用题目给的两个数组都是有序的这个条件,所以还有优化空间。 参考讨论区, 我们可以根据数组有序这个条件以O(klogk)的复杂度解此题。
由于数组都是有序的,所以和最小的pair只能是从
(nums1[0], nums2[0]), (nums1[1], nums2[0]), (nums1[2], nums2[0]).....(nums1[k-1], nums2[0])
这k个候选者中取 (这里我们暂时假设k不超过nums1的大小),假设就是(nums1[i], nums2[0]);那么下一个最小的pair该怎么选呢:应该是从刚刚剩下的k-1个候选者和 (nums1[i], nums2[1])这k个候选者中选最小的那个,假设就是(nums1[j], nums2[k]);那么再下一个最小的pair该怎么选呢: 应该是从刚刚剩下的k-1个候选者和(nums1[j], nums2[k+1])这k个候选者中选最小的那个, 以此类推直到选取了k个。 上述过程我们只需要维护一个大小为k的最小堆就可以了,建堆时间复杂度O(k),然后有k次pop和push共O(klogk),所以总的复杂度为O(klogk)。
上述思想的核心就是将所有的pair分成k类:第一个数为nums1[0]、第一个数为nums1[k]、......第一个数为nums1[k-1];每一类又是一个从小到大排好序的队列。 我们每次只需要将这k个队的队头进行比较然后选取最小的那个就可以了,这有点类似多路归并排序的归并过程。
几个注意点:
- 自定义结构体如何使用
priority_queue
,有三种常见的写法,见代码。 - k有可能比较大以至于超过了所有pair数。
时间复杂度为O(klogk),空间复杂度O(k)。
struct mypair{
int a, b, sum;
int b_idx; // 在nums2中的idx
mypair(const int a_, const int b_, const int b_idx_){
a = a_; b = b_; b_idx = b_idx_;
sum = a_ + b_;
}
// // 方法一 重载运算符<,写在结构体里面(注意后面必须加const修饰!!!)
// // 后面必须加const修饰, 否则报错!!!
// bool operator < (const mypair &p) const{
// return sum >= p.sum;
// }
};
// // 方法二 重载运算符<,写在结构体外面
// bool operator < (const mypair &p1, const mypair &p2){
// return p1.sum >= p2.sum;
// }
// 方法三 定义一个比较类, 重载operator()
struct cmp{
bool operator () (const mypair &p1, const mypair &p2){
return p1.sum >= p2.sum;
}
};
class Solution {
public:
vector<vector<int>> kSmallestPairs(vector<int>& nums1, vector<int>& nums2, int k) {
int size1 = nums1.size(), size2 = nums2.size();
int p1 = 0, p2 = 0;
k = min(k, size1 * size2); // k有可能比较大以至于超过了所有pair数
// priority_queue<mypair>min_heap; // 方法一、二
priority_queue<mypair, vector<mypair>, cmp>min_heap; // 方法三
for(int i = 0; i < size1 && i < k; i++)
min_heap.push(mypair(nums1[i], nums2[0], 0));
vector<vector<int>>res;
while(k--){
mypair tmp = min_heap.top(); min_heap.pop();
res.push_back({tmp.a, tmp.b});
int next_b_idx = tmp.b_idx + 1;
if(next_b_idx < size2)
min_heap.push(mypair(tmp.a, nums2[next_b_idx], next_b_idx));
}
return res;
}
};