Method 1:
class Solution {
public:
int reversePairs(vector<int>& nums) {
return mergeSort(nums, 0, nums.size() - 1);
}
int mergeSort(vector<int>& nums, int start, int end) {
if (start >= end) {
return 0;
}
int mid = (start + end) / 2;
int lc = mergeSort(nums, start, mid);
int rc = mergeSort(nums, mid + 1, end);
int cnt = count(nums, start, mid, end);
return lc + rc + cnt;
}
int count(vector<int>& nums, int start, int mid, int end) {
int l = start, r = mid + 1;
int count = 0;
while(l <= mid && r <= end){
if((long)nums[l] > (long)2 * nums[r]){
count += (mid - l + 1);
r++;
}else{
l++;
}
}
sort(nums.begin() + start, nums.begin() + end + 1);
return count;
}
};
Method 2:
class Solution {
public:
class BSTNode {
public:
BSTNode(int val) {
this->val = val;
this->count = 1;
this->less = 0;
this->left = this->right = nullptr;
}
int val;
int less;
int count;
BSTNode* left;
BSTNode* right;
};
int reversePairs(vector<int>& nums) {
// return mergeSort(nums, 0, nums.size() - 1);
int rst = 0;
BSTNode* root = nullptr;
for (int i = nums.size() - 1; i >= 0; i--) {
rst += searchLesser(root, nums[i] / 2.0);
root = buildBST(root, nums[i]);
}
return rst;
}
BSTNode* buildBST(BSTNode* node, int val) {
if (!node) {
return new BSTNode(val);
}
if (val < node->val) {
node->less++;
node->left = buildBST(node->left, val);
} else if (val > node->val){
node->right = buildBST(node->right, val);
} else {
node->count++;
}
return node;
}
int searchLesser(BSTNode* node, double target) {
if (!node) {
return 0;
}
if (target > node->val) {
return node->less + node->count + searchLesser(node->right, target);
} else if (target < node->val) {
return searchLesser(node->left, target);
} else {
return node->less;
}
}
};
网友评论