這題我是照官方解寫的,而有很多人是直接想辦法加速位元運算來過的,但我沒有仔細看他們的寫法,詳細可以去看看別人寫的 code 。總之我寫了FFT,首先把原始的第一個字串拆成好幾塊,每一塊都做一次FFT求得那一塊和第二個字串反過來的多項式乘積。對每個詢問都把他拆成遇處理好的幾塊根左右的一些小渣渣,那兩小塊就暴力做就可以了,但不能一個一個比,一個簡單的加速方法就是把這兩個字串的$64$個 bit 壓成一個數,然後直接對他位元運算,然後用$ \_\_builtin\_popcountll() $這個函式來算出一個 long long 裡有幾個 bit 是$1$。這樣就能讓暴力那部份的常數少$64$倍。具體來說,只要先對兩個字串的每個位置都遇處理一個數,代表「從這個數開始往後數64個(若到底了就停),這些 bit 壓起來的值是多少」就可以了。
對了,這次的FFT我換成了模大質數和原根版本的FFT,以一個質數的原根來取代$1$的$n$次單位根,並且所有數字和運算都是在模那個大質數意義下進行的,詳細可以參考這篇。(這種寫法好像叫作NTT)
code :
#include<bits/stdc++.h> #define LL long long #define ULL unsigned long long using namespace std; const int maxn=1<<18 ; const int MOD=2013265921,MAX=MOD/2 ; /// MOD=15*2^27+1, which is a prime const int gen=440564289 ; /// 31 is primitive root of MOD , gen = (31^15)%MOD LL power(LL x,int n) { if(n<=1) return n ? x : 1LL ; LL t=power(x,n/2) ; if(n&1) return (t*t%MOD)*x%MOD ; else return t*t%MOD ; } LL tmp[maxn] ; void DFT(LL *a,LL x,int n) { if(n==1) { a[0]%=MOD ; if(a[0]<0) a[0]+=MOD ; return ; } for(int i=0;i<n;i++) tmp[i]=a[i] ; for(int i=0;i<n;i++) a[i%2 ? n/2+i/2 : i/2]=tmp[i] ; LL *a1=a , *a2=a+n/2 , xx=x*x%MOD ; DFT(a1,xx,n/2) ; DFT(a2,xx,n/2) ; LL now=1 ; for(int i=0;i<n/2;i++,now=now*x%MOD) { LL val=now*a2[i]%MOD ; tmp[i]=a1[i]+val-MOD ; tmp[i+n/2]=a1[i]-val ; } for(int i=0;i<n;i++) a[i]=(tmp[i]<0?tmp[i]+MOD:tmp[i]) ; } void mul(LL *a,LL *b,LL *c,int n) { LL x=power(gen,(1<<27)/n) , xinv=power(x,n-1) ; LL ninv=power(n,MOD-2) ; DFT(a,x,n) ; DFT(b,x,n) ; for(int i=0;i<n;i++) c[i]=a[i]*b[i]%MOD ; DFT(c,xinv,n) ; for(int i=0;i<n;i++) { c[i]=c[i]*ninv%MOD ; if(c[i]>MAX) c[i]-=MOD ; } } int n,m,Q,L ; char s[maxn],t[maxn] ; ULL sbit[maxn],tbit[maxn] ; int match(int x,int y,int len) { int ret=0 ; for(;len>=64;len-=64,x+=64,y+=64) ret+=__builtin_popcountll(sbit[x]^tbit[y]) ; if(len) ret+=__builtin_popcountll( (sbit[x]<<(64-len))^(tbit[y]<<(64-len))) ; return ret ; } LL *poly[maxn],a[maxn],b[maxn] ; void precal() { n=strlen(s) , m=strlen(t) ; for(int i=0;i<n;i++) for(int j=0;j<64&&i+j<n;j++) if(s[i+j]=='1') sbit[i]|=(1ULL<<j) ; for(int i=0;i<m;i++) for(int j=0;j<64&&i+j<m;j++) if(t[i+j]=='1') tbit[i]|=(1ULL<<j) ; L=n*sqrt(150*log2(n)/Q) ; int n2=1 ; while(n2<L+m) n2*=2 ; for(int i=0;i<n;i+=L) { poly[i/L]=new LL[n2] ; memset(a,0,sizeof(a)) ; memset(b,0,sizeof(b)) ; for(int j=0;j<L && i+j<n;j++) a[j]=(s[i+j]=='0' ? -1 : 1) ; for(int j=0;j<m;j++) b[m-1-j]=(t[j]=='0' ? -1 : 1) ; mul(a,b,poly[i/L],n2) ; } } main() { scanf("%s%s%d",s,t,&Q) ; precal() ; while(Q--) { int x,y,l ; scanf("%d%d%d",&x,&y,&l) ; if(l<=L){printf("%d\n",match(x,y,l)) ; continue ;} int ans=0 ; if(x%L) { int l2=L-x%L ; ans+=match(x,y,l2) ; x+=l2 ; y+=l2 ; l-=l2 ; } if((x+l)%L) { int l2=(x+l-1)%L+1 ; ans+=match(x+l-l2,y+l-l2,l2) ; l-=l2 ; } if(!l){printf("%d\n",ans) ; continue ;} for(int i=x,j=y;l;i+=L,j+=L,l-=L) ans+=(L-poly[i/L][m-1-j])/2 ; printf("%d\n",ans) ; } }
沒有留言:
張貼留言