基本上和這篇一模一樣。TIOJ 上的那題沒有強制在線,所以可以用莫隊作。而且 TIOJ 還有保證數字是介於 1 ~ n 之間的兩兩相異的數,這裡則沒有,所以在一些步驟上會麻煩許多。以下先解釋 TIOJ 那個版本的作法。
一樣是把整個序列分成好幾塊,每塊的長度都是 x ( 最後一塊除外 ),以下記第 i 塊為 A_i,還有原始的序列為 a[ i ] 。我們先看要如何求出一個詢問的答案,再來決定應該要預處理好甚麼東西。
假設現在在詢問的左界和右界分別落在第 i 塊和第 j 塊中,首先我們求出第 i 塊到第 j 塊這一整條的逆序數對數有多少,然後再扣掉不該算的。把要算的東西分成落在同一塊內和落在不同塊內,如果是落在同一塊內的,那麼就可以預處理好每一塊裡面的逆序數對數有多少,但從第 i 塊一直加到第 j 塊太慢了,所以要預處理的應該是每塊裡面的逆序數對數的前綴和。至於落在不同塊內的,如果先算好對於任意兩塊,有多少逆序數對是由他們兩個產生的,那麼在算答案時需要的就是某種這個陣列的二維前綴和,所以預處理好的東西應該也是某種二維前綴和,這樣才可以 O( 1 ) 得到他的個數。
再來是要扣掉不該算的,圖中的兩塊紅色區塊代表不在詢問範圍內的數字。首先對每個落在左邊紅色區塊的數,扣掉「在他右邊且和他同塊且值小於他」的個數,對每個落在右邊紅色區塊的數則是扣掉「在他左邊且和他同塊且值大於他」的個數。再來對於每個在左邊紅色區塊的數,還要扣掉「在 i + 1 ~ j 塊中和他形成逆序數對的數的個數」,在右邊紅色區塊的則是扣掉「在 i ~ j - 1 塊中和他形成逆序數對的數的個數」。而一塊一塊加也太慢了,所以預處理好的應該是前綴和陣列。
但這樣會多扣掉一些東西,也就是一個數落在左邊紅色區塊,一個數落在右邊紅色區塊的逆序數對數,必須把它們加回來。而我們會用 merge sort 算一塊裡的逆序數對數,所以可以順便得到每一塊排序之後的結果,這樣就可以用 O( 區間長度 ) 的時間把兩邊都 sort 好了,詳細作法應該是參考 code 會比較好理解。最後再對它雙指標即可。
綜合以上,我們會需要預處理以下這些值:
1. x[ i ] : sigma ( j = 1 ~ i ) ( 第 j 塊內的逆序數對數 )
2. y[ i ][ j ] : sigma ( p = 1 ~ i ) sigma ( q = p+1 ~ j ) INV( p , q ) ,其中 INV( p , q ) 代表A_p 和 A_q 之間形成的逆序數對數。
3. z[ i ][ j ] : sigma ( p = 1 ~ j ) ( 第 i 個數對第 p 塊產生的逆序數對數 )
( 如果 i 在第 p 塊的話那那一項就是 0 ,但其實不會影響結果,因為之後查詢是查詢前綴和相減。 )
4. u[ i ] : 在 i 左邊且和 i 同塊且值小於 a[ i ] 的個數
5. v[ i ] : 在 i 右邊且和 i 同塊且值大於 a[ i ] 的個數
x 陣列的算法就是直接 merge sort ,並且分開記錄原始的 a 陣列和對每塊排序過後的新陣列,之後會用到。 y 的算法則是對兩塊排序好的陣列雙指標就好了,最後再把前綴和處理起來。 z 的算法比較神奇,作法是先枚舉每一塊,然後對於每個值 p ,去看看這一塊裡有幾個數大於 p 和小於 p ,有了這個資訊之後去看看值為 p 的位置,假設 a[ q ] = p 好了,就可以在「 q 對這一塊產生的逆序數對數」 的值加上這一塊所貢獻的值,最後再處理前綴和就好。至於 u[ i ] 和 v[ i ] 就是簡單的掃過去而已。
到這裡就成功解決了 TIOJ 版的問題 ( 建議先讀懂上面那個網站的 code 再繼續往下看 ),回到 TOJ 的,首先要先把數字離散化,並且在算 z 陣列的值的時候,會需要一個數字出現在哪些位置,所以需要用 n 個 vector 記錄每個數字分別出現在哪些位置。再來則是最後在 O( L ) sort 的部分,因為這時候一個數字可能有很多個位置,所以如果按照原本的方法只標記數值為 1 或是 2 的話會分不清楚,只好改成標記位置。我們需要知道到底應該要標記哪個位置的數,所以只好在 merge sort 的地方多維護一個 id 值,這樣才能知道在 sort 完的陣列裡面的每個數原本的 index 是多少,那麼取這個陣列的反函數就可以得到應該要標記哪個位置了。
code :
#include<bits/stdc++.h> using namespace std; const int maxn=30000+10,maxm=450 ; int n,K,num ; int a[maxn],s[maxn] ; vector<int> pos[maxn],vec ; struct P{int id,val;}a2[maxn],tmp[maxn] ; int per[maxn] ; void cal() { sort(vec.begin(),vec.end()) ; vec.resize(unique(vec.begin(),vec.end())-vec.begin()) ; for(int i=1;i<=n;i++) a[i]= upper_bound(vec.begin(),vec.end(),a[i])-vec.begin() , pos[a[i]].push_back(i) , a2[i]=(P){i,a[i]} ; } int inv_cnt ; void merge(int l,int r) { if(l==r) return ; int mid=(l+r)/2 ; merge(l,mid) ; merge(mid+1,r) ; for(int i=l,j=mid+1,cnt=l ; i<=mid || j<=r ; ) { if(j==r+1 || (i!=mid+1 && a2[i].val<=a2[j].val)) tmp[cnt++]=a2[i++] ; else tmp[cnt++]=a2[j++] , inv_cnt+=mid+1-i ; } for(int i=l;i<=r;i++) a2[i]=tmp[i] ; } inline void get(int t,int &st,int &ed,int &id) { id= (t-1)/K+1 ; st= K*(id-1)+1 ; ed= min(K*id,n) ; } int x[maxn],y[maxm][maxm],z[maxn][maxm] ; int u[maxn],v[maxn] ; int type[maxn] ; int tmpl[maxn],tmpr[maxn] ; main() { scanf("%d",&n) ; for(int i=1;i<=n;i++) scanf("%d",&a[i]) , vec.push_back(a[i]) ; cal() ; int Q ; scanf("%d",&Q) ; K= (int)(n/sqrt(Q+0.5)) ; num= (n%K==0 ? n/K : n/K+1) ; for(int i=1;i<=num;i++) { int st=K*(i-1)+1 , ed=min(K*i,n) ; inv_cnt=0 ; merge(st,ed) ; for(int j=st;j<=ed;j++) s[j]=a2[j].val , per[a2[j].id]=j ; x[i]=x[i-1]+inv_cnt ; } for(int i=1;i<=num;i++) for(int j=i+1;j<=num;j++) { int cnt=0 , ed=min(j*K,n) ; for(int i2=(i-1)*K+1 , j2=(j-1)*K ; i2<=i*K ; i2++) { while(j2<ed && s[j2+1]<s[i2]) j2++ ; cnt+= j2-(j-1)*K ; } y[i][j]=y[i][j-1]+cnt ; } for(int i=1;i<=num;i++) for(int j=i+1;j<=num;j++) y[i][j]+=y[i-1][j] ; for(int i=1;i<=num;i++) { int now1=K*(i-1) , now2=now1+1 , ed=min(n,K*i) ; for(int j=1;j<=vec.size();j++) { while(now1<ed && s[now1+1]<j) now1++ ; while(now2<=ed && s[now2]<=j) now2++ ; for(auto k : pos[j]) { if(k<=ed && k> K*(i-1)) continue ; if(k>ed) z[k][i]+=(ed-now2+1) ; else z[k][i]+=(now1-K*(i-1)) ; } } } for(int i=1;i<=n;i++) for(int j=1;j<=num;j++) z[i][j]+=z[i][j-1] ; for(int i=1;i<=num;i++) { int st=K*(i-1)+1 , ed=min(n,K*i) ; for(int j=st;j<=ed;j++) { for(int k=st;k<j;k++) if(a[k]>a[j]) u[j]++ ; for(int k=j+1;k<=ed;k++) if(a[k]<a[j]) v[j]++ ; } } int L,R , ans ; for(int q0=1;q0<=Q;q0++) { if(q0==1) scanf("%d%d",&L,&R) ; else { ans %= n ; L=(ans+2217+q0)%n+1 ; R=(ans*2217+q0)%n+1 ; if(L>R) swap(L,R) ; } if(L==R) { printf("%d\n",ans=0) ; continue ; } int stl,edl,str,edr,idl,idr ; get(L,stl,edl,idl) ; get(R,str,edr,idr) ; ans=0 ; for(int i=stl;i<=edl;i++) type[i]=0 ; for(int i=str;i<=edr;i++) type[i]=0 ; ans+= x[idr]-x[idl-1] ; ans+= y[idr-1][idr]-y[idl-1][idr] ; for(int i=stl;i<L;i++) ans-=(z[i][idr]-z[i][idl]) , ans-=v[i] , type[per[i]]=1 ; for(int i=R+1;i<=edr;i++) ans-=(z[i][idr-1]-z[i][idl-1]) , ans-=u[i] , type[per[i]]=2 ; int nl=0 , nr=0 ; for(int i=stl;i<=edl;i++) if(type[i]==1) tmpl[nl++]=s[i] ; for(int i=str;i<=edr;i++) if(type[i]==2) tmpr[nr++]=s[i] ; for(int i=0 , j=-1;i<nl;i++) { while(j+1<nr && tmpr[j+1]<tmpl[i]) j++ ; ans+=j+1 ; } printf("%d\n",ans) ; } }
沒有留言:
張貼留言