20180409胡策C 线段树优化DP

xiaoxiao2021-02-28  32

题目描述:

题解:

O(n2) O ( n 2 ) DP应该一眼秒吧…… f[i] f [ i ] 表示前i的最小花费, sw[i] s w [ i ] 表示w的前缀和,那么有DP方程: f[i]=min(f[i],f[j]+(h[i]h[j])2+sw[i1]sw[j])(1<=j<i),f[1]=0 f [ i ] = m i n ( f [ i ] , f [ j ] + ( h [ i ] − h [ j ] ) 2 + s w [ i − 1 ] − s w [ j ] ) ( 1 <= j < i ) , f [ 1 ] = 0 ,然后再化一下,得到 f[i]=h[i]2+sw[i1]+min{h[i]×2h[j]+f[j]+h[j]2sw[j]} f [ i ] = h [ i ] 2 + s w [ i − 1 ] + m i n { h [ i ] × − 2 h [ j ] + f [ j ] + h [ j ] 2 − s w [ j ] } ,把 h[i] h [ i ] 看做是直线的x, 2h[j] − 2 h [ j ] 看作k, f[j]+h[j]2sw[j] f [ j ] + h [ j ] 2 − s w [ j ] 看作b,发现就是求当 x=h[i] x = h [ i ] 时,最低的直线是多少,这个就可以用线段树维护了,具体可以看bzoj3165这道题。

代码:

#include<bits/stdc++.h> using namespace std; #define LL long long #define pa pair<int,int> const int Maxn=150010; const LL inf=(1LL<<60); LL read() { LL x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar(); return x*f; } struct Seg { int lc,rc,l,r,tag; }tr[2000010]; struct Line { LL k,b; Line(LL _k=0,LL _b=0){k=_k,b=_b;} }L[Maxn];int ll=0; LL get(int p,LL x){return L[p].k*x+L[p].b;} int tot=0; void build(int l,int r) { int t=++tot; tr[t].l=l;tr[t].r=r;tr[t].tag=-1; if(l==r)return; int mid=l+r>>1; tr[t].lc=tot+1;build(l,mid); tr[t].rc=tot+1;build(mid+1,r); } void insert(int x,int p) { if(tr[x].tag==-1) { tr[x].tag=p; return; } int lc=tr[x].lc,rc=tr[x].rc,mid=tr[x].l+tr[x].r>>1; if(get(p,(LL)tr[x].l)>=get(tr[x].tag,(LL)tr[x].l)&& get(p,(LL)tr[x].r)>=get(tr[x].tag,(LL)tr[x].r))return; if(get(p,(LL)tr[x].l)<get(tr[x].tag,(LL)tr[x].l)&& get(p,(LL)tr[x].r)<get(tr[x].tag,(LL)tr[x].r)){tr[x].tag=p;return;} if(get(p,(LL)mid)<get(tr[x].tag,(LL)mid)) { if(get(p,(LL)tr[x].r)<get(tr[x].tag,(LL)tr[x].r))insert(lc,tr[x].tag); else insert(rc,tr[x].tag); tr[x].tag=p; } else { if(get(p,(LL)tr[x].l)<get(tr[x].tag,(LL)tr[x].l))insert(lc,p); else insert(rc,p); } } LL query(int x,LL X) { if(tr[x].tag==-1)return inf; if(tr[x].l==tr[x].r)return get(tr[x].tag,X); int lc=tr[x].lc,rc=tr[x].rc,mid=tr[x].l+tr[x].r>>1; LL re=get(tr[x].tag,X); if(X<=mid)return min(re,query(lc,X)); else return min(re,query(rc,X)); } LL n,h[Maxn],w[Maxn],f[Maxn]; LL sw[Maxn]; int main() { n=read();sw[0]=0; for(int i=1;i<=n;i++)h[i]=read(); for(int i=1;i<=n;i++)w[i]=read(),sw[i]=sw[i-1]+w[i]; build(0,1e6); f[1]=0; L[++ll]=Line(-2LL*h[1],f[1]+h[1]*h[1]-sw[1]); insert(1,ll); for(int i=2;i<=n;i++) { f[i]=h[i]*h[i]+sw[i-1]+query(1,h[i]); L[++ll]=Line(-2LL*h[i],f[i]+h[i]*h[i]-sw[i]); insert(1,ll); } printf("%lld",f[n]); }
转载请注明原文地址: https://www.6miu.com/read-2632010.html

最新回复(0)