描述
一个长度为 的非负整数序列,定义
表示从该序列的区间
选择若干不相邻的数的和的最大值。
求所有 的和,最终结果对
取模
分析
这道题我前前后后提交了将近 20 次才过,心态都崩了。首先,对于给定的 我们很容易想到用动态规划去求
(leetcode打家劫舍问题),我们可以用这种方式,分别求出所有的
然后计算它们的和。但是这种做法的复杂度为
,当数据范围达到
时会超时。考虑分治,对于区间
,令
,我们可以将其子区间分为三类:
- 左端点和右端点都在
上。
- 左端点和右端点都在
上。
- 左端点在
上,右端点在
上。
对于第一、二种情况,我们直接递归到该子区间以同样的方法求解,于是只需要考虑第三种情况。分别求出左/右边每个位置为左/右端点时 选不选时的最优答案,记为
。对于一对
的答案即为
。首先可以考虑
的情况,这种情况下始终可以用
替代
,因为
对
没有限制,所以
选
与
中较大的即可。当
,我们比较的就是
这两个值的大小,移项之后其实就是比较
的大小。分析到这一步思路就很清晰了,我们可以事先计算出所有的
,将其按照差值的大小排序,之后采用双指针法。先处理
的情况,移动左指针
直至
,对于这
个子区间,我们始终选择
。而对于右边的区间我们选择两者中的较大值,同样移动右指针
直至
,前面
个区间,我们选择
,后面
个区间选择
,引入两个变量
分别记录
之前所有
之和,和
及其之后所有
之和。这一轮处理对结果的贡献就是
(
为右半区间元素个数)。剩下的就是
的情况,遍历左指针,每次将右指针移动到
的地方,那么对于右指针
之前的子区间,我们选择
及
之后的选择
,这一次循环对结果的贡献就是
。每次移动右指针时都要更新
与
的值,注意他们是区间的和的和,结果会非常大,记得要取模。分治解法复杂度为
代码
// 动态规划求出每一个 f(l, r),O(n^2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
int n; cin >> n;
int a[n];
for(int& x: a) cin >> x;
ll dp[n+1][n+1];
memset(dp, 0, sizeof dp);
ll ans = 0;
for(int i = 1; i <= n; i++) {
dp[i][i] = a[i-1];
ans = (ans + dp[i][i]) % mod;
}
for(int len = 2; len <= n; len++) {
for(int i = 1; i <= n-len+1; i++) {
int j = i+len-1;
dp[i][j] = max(dp[i][j-1], dp[i][j-2] + a[j-1]);
ans = (ans + dp[i][j]) % mod;
}
}
cout << ans;
return 0;
}
// 分治法,O(n*(log_n)^2)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
ll ans;
bool cmp(pair<ll, ll>& p1, pair<ll, ll>& p2) {
return p1.second - p1.first < p2.second - p2.first;
}
void solve(const int* a, int lp, int rp) {
if(lp == rp) {
ans = (ans + a[lp]) % mod;
return;
}
int mid = (lp+rp) / 2;
solve(a, lp, mid);
solve(a, mid+1, rp);
int n1 = mid-lp+1, n2 = rp-mid;
pair<ll, ll> f1[n1], f2[n2];
int k1 = 0, k2 = 0;
// 分别计算 f_{l, 0/1} 和 f_{r, 0/1}
f1[k1].first = 0; f1[k1++].second = a[mid];
for(int i = mid-1; i >= lp; i--) {
if(i == mid-1) {
f1[k1].first = a[mid-1];
f1[k1].second = max(a[mid-1], a[mid]);
}
else {
f1[k1].first = max(f1[k1-2].first + a[i], f1[k1-1].first);
f1[k1].second = max(f1[k1-2].second + a[i], f1[k1-1].second);
}
k1++;
}
f2[k2].first = 0; f2[k2++].second = a[mid+1];
for(int i = mid+2; i <= rp; i++) {
if(i == mid+2) {
f2[k2].first = a[mid+2];
f2[k2].second = max(a[mid+1], a[mid+2]);
}
else {
f2[k2].first = max(f2[k2-2].first + a[i], f2[k2-1].first);
f2[k2].second = max(f2[k2-2].second + a[i], f2[k2-1].second);
}
k2++;
}
// 根据差值大小排序
sort(f1, f1+n1, cmp);
sort(f2, f2+n2, cmp);
ll preSum = 0, postSum = 0;
// 移动右指针
k2 = 0;
while(k2 < n2 && f2[k2].first >= f2[k2].second) {
preSum = (preSum+f2[k2++].first) % mod;
}
for(int i = k2; i < n2; i++) {
postSum = (postSum+f2[i].second) % mod;
}
// 移动左指针
k1 = 0;
while(f1[k1].first >= f1[k1].second) {
ans = (ans + f1[k1].first*n2%mod + preSum + postSum) % mod;
k1++;
}
// 遍历左指针
for(; k1 < n1; k1++) {
while(k2 < n2 && f2[k2].second-f2[k2].first <= f1[k1].second-f1[k1].first) {
preSum = (preSum+f2[k2].first) % mod;
postSum = (postSum - f2[k2++].second + mod) % mod;
}
ans = (ans + k2*f1[k1].second%mod + (n2-k2)*f1[k1].first%mod + preSum + postSum) % mod;
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
int n; cin >> n;
int a[n];
for(int& x: a) cin >> x;
solve(a, 0, n-1);
cout << ans;
return 0;
}
网友评论