作法:
以下記每個點的保衛值為 val[ x ] 。
考慮樹分治,對於一個子樹,先選出他的重心當根,把它叫做 x 好了,然後我們的目的是要把根拔掉,分出好幾顆子樹,再分別做這些子樹。所以第一個我們需要把根的答案加上「這顆子樹中和他距離 <= val[ x ] 的點數個數 」,並且因為接下來要把各個子樹拆開處理,假設子樹們為 T1 , ... , Tk ,那麼對於 T1 裡的某個點 P 而言,在 T1 裡面並且和他距離 <= val[ P ] 的個數會在計算子問題 ( T1 ) 的時候被算完,而 T2 ~ Tk 裡面的點對 P 的貢獻在之後就不會被算到了,所以這時候必須把 T2 ~ Tk ( 還有 x )提供的和 P 距離 <= val[ P ] 的點個數先加到 P 的答案裡面。
所以現在的問題變成要如何對每個子樹計算那個值。設 T1 中的點 P 的深度為 d[ P ] ( 其中 x 的 d 值是 0 ),那麼要算的 P 的答案就是用「以 x 為根的子樹中和 P 的距離 <= val[ P ] - d[ P ] 」的點數量,扣掉「 T1 中和 T1 的根的距離 <= val[ P ] - d[ P ] - 1 」 的點數量,所以我們需要先 DFS 求出以 x 為根的子樹中距離 <= 某個值 k 的點數量有多少( 在我的 code 裡是用 sum 陣列紀錄 ),並且在算下一個子樹 Ti 裡的所求值時候,先算出 Ti 中到 Ti 的根的距離 <= 某個值 k 的點數量有多少 ( 用 sum2 陣列紀錄 ),就可以算出每個點要加多少值了。
但如果在每次做樹分治的時候都對 sum 陣列和 sum2 陣列直接用 memset 歸零 ,那麼時間就會爆掉。但事實上當某個子樹的大小 = size 的時候,到子樹的根的距離頂多是 size ,也就是 sum ( 或 sum2 ) 陣列其實從某一個點之後他的值就不在變了,所以當我們在求 sum 陣列的時候,第一步是先算出「到子樹的根的距離恰為 k 的樹有幾個」,再做他的前綴和,所以只需要把前 size 個都設成0就好了 ( 用 fill 函式 )。但這樣要注意到,之後要查詢 sum 陣列裡的值的時候,如果查詢的 index > size ,那麼這次查詢的答案就會是 sum[ size ] ,因為 sum[ size ] 以後的值都是錯的,我們只有維護 0 ~ size 的值,所以要特判。
code :
#include<bits/stdc++.h> using namespace std; const int maxn=100000+10 ; int val[maxn] ; vector<int> v[maxn] ; bool vis[maxn] ; int sum[maxn],sum2[maxn],sz1,sz2 ; int cnt ; int d[maxn] ; void dfs0(int x,int &M,int dep) { d[x]=dep ; vis[x]=1 ; cnt++ ; if(d[x]>d[M]) M=x ; for(auto i : v[x]) if(!vis[i]) dfs0(i,M,dep+1) ; vis[x]=0 ; } int get_cent(int x,int &sz) { int y=x ; cnt=0 ; dfs0(x,y,0) ; sz=cnt ; x=y ; dfs0(y,x,0) ; int maxd=d[x] ; if(!maxd) return x ; for(int i=x;;) { for(auto j : v[i]) if(!vis[j] && d[j]==d[i]-1) { i=j ; break ; } if(d[i]==maxd/2) return i ; } } void dfs_dis(int x,int dep,int *sm) { sm[dep]++ ; vis[x]=1 ; for(auto i : v[x]) if(!vis[i]) dfs_dis(i,dep+1,sm) ; vis[x]=0 ; } int ans[maxn] ; void dfs_cal(int x,int dep) { vis[x]=1 ; for(auto i : v[x]) if(!vis[i]) dfs_cal(i,dep+1) ; vis[x]=0 ; int val2=val[x]-dep-1 ; if(val2>=0) ans[x]+= ( val2<=sz1 ? sum[val2] : sum[sz1] ) , ans[x]-= (val2<=sz2 ? sum2[val2] : sum2[sz2]) ; } void solve(int y) { int x=get_cent(y,sz1) ; fill(sum,sum+sz1+1,0) ; dfs_dis(x,0,sum) ; for(int i=1;i<=sz1;i++) sum[i]+=sum[i-1] ; ans[x]+= (val[x]<=sz1 ? sum[val[x]] : sum[sz1]) ; vis[x]=1 ; for(auto i : v[x]) if(!vis[i]) { get_cent(i,sz2) ; fill(sum2,sum2+sz2+1,0) ; dfs_dis(i,1,sum2) ; for(int i=1;i<=sz2;i++) sum2[i]+=sum2[i-1] ; dfs_cal(i,0) ; } for(auto i : v[x]) if(!vis[i]) solve(i) ; } main() { int n ; scanf("%d",&n) ; for(int i=1;i<=n;i++) scanf("%d",&val[i]) ; for(int i=1;i<n;i++) { int x,y ; scanf("%d%d",&x,&y) ; v[x].push_back(y) ; v[y].push_back(x) ; } solve(1) ; for(int i=1;i<=n;i++) printf("%d\n",ans[i]) ; }
沒有留言:
張貼留言