点此看题面
大致题意: 给定一棵树,每个节点有一个颜色,定义
s
(
i
,
j
)
s(i,j)
s(i,j)为
i
i
i到
j
j
j路径上颜色数量,请你对于每一个
i
i
i求出
∑
i
=
1
n
s
(
i
,
j
)
\sum_{i=1}^n s(i,j)
∑i=1ns(i,j)。
点分治
这种题目比较显然是点分治吧… …
L
i
n
k
Link
Link
点分治 详见博客 初学点分治
大致思路
首先,按照点分治的基本套路,对于一棵子树内的路径,我们分两种情况讨论:
经过根节点的不经过根节点的
呃,第二种情况就按照点分治继续处理下去就可以了,因此忽略不提。
而对于第一种情况的边,我们又可以分为两类:
以根节点为一个端点的不以根节点为端点的
接下来就按此分类处理即可。
先是一波定义
首先,我们需要先来一波定义。
s
u
m
sum
sum:所有颜色造成的总贡献。
c
o
l
i
col_i
coli:节点
i
i
i的颜色。
V
a
l
i
Val_i
Vali:颜色
i
i
i对答案造成的贡献。
S
i
z
e
i
Size_i
Sizei:以节点
i
i
i为根的子树大小。
具体流程
首先是一遍
d
f
s
dfs
dfs预处理。在这一遍与处理中,我们主要求出两个东西:
S
i
z
e
Size
Size和
V
a
l
Val
Val两个数组,当然,在求
V
a
l
Val
Val数组的同时,也要顺带求出
s
u
m
sum
sum。
然后,我们就可以对第一类路径进行处理,更新
a
n
s
r
t
ans_{rt}
ansrt。
具体操作:
a
n
s
r
t
+
=
s
u
m
−
V
a
l
c
o
l
r
t
+
S
i
z
e
r
t
ans_{rt}+=sum-Val_{col_{rt}}+Size_{rt}
ansrt+=sum−Valcolrt+Sizert。
这样做的理由是先将含有
r
t
rt
rt的颜色所造成的贡献从
s
u
m
sum
sum中删去(不然会造成重复计算),然后加上
S
i
z
e
r
t
Size_{rt}
Sizert(即含有
r
t
rt
rt的路径条数,
r
t
rt
rt的颜色在这些路径中会各对答案造成一点贡献)。
接下来便是对第二类路径的处理。我们考虑对于根节点的每一棵子树如何计算其贡献值。
第一步自然是将这棵子树中的贡献清空,不然会造成重复计算。接下来,从该子节点开始,一直往下遍历,更新每一个节点的
a
n
s
ans
ans。首先,我们要将沿途经过的在该子树中出现的颜色的贡献从
s
u
m
sum
sum中减去,避免重复计算,并用一个变量
c
o
l
o
r
_
t
o
t
color\_tot
color_tot记录该子树中颜色种数。则对于当前节点
x
x
x,我们可以将
a
n
s
x
ans_x
ansx加上
s
u
m
+
c
o
l
o
r
_
t
o
t
∗
O
t
h
e
r
sum+color\_tot*Other
sum+color_tot∗Other。其中
O
t
h
e
r
Other
Other表示除该棵子树外其他子树以及根节点的总点数,因此将剩余贡献加上
c
o
l
o
r
_
t
o
t
∗
O
t
h
e
r
color\_tot*Other
color_tot∗Other就是当前节点能造成的贡献了(和对第一类路径的操作同理)。
好吧,还有一个很重要的细节:注意清空数组。
但是,如何清空数组也是有学问的。如果直接
m
e
m
s
e
t
memset
memset,显然复杂度成了
O
(
n
)
O(n)
O(n),可能会
T
L
E
TLE
TLE。所以,我们直接将该子树中出现的所有颜色的相关值全部清空即可,就是一个遍历子树的过程,时间复杂度为
O
(
子
树
大
小
)
O(子树大小)
O(子树大小),从而保证了时间复杂度。
代码
#include<bits/stdc++.h>
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define uint unsigned int
#define LL long long
#define ull unsigned long long
#define swap(x,y) (x^=y,y^=x,x^=y)
#define abs(x) ((x)<0?-(x):(x))
#define INF 1e9
#define Inc(x,y) ((x+=(y))>=MOD&&(x-=MOD))
#define ten(x) (((x)<<3)+((x)<<1))
#define N 100000
#define add(x,y) (e[++ee].nxt=lnk[x],e[lnk[x]=ee].to=y)
using namespace std
;
LL n
,ee
=0,lnk
[N
+5],col
[N
+5];
struct edge
{
LL to
,nxt
;
}e
[(N
<<1)+5];
class FIO
{
private:
#define Fsize 100000
#define tc() (FinNow==FinEnd&&(FinEnd=(FinNow=Fin)+fread(Fin,1,Fsize,stdin),FinNow==FinEnd)?EOF:*FinNow++)
#define pc(ch) (FoutSize<Fsize?Fout[FoutSize++]=ch:(fwrite(Fout,1,FoutSize,stdout),Fout[(FoutSize=0)++]=ch))
LL f
,FoutSize
,OutputTop
;char ch
,Fin
[Fsize
],*FinNow
,*FinEnd
,Fout
[Fsize
],OutputStack
[Fsize
];
public:
FIO() {FinNow
=FinEnd
=Fin
;}
inline void read(LL
&x
) {x
=0,f
=1;while(!isdigit(ch
=tc())) f
=ch
^'-'?1:-1;while(x
=ten(x
)+(ch
&15),isdigit(ch
=tc()));x
*=f
;}
inline void read_char(char &x
) {while(isspace(x
=tc()));}
inline void read_string(string
&x
) {x
="";while(isspace(ch
=tc()));while(x
+=ch
,!isspace(ch
=tc())) if(!~ch
) return;}
inline void write(LL x
) {if(!x
) return (void)pc('0');if(x
<0) pc('-'),x
=-x
;while(x
) OutputStack
[++OutputTop
]=x
%10+48,x
/=10;while(OutputTop
) pc(OutputStack
[OutputTop
]),--OutputTop
;}
inline void write_char(char x
) {pc(x
);}
inline void write_string(string x
) {register LL i
,len
=x
.length();for(i
=0;i
<len
;++i
) pc(x
[i
]);}
inline void end() {fwrite(Fout
,1,FoutSize
,stdout);}
}F
;
class Class_DotSolver
{
private:
LL rt
,sum
,color_tot
,used
[N
+5],Size
[N
+5],Max
[N
+5],cnt
[N
+5],Val
[N
+5];
inline void GetRt(LL x
,LL lst
,LL tot
)
{
register LL i
;
for(i
=lnk
[x
],Size
[x
]=1,Max
[x
]=0;i
;i
=e
[i
].nxt
)
if(e
[i
].to
^lst
&&!used
[e
[i
].to
]) GetRt(e
[i
].to
,x
,tot
),Size
[x
]+=Size
[e
[i
].to
],Max
[x
]=max(Max
[x
],Size
[e
[i
].to
]);
if((Max
[x
]=max(Max
[x
],tot
-Size
[x
]))<Max
[rt
]) rt
=x
;
}
inline void Clear(LL x
,LL lst
)
{
for(register LL i
=lnk
[x
];i
;i
=e
[i
].nxt
) if(e
[i
].to
^lst
&&!used
[e
[i
].to
]) Clear(e
[i
].to
,x
);
cnt
[col
[x
]]=Val
[col
[x
]]=0;
}
inline void Init(LL x
,LL lst
)
{
register LL i
;
for(i
=lnk
[x
],Size
[x
]=1,++cnt
[col
[x
]];i
;i
=e
[i
].nxt
) if(e
[i
].to
^lst
&&!used
[e
[i
].to
]) Init(e
[i
].to
,x
),Size
[x
]+=Size
[e
[i
].to
];
if(!--cnt
[col
[x
]]) sum
+=Size
[x
],Val
[col
[x
]]+=Size
[x
];
}
inline void Change(LL x
,LL lst
,LL flag
)
{
register LL i
;
for(i
=lnk
[x
],++cnt
[col
[x
]];i
;i
=e
[i
].nxt
) if(e
[i
].to
^lst
&&!used
[e
[i
].to
]) Change(e
[i
].to
,x
,flag
);
if(!--cnt
[col
[x
]]) flag
?(sum
+=Size
[x
],Val
[col
[x
]]+=Size
[x
]):(sum
-=Size
[x
],Val
[col
[x
]]-=Size
[x
]);
}
inline void F5(LL x
,LL lst
,LL Other
)
{
register LL i
;
if(!cnt
[col
[x
]]++) sum
-=Val
[col
[x
]],++color_tot
;
for(i
=lnk
[x
],ans
[x
]+=sum
+color_tot
*Other
;i
;i
=e
[i
].nxt
) if(e
[i
].to
^lst
&&!used
[e
[i
].to
]) F5(e
[i
].to
,x
,Other
);
if(!--cnt
[col
[x
]]) sum
+=Val
[col
[x
]],--color_tot
;
}
inline void Solve(LL x
)
{
register LL i
,j
;
for(i
=lnk
[x
],used
[x
]=1,Init(x
,0),ans
[x
]+=sum
-Val
[col
[x
]]+Size
[x
];i
;i
=e
[i
].nxt
)
{
if(used
[e
[i
].to
]) continue;
++cnt
[col
[x
]],sum
-=Size
[e
[i
].to
],Val
[col
[x
]]-=Size
[e
[i
].to
],Change(e
[i
].to
,x
,0),
--cnt
[col
[x
]],F5(e
[i
].to
,x
,Size
[x
]-Size
[e
[i
].to
]),++cnt
[col
[x
]],
Change(e
[i
].to
,x
,1),Val
[col
[x
]]+=Size
[e
[i
].to
],sum
+=Size
[e
[i
].to
],--cnt
[col
[x
]];
}
for(i
=lnk
[x
],sum
=color_tot
=0,Clear(x
,0);i
;i
=e
[i
].nxt
)
if(!used
[e
[i
].to
]) GetRt(e
[i
].to
,rt
=0,Size
[e
[i
].to
]),Solve(rt
);
}
public:
Class_DotSolver() {Max
[0]=INF
;}
LL ans
[N
+5];
inline void GetAns() {GetRt(1,rt
=0,n
),Solve(rt
);}
}DotSolver
;
int main()
{
register LL i
,x
,y
;
for(F
.read(n
),i
=1;i
<=n
;++i
) F
.read(col
[i
]);
for(i
=1;i
<n
;++i
) F
.read(x
),F
.read(y
),add(x
,y
),add(y
,x
);
for(DotSolver
.GetAns(),i
=1;i
<=n
;++i
) F
.write(DotSolver
.ans
[i
]),F
.write_char('\n');
return F
.end(),0;
}