首先觀察到我們有兩種方法來計算題目所求的值,令$ \displaystyle g(u)=\sum_{i=1}^{n} d(u,i)^2 $,那麼$\displaystyle f(u,v)=g(u)-2\cdot \sum_{x\notin S(v)} d(u,x)^2=2\cdot \sum_{x\in S(v)} d(u,x)^2 - g(u)$
首先先想辦法求出$g(u)$,如果我們只要知道特定一個$u$的$g$值的話,就直接對以$u$為根建立有根樹並DP就好了,DP時對每個節點$x$紀錄三個值:以$x$為根的子樹的大小,以$x$為根的子樹中的所有點走到$x$的距離和,還有以$x$為根的子樹中的所有點走到$x$的距離平方和,這三個東西的遞推式子算是顯然的(以後簡稱他們叫作三個DP值)。但我們需要對每個$u$都算出他的$g$值,如果對每個點都DFS一次顯然會太慢。這時只要注意到,實際上對兩個不同的點DFS的時候,會有很多點的那三個值是一模一樣的,更具體來說,如果現在對$X$和$Y$作DFS,那麼對於任何一個點$u$,如果從$X$走到$u$的路徑上的最後一條邊等於從$Y$走到$u$的路徑上的最後一條邊,則兩次的DFS在$u$求出的值是一樣的,因為這兩次的$u$的子樹是一模一樣的。這啟發了我們可以對每個點分別計算「當他的父節點是他的某個鄰居時,他的三個DP值會是多少」。嚴謹來說,定義對於相鄰的兩點$u,v$,定義$T(u,v)$代表砍掉$(u,v)$這條邊後以$u$為根的子樹,那麼把這棵子樹對應的三個DP值算出來。這樣只有$O(n)$個值要算,並且也不難得知這些值的$O(n)$或是$O(nlogn)$的算法,因此這裡成功解決了遇處理每個點的$g$值的部份。
有了每個點的$g$值之後,再回到題目要求的式子。當$u$不在以$v$為根的子樹內時,那麼第二條和所求等價的式子會很好算,因為我們可以把$d(u,x)$拆成$d(u,v)+d(v,x)$,這樣所求就可以由剛才求出的$T(v,fa[v])$的三個DP值計算出來了,其中$fa[v]$代表$v$的父節點(題目給的是以$1$為根的有根樹)。至於$u$落在$v$的子樹內的情形則使用第一條和所求等價的式子,此時對應的子樹就會是$T(fa[v],v)$,因此算法也根剛才類似了。另外注意到這個算法的回答詢問複雜度會是$O(logn)$,因為會需要求兩點之間的距離。
code :
#include<bits/stdc++.h> #define LL long long #define MOD 1000000007 using namespace std; const int maxn=100000+10 ; struct P{int to ; LL dis;}; vector<P> v[maxn] ; map<int,int> mp[maxn] ; vector<LL> psum[maxn],sqsum[maxn] ; vector<int> sz[maxn] ; LL sqval[maxn] ; int anc[17][maxn],dep[maxn] ; int tin[maxn],tout[maxn],tick ; void dfs0(int x,int f,int d) { tin[x]=tick++ ; dep[x]=d ; anc[0][x]=f ; for(int i=1;i<17;i++) anc[i][x]=anc[i-1][anc[i-1][x]] ; for(auto i : v[x]) if(f!=i.to) dfs0(i.to,x,(i.dis+d)%MOD) ; tout[x]=tick++ ; } inline bool isfa(int x,int y) { return tin[x]<=tin[y] && tout[x]>=tout[y] ; } int LCA(int x,int y) { if(isfa(x,y)) return x ; if(isfa(y,x)) return y ; for(int i=16;i>=0;i--) if(!isfa(anc[i][x],y)) x=anc[i][x] ; return anc[0][x] ; } int getlen(int x,int y) { int lca=LCA(x,y) , ret=(dep[x]+dep[y]-2*dep[lca])%MOD ; if(ret<0) ret+=MOD ; return ret ; } inline LL cal_add_s(const P &i,int id) { return sqsum[i.to][id]+(i.dis*i.dis%MOD)*sz[i.to][id]+2*i.dis*psum[i.to][id] ; } inline LL cal_add_p(const P &i,int id) { return psum[i.to][id]+sz[i.to][id]*i.dis ; } int n,cnt[maxn] ; void dfs(int x,int id) { if(psum[x][id]!=-1) return ; sz[x][id]=1 ; LL &ans1=psum[x][id] ; LL &ans2=sqsum[x][id] ; ans1=0 ; if(v[x].size()==1) return ; for(auto i : v[x]) if(i.to!=v[x][id].to) { int id2=mp[i.to][x] ; dfs(i.to,id2) ; ans1=(ans1+cal_add_p(i,id2))%MOD ; ans2=(ans2+cal_add_s(i,id2))%MOD ; sz[x][id]+=sz[i.to][id2] ; } if(++cnt[x]==2) { int id2=mp[v[x][id].to][x] ; LL tot1=ans1+cal_add_p(v[x][id],id2) ; tot1%=MOD ; LL tot2=ans2+cal_add_s(v[x][id],id2) ; tot2%=MOD ; int totsz=sz[x][id]+sz[v[x][id].to][id2] ; for(int i=0;i<v[x].size();i++) if(psum[x][i]==-1) { int id3=mp[v[x][i].to][x] ; assert(psum[v[x][i].to][id3]!=-1) ; psum[x][i]=tot1-cal_add_p(v[x][i],id3)%MOD ; if(psum[x][i]<0) psum[x][i]+=MOD ; sqsum[x][i]=tot2-cal_add_s(v[x][i],id3)%MOD ; if(sqsum[x][i]<0) sqsum[x][i]+=MOD ; sz[x][i]=totsz-sz[v[x][i].to][id3] ; } } } LL getval(int x) { LL ret=0 ; for(auto i : v[x]) { int id=mp[i.to][x] ; dfs(i.to,id) ; ret=(ret+cal_add_s(i,id))%MOD ; } return ret ; } LL query(int u,int v) { LL ret=0LL ; if(!isfa(v,u)) { ret=-sqval[u]+MOD ; int id=mp[v][anc[0][v]] ; LL dis=getlen(u,v) , add=(dis*dis%MOD)*sz[v][id] ; add+=sqsum[v][id] ; add+=2*dis*psum[v][id] ; add%=MOD ; return (ret+2*add)%MOD ; } else if(v==1) return sqval[u] ; else { ret=sqval[u] ; int v2=anc[0][v] , id=mp[v2][v] ; LL dis=dep[u]-dep[v2] ; if(dis<0) dis+=MOD ; LL sub=(dis*dis%MOD)*sz[v2][id] ; sub+=sqsum[v2][id] ; sub+=2*dis*psum[v2][id] ; sub%=MOD ; return ((ret-2*sub)%MOD+MOD)%MOD ; } } main() { scanf("%d",&n) ; for(int i=1;i<n;i++) { int x,y ; LL dis ; scanf("%d%d%I64d",&x,&y,&dis) ; v[x].push_back((P){y,dis}) ; v[y].push_back((P){x,dis}) ; } for(int i=1;i<=n;i++) { psum[i].resize(v[i].size()) ; sqsum[i].resize(v[i].size()) ; sz[i].resize(v[i].size()) ; for(int j=0;j<v[i].size();j++) mp[i][v[i][j].to]=j , psum[i][j]=-1 ; } dfs0(1,1,0) ; for(int i=1;i<=n;i++) sqval[i]=getval(i) ; int Q ; scanf("%d",&Q) ; while(Q--) { int x,y ; scanf("%d%d",&x,&y) ; printf("%I64d\n",query(x,y)) ; } }
沒有留言:
張貼留言