首先考慮一定會對的DP方法:設$dp[S][i]$代表當目前猜的位置的二進制表示為$S$,且所選的答案為第$i$個字串時,期望還要再猜幾次。計算這個DP值的方法並不難,如果$S$裡的這些位置已經可以唯一確定第$i$個字串的話,他的DP值就是0,否則$\displaystyle dp[S][i]=1+\frac{1}{k}\sum_{j} dp[S| (2^j)][i]$,其中$j$滿足他是還沒被猜過的位置,$k$是還沒被猜過的位置個數。最後答案就是$\displaystyle \frac{1}{n}\sum_{i=0}^{n-1} dp[0][i]$。但這樣時間和空間都會爆掉,所以需要改進。
事實上我們可以把這些DP值的第2維壓起來,也就是令$dp2[S]=dp[S][0]+...+dp[S][n-1]$,而因為當$S$能唯一辨別第$i$個字串時$dp[S][i]=0$,也就是我們可以把剛才的轉移式改寫成:若$S$能唯一辨別第$i$個字串,則$\displaystyle dp[S][i]=\frac{1}{k}\sum_{j} dp[S| (2^j)][i]$(其實他就是$0=0+...+0$)。那麼由這條式子就可以推出$dp2$的轉移式了,他會是$\displaystyle dp2[S]=num[S]+\frac{1}{k}\sum_{j} dp[S| (2^j)]$,其中$num[S]$代表$S$辨別不出來的字串的個數。
所以問題轉換為要如何求出所有的$num[S]$,考慮任兩個字串$s_i,s_j$,設$s_i$和$s_j$共同的部份的集合為$S0$,那麼所有$S0$的子集合都沒辦法辨識出$s_i$和$s_j$。因此如果對於每個$i$,找出所有$j$對應的$S0$(當然$j\neq i$),那麼所有這些集合的子集合的聯集就是所有沒辦法辨識出$s_i$的集合。可以用DFS來處理這個問題,設$in[S]$代表$S$這個集合已經在剛才那些集合的聯集裡了,那麼所有$S$的子集合也會在裡面,因此當我們多加入一個集合,準備把他的所有子集合的$in$值設成$1$的時候,就直接DFS下去,並且遇到$in$值已經是$1$的數就直接return 就好了,這樣複雜度會是$O(n2^m)$,其中$m$是字串的長度。
但這樣傳上去TLE了,後來看解才知道,其實可以令$d[S]$代表「用$S$沒辦法辨識出的字串集合」,那麼當我們找到$s_i$和$s_j$的$S0$時,讓$d[S0]|=(2^i)$根$(2^j)$,這樣就得到了初步的$d$值們,但我們還要求,如果$S2$是$S1$的子集合,且$S1$不能辨識$i$,那麼$S2$也不行,而這只要按照數字大到小,把$d[S]$的值拿去 $or$ 上所有$d[S']$的值就可以了,其中$S'$是包含$S$且他們只差一個bit的集合。最後看$d[S]$有幾位是$1$就知道他不能辨別多少字串了。
code :
#include<bits/stdc++.h> #define DB double #define LL long long using namespace std; const int maxn=20,maxm=50 ; int n,m,num[1<<maxn] ; char s[maxm][maxn] ; LL bit[1<<maxn] ; void getnum() { for(int i=0;i<(1<<n);i++) num[i]=m ; for(int i=0;i<m;i++) for(int j=i+1;j<m;j++) { int same=0 ; for(int k=0;k<n;k++) if(s[i][k]==s[j][k]) same|=(1<<k) ; bit[same]|=(1LL<<i) ; bit[same]|=(1LL<<j) ; } for(int i=((1<<n)-1);i>=0;i--) { for(int j=0;j<n;j++) if(!(i&(1<<j))) bit[i]|=bit[i^(1<<j)] ; num[i]=__builtin_popcountll(bit[i]) ; } } DB dp[1<<maxn] ; DB DP(int S) { if(dp[S]>=0) return dp[S] ; DB &ans=dp[S] ; ans=0 ; int k=0 ; for(int i=0;i<n;i++) if(!(S&(1<<i))) k++ , ans+=DP(S^(1<<i)) ; ans=ans/k+num[S] ; return ans ; } main() { scanf("%d",&m) ; for(int i=0;i<m;i++) { scanf("%s",s[i]) ; if(!i) n=strlen(s[i]) ; } getnum() ; fill(dp,dp+(1<<n),-1) ; dp[(1<<n)-1]=0 ; printf("%.15f\n",DP(0)/m) ; }
沒有留言:
張貼留言