圍繞這題的東西是「一個點到離他最遠的葉子的距離」,那麼我們首先來看這個東西有什麼性質。假設離某個點 A 最遠的葉節點為 P ,且 A 走到 P 的路上經過的第一個點為 B ,那麼可以推到對於和 A 相鄰的除了 B 以外的點們來說,離他們最遠的葉子都是 P ,證明不難證。而這也就代表了如果以 A 為樹的根的話,那麼會得到 B 不在的所有子樹裡的點,他們的離最遠的葉節點都會是 P 了。但在處理這個問題之前,我們還是要先知道要怎麼求出每個點的離他最遠的葉子距離,而實際上不難證明:如果 A1 A2 為這棵樹的直徑的兩端點 (直徑就是樹上離最遠的兩點),那麼對於每個點,他到 A1 的距離和他到 A2 的距離之一就會是所求的東西。
以下記一個點 x 到離他最遠的葉節點的距離為 val[ x ] ,考慮 val 值最小的點 X ,那麼就可以得到一個重要的性質:把這棵樹以 X 為根建成有根樹,那麼對於每一個點 Y , Y 的父親的 val 值 < Y 的 val 值。這件事情的證明只要用前面提到的 val 的性質,加上 X 是 val 最小的點就可以簡單的證明了。
有了這個重要的性質之後,這就代表了我們可以枚舉所求點集中深度最淺的那個點,記他為 P 好了,那麼當 P 確定之後,就代表整個點集的最小值確定了,那麼這時候所有「 是 P 的子孫且其 val 值 <= val[ P ] + L 」的點就會是一個滿足題目條件的集合,用這個集合的大小去更新答案。而這就代表我們必須對每個節點都維護一個 multiset ,裡面放的是這個節點和他的子孫的所有 val 值,並且在算 P 的答案的時候必須把所有 val 值太大的從 multiset 裡面移除掉。然後我們還需要合併兩個 multiset ,於是就會馬上想到啟發式合併(注意到當在計算 P 的答案時,如果把某個 val 值移除掉了,那麼也代表之後在算 P 的祖先的答案的時候這個 val 值也是沒有用的,因為 P 的祖先的 val 值比 P 的還要小,因此這個合併不會漏掉任何可能性)。但啟發式合併的複雜度是 O(n * log^2 n) ,我自己寫完傳上去就 TLE 了,把 multiset 改成 map 一樣 TLE ,畢竟會有 50 筆詢問,應該就是用來卡掉 O(n * log^2 n) 解用的。
仔細想想其實可以把這個解改進到 O(nlogn) ,因為事實上對於一個點 Q 的 val 值,我們可以二分搜出他會在處理到哪個點的時候被移除掉!也就是如果 R 是深度最淺的點,滿足 val[ R ] < val[ Q ] - L ,那麼 val[ Q ] 就會在處理 val[ R ] 的時候被移除掉(其實應該不太算二分搜,我在找 R 的作法是類似 LCA 的倍增法,紀錄 2^k 級祖先然後一直往上跳這樣)。因此我們可以換個角度作,既然都知道 Q 會在什麼時候被移除掉了,那這就代表從 Q 一路沿著他的父親走,走到 R 之前的這些點們都會收到 Q 的貢獻。那麼這樣就可以轉換為經典問題了!也就是現在問題變成:對於每個點 Q ,找出他能夠貢獻的點們形成的路徑之後,把這條路徑上的每個點的權值都加上 1 ,最後詢問所有點中的點權最大值為多少。這樣就可以透過樹壓平轉換成簡單的一維問題了!
code :
#include<bits/stdc++.h> #define LL long long #define INF (1LL<<60) using namespace std; const int maxn=100000+10 ; struct P{int to,dis;}; vector<P> v[maxn] ; LL d1[maxn],d2[maxn],val[maxn] ; int n ; void dfs(int x,int f,LL *d) { for(auto i : v[x]) if(i.to!=f) d[i.to]=d[x]+i.dis , dfs(i.to,x,d) ; } void getdia(int &a,int &b) { d1[1]=0 ; dfs(1,-1,d1) ; LL M=-1 ; for(int i=1;i<=n;i++) if(d1[i]>M) M=d1[a=i] ; d1[a]=0 ; dfs(a,-1,d1) ; M=-1 ; for(int i=1;i<=n;i++) if(d1[i]>M) M=d1[b=i] ; } void getval() { int a,b ; getdia(a,b) ; d1[a]=0 ; dfs(a,-1,d1) ; d2[b]=0 ; dfs(b,-1,d2) ; for(int i=1;i<=n;i++) val[i]=max(d1[i],d2[i]) ; } int pos[maxn],ri[maxn],t ; int anc[18][maxn] ; void dfs2(int x,int f) { pos[x]=++t ; anc[0][x]=f ; for(int i=1;i<18;i++) anc[i][x]=anc[i-1][anc[i-1][x]] ; for(auto i : v[x]) if(i.to!=f) dfs2(i.to,x) ; ri[x]=t ; } int getinterval() { LL mi=INF ; int root ; for(int i=1;i<=n;i++) if(val[i]<mi) mi=val[root=i] ; dfs2(root,root) ; return root ; } int sum[maxn] ; main() { scanf("%d",&n) ; for(int i=1;i<n;i++) { int x,y,dis ; scanf("%d%d%d",&x,&y,&dis) ; v[x].push_back((P){y,dis}) ; v[y].push_back((P){x,dis}) ; } getval() ; int root=getinterval() ; int Q ; scanf("%d",&Q) ; while(Q--) { LL L ; scanf("%I64d",&L) ; memset(sum,0,sizeof(sum)) ; for(int i=1;i<=n;i++) { if(val[i]-val[root]<=L) { sum[pos[i]]++ ; continue ; } int j=i ; for(int k=17;k>=0;k--) if(val[i]-val[anc[k][j]]<=L) j=anc[k][j] ; j=anc[0][j] ; sum[pos[i]]++ ; sum[pos[j]]-- ; } for(int i=1;i<=n;i++) sum[i]+=sum[i-1] ; int ans=0 ; for(int i=1;i<=n;i++) ans=max(ans,sum[ri[i]]-sum[pos[i]-1]) ; printf("%d\n",ans) ; } }
沒有留言:
張貼留言