Description 给出一棵n个节点的树和一个1~n的排列,要求把该排列分成k个连续的段,使得每段点在树上的LCA深度之和最小 Input 多组用例,每组用例首先输入两个整数n和k,之后输入一个1~n的排列,最后n-1行每行两个整数u和v表示u和v之间有一条树边,以文件尾结束输入 (1<=k<=n<=3e5,n*k<=3e5) Output 输出将该排列分成k段后每段LCA深度之和最小值 Sample Input 6 3 4 6 2 5 1 3 1 2 2 3 3 4 4 5 4 6 Sample Output 6 Solution1 dp[i][j]表示把前i个数分成j段的LCA深度之和最小值,则有转移方程 dp[i][j]=min( dp[k][j-1] + LCA(a[k+1],…,a[i] ),j-1<=k < i 直接转移时间复杂度O(k*n^2),考虑用CDQ分治加速转移,即考虑dp[l][j-1],…,dp[mid][j-1]对dp[mid+1][j],…,dp[r][j]的影响,由于LCA有单调性,在更新[mid+1,r]中t点时,[l,mid]存在一个分界点pos,使得在pos~mid-1的任意处作为前j-1段和第j段的分界点,第j段的LCA是定值,而l~pos-1的任意处作为前j-1段和第j段的分界点,第j段的LCA和[mid+1,t]这些点无关,所以可以通过记录一些前缀后缀的最小值来O(1)转移,时间复杂度O(nklogn) Code1
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=50001,INF=4e5; int n,k,a[maxn]; struct node { int v,next; }e[2*maxn]; int head[maxn],tot; int ST[maxn<<1][20],Id[maxn<<1],Pos[maxn<<1],Deep[maxn],index,lg[maxn<<1]; void add(int u,int v) { e[tot].v=v,e[tot].next=head[u],head[u]=tot++; } void dfs(int u,int fa) { Id[++index]=u,ST[index][0]=u,Pos[u]=index; for(int i=head[u];~i;i=e[i].next) { int v=e[i].v; if(v==fa)continue; Deep[v]=Deep[u]+1; dfs(v,u); Id[++index]=u,ST[index][0]=u; } } void lca_init() { lg[1]=0; for(int i=2;i<2*maxn;i++)lg[i]=lg[i>>1]+1; Deep[1]=1; index=0; dfs(1,0); for(int j=1;j<=lg[index];j++) for(int i=1;i+(1<<j)-1<=index;i++) { int x=ST[i][j-1],y=ST[i+(1<<(j-1))][j-1]; ST[i][j]=(Deep[x]<Deep[y]?x:y); } } int lca(int a,int b) { if(Pos[a]>Pos[b])swap(a,b); int t=lg[Pos[b]-Pos[a]+1]; int x=ST[Pos[a]][t],y=ST[Pos[b]-(1<<t)+1][t]; return (Deep[x]<Deep[y]?x:y); } int ans[maxn],pre[maxn],A[maxn],B[maxn],C[maxn]; //i<=mid ,A[i]=lca(a[i],...,a[mid]) //i>mid,A[i]=lca(a[mid+1],...,a[i]) //B[i]=min(ans[j-1]+deep(A[j])),l+1<=j<=i,l+1<=i<=mid //C[i]=min(ans[j-1]),i<=j<=mid,l+1<=i<=mid void CDQ(int l,int r) { if(l==r)return ; int mid=(l+r)/2; CDQ(l,mid); A[mid]=a[mid],A[mid+1]=a[mid+1]; for(int i=mid-1;i>=l;i--)A[i]=lca(A[i+1],a[i]); for(int i=mid+2;i<=r;i++)A[i]=lca(A[i-1],a[i]); //1~mid分成k-1段,mid~i分成1段 for(int i=mid+1;i<=r;i++)pre[i]=min(pre[i],ans[mid]+Deep[A[i]]); B[l]=INF; for(int i=l+1;i<=mid;i++)B[i]=min(B[i-1],ans[i-1]+Deep[A[i]]); C[mid+1]=INF; for(int i=mid;i>=l+1;i--)C[i]=min(C[i+1],ans[i-1]); int temp=lca(a[mid],a[mid+1]),pos=mid; while(pos>l&&Deep[A[pos]]>Deep[temp])pos--; for(int i=mid+1;i<=r;i++) if(Deep[A[i]]>Deep[temp]) { //以pos~mid-1中任一点j结束,j~i的lca均为temp,故要选择ans[pos]~ans[mid-1]的最小值 pre[i]=min(pre[i],C[pos+1]+Deep[temp]); //以l~pos中任一点j结束,j~i的lca为A[j+1],故要选择g[j]+deep(A[j+1])的最小值 pre[i]=min(pre[i],B[pos]); } else { while(pos>l&&Deep[A[pos]]>Deep[A[i]])pos--; //以po~smid-1中任一点j结束,j~i的lca均为A[i],故要选择ans[pos]~ans[mid-1]的最小值 pre[i]=min(pre[i],C[pos+1]+Deep[A[i]]); //以l~pos中任一点j结束,j~i的lca为A[j+1],故要选择g[j]+deep(A[j+1])的最小值 pre[i]=min(pre[i],B[pos]); } CDQ(mid+1,r); } int main() { while(~scanf("%d%d",&n,&k)) { tot=0; memset(head,-1,sizeof(head)); for(int i=1;i<=n;i++)scanf("%d",&a[i]); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); add(u,v),add(v,u); } lca_init(); for(int i=1;i<=k;i++) { if(i==1) { int temp=a[1]; for(int j=1;j<=n;j++) temp=lca(temp,a[j]),ans[j]=Deep[temp]; } else { for(int j=1;j<=n;j++)pre[j]=INF; CDQ(1,n); for(int j=1;j<=n;j++)ans[j]=pre[j]; } } printf("%d\n",ans[n]); } return 0; }Solution2 可以用数学归纳法证明m个点的LCA其实就是m-1个相邻两点LCA中深度最小的那个,同样的,设dp[i][j]为前i个数分成j段每段LCA深度之和最小值,那么对于每个点,有三种选择 1.该点单独作为一段,dp[i][j]=min(dp[i][j],dp[i-1][j-1]+deep[i]) 2.该点与前一个点的LCA作为新的一段的最小值,dp[i][j]=min(dp[i][j],dp[i-2]j-1]+LCA(a[i-1],a[i])) 3.该点并入前一段,dp[i][j]=min(dp[i][j],dp[i-1][j]) 离线Tarjan求出相邻两点的LCA,每次转移O(1),用滚动数组优化掉dp数组第一维,时间复杂度O(nk) Code2
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int maxn=50011; namespace fastIO { #define BUF_SIZE 100000 //fread -> read bool IOerror=0; inline char nc() { static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE; if(p1==pend) { p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin); if(pend==p1) { IOerror=1; return -1; } } return *p1++; } inline bool blank(char ch) { return ch==' '||ch=='\n'||ch=='\r'||ch=='\t'; } inline void read(int &x) { char ch; while(blank(ch=nc())); if(IOerror)return; for(x=ch-'0';(ch=nc())>='0'&&ch<='9';x=x*10+ch-'0'); } #undef BUF_SIZE }; using namespace fastIO; struct node { int v,id,next; }e[2*maxn],q[2*maxn]; int n,k,tot1,tot2,head1[maxn],head2[maxn],f[maxn],Deep[maxn],vis[maxn],ans[maxn],lca[maxn]; int a[maxn],dp[3][maxn]; void add_edge(int u,int v) { e[tot1].v=v,e[tot1].next=head1[u],head1[u]=tot1++; } void add_query(int u,int v,int id) { q[tot2].v=v,q[tot2].id=id,q[tot2].next=head2[u],head2[u]=tot2++; } int find(int x) { if(f[x]==x)return x; return f[x]=find(f[x]); } void unite(int x,int y) { x=find(x),y=find(y); if(x!=y)f[x]=y; } void dfs(int u,int fa) { ans[u]=u; for(int i=head1[u];~i;i=e[i].next) { int v=e[i].v; if(v==fa)continue; Deep[v]=Deep[u]+1; dfs(v,u); unite(u,v); ans[find(u)]=u; } vis[u]=1; for(int i=head2[u];~i;i=q[i].next) { int v=q[i].v,id=q[i].id; if(vis[v])lca[id]=ans[find(v)]; } } int main() { using namespace fastIO; while(read(n),read(k),!IOerror) { tot1=tot2=0; memset(head1,-1,sizeof(head1)); memset(head2,-1,sizeof(head2)); memset(vis,0,sizeof(vis)); for(int i=1;i<=n;i++)f[i]=i; for(int i=1;i<=n;i++) { read(a[i]); if(i>1)add_query(a[i-1],a[i],i),add_query(a[i],a[i-1],i); } for(int i=1;i<n;i++) { int u,v; read(u),read(v); add_edge(u,v),add_edge(v,u); } Deep[1]=1; dfs(1,0); int cur=-1; for(int i=1;i<=n;i++) { cur=(cur+1)%3; int m=min(i,k); for(int j=1;j<=m;j++) { int temp=dp[(cur-1+3)%3][j-1]+Deep[a[i]]; if(i-1>=j)temp=min(temp,dp[(cur-1+3)%3][j]); if(i>=2&&j>=1&&i-2>=j-1)temp=min(temp,dp[(cur-2+3)%3][j-1]+Deep[lca[i]]); dp[cur][j]=temp; } } printf("%d\n",dp[cur][k]); } return 0; }