斜率优化DP模板题--HDU3507 Print Article

xiaoxiao2021-02-28  13

题目大意:给出一串正数,将这一串数分成若干段,每一段的代价为这一段数的加和的平方+一个常数m,求最小代价。

不难写出DP转移方程: f[i]=max(f[j]+(sum[i]sum[j])2+m) 但是,这样复杂度是 N2 的,对于 N500,000 显然是不行的,对于这种转移方程与i,j都有关的,可以用斜率优化来优化。

考虑k < <script type="math/tex" id="MathJax-Element-4887"><</script>j < <script type="math/tex" id="MathJax-Element-4888"><</script>i,若j比k更优,则

f[j]+(sum[i]sum[j])2+m<f[k]+(sum[i]sum[k])2+m ,f[j]+sum[j]2(f[k])+sum[k]22sum[j]2sum[k]<sum[i] 如果将 f[j]+sum[j]2 看成纵坐标, 2sum[j] 看成横坐标,那么原来的不等式就可以表示为 y[j]y[k]x[j]x[k]<sum[i] 没错,这样不等式左边就是斜率的形式,如果斜率小于sum[i],那么就说明j状态比k状态更优。

由于sum[i]是递增的,每次新加入的点一定在最右边,所以可以很方便的利用队列优化到 O(N) 的复杂度。还是考虑k < <script type="math/tex" id="MathJax-Element-4895"><</script>j < <script type="math/tex" id="MathJax-Element-4896"><</script>i,若 k(k,j)>k(j,i) 那么j这个状态就无用了,理由如下: 对于之后的任意一个状态u,若 k(j,i)<sum[u] ,那么状态i显然比状态j优 若 k(j,i)sum[u] ,那么虽然j比i更优,但是 k(k,j)>k(j,i)sum[u] ,k比j更优 这样,队列中的点一定是下凸的(斜率随x的增大而增大),每次新加入一个点,就把所有k(x-1,x)>k(x,u)的节点x删除。由于曲线是下凸的,所以把前面的所有k小于sum[u]的节点删除以后,剩下的队列头就是最优状态。

代码如下:

#include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define maxn 500006 #define LL long long using namespace std; inline char nc(){ static char buf[100000],*p1=buf,*p2=buf; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } inline LL _read(){ LL sum=0;char ch=nc(); while(!(ch>='0'&&ch<='9')){ch=nc();if(ch==EOF)return -1e9;} while(ch>='0'&&ch<='9')sum=sum*10+ch-48,ch=nc(); return sum; } LL n,m,hed,tal,que[maxn],px[maxn],py[maxn],f[maxn],a[maxn],sum[maxn]; double k(int x,int y){ if(px[x]==px[y])return py[x]>py[y]?-1e9:1e9; return (py[y]-py[x])/(px[y]-px[x]); } void pop(double x){while(hed<tal&&k(que[hed],que[hed+1])<x)hed++;} void push(int x){ while(tal>hed&&k(que[tal-1],que[tal])>k(que[tal],x))tal--; que[++tal]=x; } int main(){ freopen("print.in","r",stdin); freopen("print.out","w",stdout); while((n=_read())!=-1e9){ m=_read(); for(int i=1;i<=n;i++)sum[i]=sum[i-1]+(a[i]=_read()); f[1]=a[1]*a[1]+m;hed=0;que[tal=1]=1;px[1]=2*sum[1];py[1]=f[1]+sum[1]*sum[1]; for(int i=2;i<=n;i++){ pop(sum[i]);f[i]=f[que[hed]]+(sum[i]-sum[que[hed]])*(sum[i]-sum[que[hed]])+m; px[i]=2*sum[i];py[i]=f[i]+sum[i]*sum[i]; push(i); } printf("%lld\n",f[n]); } return 0; }
转载请注明原文地址: https://www.6miu.com/read-1100101.html

最新回复(0)