题目:洛谷P5858 「SWTR-03」Golden Sword
看完题,就感觉是用dp解决,和背包问题有点像。如果简单地用f[i]表示放入第i种原料时宝剑的最大耐久度,显然不满足无后效性,因为对于一个耐久度是负数的材料,它需要锅里的材料尽可能少;反之对于大耐久度的材料,它需要锅里的材料尽可能多。如果遇上负耐久的一股脑的全拿走,后面有大的正耐久度的材料,总耐久度就不大了;如果一直不取,后面来一个负的大耐久,总耐久也不是最大。
如何消除后效性,做法之一就是给状态加一维,变成f(i,j),表示第i个材料,当锅里有j个材料时,宝剑最大耐久度。这样的话后效性就被消除了,变成了无后效性,最终结果就是max(f(n,k)),k属于[1,w]。转移方程:
f(i,j)=max(f(i-1,k)+a[i]*j),k属于[max(j-1,1),min(i-1, j+s-1,w)]
因为在放入第i个材料时,如果一个都不取出,那么放第i-1个材料时候锅里有j-1个,加上新放入的第i个,正好是j个;如果取出s个,那么说明原来有j-1+s个。
j最大就是w,因为锅里只能放w个,先手画图感受一下样例3:
7 4 2
-5 3 -1 -4 7 -6 5
x | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | -5 | |||
2 | -2 | 1 | ||
3 | 0 | -1 | -2 | |
4 | -4 | -8 | -13 | -18 |
5 | 3 | 10 | 13 | 15 |
6 | 4 | 1 | -3 | -9 |
7 | 9 | 14 | 16 | 17 |
根据此方程写出代码:
#include <cstdio>
#include <climits>
#include <algorithm>
#define MAX 5500
typedef long long ll;
using namespace std;
int n, w, s, a[MAX+1];
ll f[MAX+1][MAX+1];
void dp(){
f[1][1] = a[1];
for(int i = 2; i <= n; i++){
for(int j = 1; j <= min(i, w); j++){
//当第i个材料放到锅里以后,形成了j个,那么原来可能有多少个?
//原来有j-1个,不用取,直接放入1个,就形成了j,原来不可能有0个
//原来有j-1+s个,取走s,再放入1个,就形成了j,原来不可能有大于w个
ll m = LLONG_MIN;
for(int k = max(j-1, 1); k <= min(i-1, min(j-1+s, w)); k++){
ll tmp = f[i-1][k];
if (m < tmp){
m = tmp;
}
}
f[i][j] = m + (ll)a[i] * j;
}
}
}
int main(){
scanf("%d%d%d", &n, &w, &s);
for(int i = 1; i <= n; i++){
scanf("%d", a + i);
}
dp();
ll ans = LLONG_MIN;
for(int i = 1; i <= w; i++){
ans = max(ans, f[n][i]);
}
printf("%lld\n", ans);
return 0;
}
单调队列优化
上面做法由于3层循环,复杂度是O(n(w平方)),只能85分,当s和w很大时候,出不来结果。
考虑优化,最外层循环肯定不能优化,总要从1看到n,内层两个循环,可以看成是一个大小为s+1的窗口(取的数目是[0,s],所以窗口大小是s+1)不断向右滑动,每次设置f(i,j)时都是取出这个窗口里的最大值,使用单调队列,可以将内层两个循环的复杂度从w平方变成w。
#include <cstdio>
#include <climits>
#include <algorithm>
#define MAX 5500
typedef long long ll;
using namespace std;
int n, w, s, a[MAX+1];
ll f[MAX+1][MAX+1];
int q[MAX], qi, qj;
void dp(){
f[1][1] = a[1];
for(int i = 2; i <= n; i++){
qi = qj = 0;//初始化单调队列
//把1到s-1入队,肯定不需要队头出队
int last;
for(last = 1; last < s && last <= i - 1; last++){
while(qi < qj && f[i-1][q[qj-1]] <= f[i-1][last]){
qj--;
}
q[qj++] = last;
}
last--;
for(int j = 1; j <= min(i, w); j++){
int st = max(j-1, 1);
if (qi < qj && q[qi] < st){
qi++;
}
int end = min(i-1, min(w, j - 1 + s));
if (last < end){
while(qi < qj && f[i-1][q[qj-1]] <= f[i-1][last+1]){
qj--;
}
q[qj++] = (++last);
}
f[i][j] = f[i-1][q[qi]] + (ll)a[i] * j;
}
}
}
int main(){
scanf("%d%d%d", &n, &w, &s);
for(int i = 1; i <= n; i++){
scanf("%d", a + i);
}
dp();
ll ans = LLONG_MIN;
for(int i = 1; i <= w; i++){
ans = max(ans, f[n][i]);
}
printf("%lld\n", ans);
return 0;
}