測資修正之後AC了!總共獲得了24個AC XD
作法:
因為其實是「把邊砍掉」之後才會有「哪些點被感染哪些點沒被感染」,所以可以反過來作砍掉 m 條邊之後最多有幾個點沒被感染。把沒被感染的叫白點,剩下的叫黑點,記 dp[ b ][ x ][ m ] 代表在以 x 為根的子樹中,當 x 的顏色是 b ,且砍掉 m 個邊的時候,白點的最大值。所以這樣可以寫出很大一坨的轉移式:
dp[ b ][ x ][ m ] = ( b==white ) + MAX{ sigma dp[ bi ][ vi ][ mi ] } ,其中 bi 和 mi 們滿足
sigma( mi + ( bi != b ) ) = m 。 vi 代表的是 x 的子結點們,而其實這句話的意思就是將可以砍的 m 條邊分配給 x 的子結點,並且當 x 的子結點和 x 不同色的時候必須要再多加一條邊分隔他們。
所以這個的意思就是我們必須從好幾個 DP 陣列 ( 也就是 dp[ bi ][ vi ] 們 ) 合併成一個新的陣列 dp[ b ][ x ] ,但當然不可能枚舉所有的 mi ,所以在這裡我們還需要另一個DP,設 dp2[ i ][ m ]代表目前已經考慮了 i 顆子樹了,那麼砍掉 m 條邊至多可以有幾個白點。每次多考慮一個子樹就會把那個子樹的DP陣列和現在的dp2 陣列跑過一次。
這樣看起來每個節點都會花掉 O( n^2 ) 的時間轉移,但其實只要注意到對於一個子樹 x ,他的邊數只有 x 的 size -1 ,也就是 dp[ b ][ x ][ m ] 的 m 其實只要作 0 ~ x 的 size -1 就夠了。所以設 x 的子樹的 size 分別為 n1 , n2 , ... , nt ,當在作 dp[ i+1 ] 的時候其實要跑過的陣列只有 n(i+1) * ( n1+n2+...+ni ) ,把他們 sum 起來會得到總合併時間 = sigma ni * nj ,其中 i , j 跑遍 1~t 且 i < j 。所以如果用數歸,假設 x 的子樹們都能在 O( size ^2 ) 的時間內完成,那麼完成 x 的時間就是 (sigma ni^2) + (sigma ni * nj) <= ( sigma ni ) ^2 < size[x]^2 ,得證。
code :
#include<bits/stdc++.h> #define INF 10000000 using namespace std; const int maxn=1000+50 ; int sz[maxn],dp[2][maxn][maxn] ; int dp2[maxn][maxn] ; vector<int> v[maxn] ; bool G[maxn] ; void dfs(int x,int fa) { sz[x]=1 ; if(x && (int)v[x].size()==1) { if(!G[x]) dp[0][x][0]=1 , dp[1][x][0]=0 ; else dp[0][x][0]=-INF , dp[1][x][0]=0 ; return ; } for(auto i : v[x]) if(i!=fa) dfs(i,x) , sz[x]+=sz[i] ; for(int b=0;b<2;b++) { if(b==0 && G[x]) { for(int i=0;i<sz[x];i++) dp[b][x][i]=-INF ; continue ; } for(int i=0;i<sz[x];i++) dp2[0][i]=-INF ; dp2[0][0]=0 ; int tot=0,cnt=0 ; for(int i=0;i<v[x].size();i++) if(v[x][i]!=fa) { int y=v[x][i] ; tot+=sz[y] ; for(int j=0;j<sz[x];j++) dp2[cnt+1][j]=-INF ; for(int j=(b!=0);j<=tot;j++) for(int k=0;k<sz[y] && j-k-(b!=0)>=0;k++) dp2[cnt+1][j]=max(dp2[cnt+1][j], dp2[cnt][j-k-(b!=0)]+dp[0][y][k]) ; for(int j=(b!=1);j<=tot;j++) for(int k=0;k<sz[y] && j-k-(b!=1)>=0;k++) dp2[cnt+1][j]=max(dp2[cnt+1][j], dp2[cnt][j-k-(b!=1)]+dp[1][y][k]) ; cnt++ ; } for(int i=0;i<sz[x];i++) dp[b][x][i]= (b==0) + dp2[cnt][i] ; } } main() { int n,k,num ; scanf("%d%d%d",&n,&k,&num) ; if(k+num>n) { printf("ACM rules!\n") ; return 0 ; } while(num--) { int x ; scanf("%d",&x) ; G[x]=1 ; } for(int i=1;i<n;i++) { int x,y ; scanf("%d%d",&x,&y) ; v[x].push_back(y) ; v[y].push_back(x) ; } dfs(0,-1) ; int ans1=n-1,ans2=n-1 ; for(int i=0;i<n;i++) if(dp[0][0][i]>=k) { ans1=i ; break ; } for(int i=0;i<n;i++) if(dp[1][0][i]>=k) { ans2=i ; break ; } printf("%d\n",min(ans1,ans2)) ; }