Hdu 5593 ZYB's Tree【树型Dp】经典题

xiaoxiao2021-02-28  116

ZYB's Tree

   Accepts: 77    Submissions: 513  Time Limit: 3000/1500 MS (Java/Others)    Memory Limit: 131072/131072 K (Java/Others) 问题描述 ZYBZYB有一颗NN个节点的树,现在他希望你对于每一个点,求出离每个点距离不超过KK的点的个数. 两个点(x,y)(x,y)在树上的距离定义为两个点树上最短路径经过的边数, 为了节约读入和输出的时间,我们采用如下方式进行读入输出: 读入:读入两个数A,BA,B,令fa_ifai为节点ii的父亲,fa_1=0fa1=0;fa_i=(A*i+B)\%(i-1)+1fai=(Ai+B)%(i1)+1 i \in [2,N]i[2,N] . 输出:输出时只需输出NN个点的答案的xorxor和即可。 输入描述 第一行一个整数TT表示数据组数。 接下来每组数据: 一行四个正整数N,K,A,BN,K,A,B. 最终数据中只有两组N \geq 100000N1000001 \leq T \leq 51T5,1 \leq N \leq 5000001N500000,1 \leq K \leq 101K10,1 \leq A,B \leq 10000001A,B1000000 输出描述 TT行每行一个整数表示答案. 输入样例 1 3 1 1 1 输出样例 3

思路:

观察到K的范围不大,那么我们考虑对每个点进行Dp计数。

这类Dp基本上都是,分两个方向去Dp。很套路。

①设定Dp【i】【j】表示以点i为根的子树中,到点i距离为j的点的个数。那么不难写出有:Dp【i】【j】+=Dp【son【i】】【j-1】;

②再设定F【i】【j】表示以点i为中心,非子树方向,到点i距离为j的点的个数,其实这部分的转移也很好想,除了根方向以外的部分,都要转移过来即可。

那么只有两种可能,一种是从父亲节点转移过来:F【i】【j】+=F【fa】【j-1】;

另一种可能就是从兄弟节点转移过来:F【i】【j】+=Dp【brother】【j-2】;

很显然直接去转移兄弟节点会TLE掉,因为一个节点的兄弟节点会很多,那么每个兄弟节点都处理一次的话,任务量实在是太大了。所以我们优化一下,设定Sum【i】【j】表示ΣDp【son【i】】【j】;那么我们就可以优化最后一个转移方程为:F【i】【j】+=Sum【fa【i】】【j-2】-Dp【i】【j-2】;

过程维护统计一下即可。

#include<stdio.h> #include<vector> #include<string.h> using namespace std; #define ll long long int vector<int>mp[550000]; int dp[550000][12]; int sum[550000][12]; int F[550000][12]; int n,k,A,B; void Dfs(int u,int from) { for(int i=0; i<mp[u].size(); i++) { int v=mp[u][i]; if(v==from)continue; Dfs(v,u); for(int j=1; j<=k; j++) { dp[u][j]+=dp[v][j-1]; } } for(int i=0; i<mp[u].size(); i++) { int v=mp[u][i]; if(v==from)continue; for(int j=0; j<=k; j++) { sum[u][j]+=dp[v][j]; } } } void dfs(int u,int from) { if(from!=-1)for(int j=1; j<=k; j++)F[u][j]+=F[from][j-1]; for(int j=2; j<=k; j++) { if(j-2>=0&&from!=-1) { F[u][j]+=sum[from][j-2]-dp[u][j-2]; } } for(int i=0; i<mp[u].size(); i++) { int v=mp[u][i]; if(v==from)continue; dfs(v,u); } } int main() { int t; scanf("%d",&t); while(t--) { memset(sum,0,sizeof(sum)); memset(dp,0,sizeof(dp)); memset(F,0,sizeof(F)); scanf("%d%d%d%d",&n,&k,&A,&B); for(int i=1; i<=n; i++)mp[i].clear(); for(int i=2; i<=n; i++) { ll fa=(ll)((ll)A*(ll)i+B)%(ll)(i-1)+1; mp[i].push_back(fa); mp[fa].push_back(i); } for(int i=1; i<=n; i++)dp[i][0]=1,F[i][0]=1; Dfs(1,-1); dfs(1,-1); int ans=0; for(int i=1; i<=n; i++) { int sum=0; for(int j=0; j<=k; j++) { sum+=F[i][j]+dp[i][j]; } ans^=(sum-1); } printf("%d\n",ans); } }

转载请注明原文地址: https://www.6miu.com/read-47565.html

最新回复(0)