首先把問題轉換為一維問題,記 a_i 代表落在第 i 行的東西是落在第幾列的,那麼我們要找的東西其實就是滿足以下條件的 ( i , j ) :如果 a_i ~ a_j 的最小值為 m ,最大值為 M ,那麼 m ~ M 之間的所有數字都出現在 a_i ~ a_j 中,而這其實也等價於 M - m = j - i ,這件事是個很好的充要條件,之後都會用到他。利用分治法,設現在要算所有被 [ L , R ] 包含的符合條件的區間有幾個,令 mid = ( L + R ) / 2 ,那麼對於 [ L mid ] 和 [ mid+1 , R ] ,遞迴下去處理就好了,所以現在只要處理左界 <= mid 且右界 > mid 的區間就好了。而如果 [ l , r ] 是一個橫跨 mid , mid+1 的區間,那麼這個區間內的最大和最小值就可以用 [ l , mid ] 和 [ mid+1 , r ] 組合出來,因此我們先算出以下這些陣列: lmin , lmax , rmin , rmax ,其中 lmin[ x ] = min { a[ x ] , ... , a[ mid ] } ,其餘類似。並且把目標的區間分成四種:( 以下簡稱 [ L , mid ] 為左邊,[ mid+1 , R ] 為右邊 )
1. [ l , r ] 中的最大值和最小值均落在左邊
2. [ l , r ] 中的最大值和最小值均落在右邊
3. [ l , r ] 中的最大值落在左邊,最小值落在右邊
4. [ l , r ] 中的最大值落在右邊,最小值落在左邊
我們只要會處理 1. 和 3. 就可以了,因為 2. 和 4. 只是他們左右反過來而已。首先處理 1. ,枚舉目標區間的左端點,那麼由 lmin 和 lmax 就可以知道目標區間的最大和最小值,那麼由前面講過的性質可以得到目標區間的長度了,也就是右端點也確定了,所以就可以再根據 rmin 和 rmax 來確認這個區間是不是合法的了。再來是第3種情況,一樣考慮枚舉左端點,設當前左端點為 x ,那麼這時候右端點 r 就必須滿足 rmin[ r ] < lmin[ x ] ,還有 rmax[ r ] < lmax[ x ] ,並且注意到 rmin 是個遞減的陣列,rmax 則是遞增的,因此第一個限制條件告訴我們 r 必須要夠大,他的 rmin 才可以夠小,而第二個限制條件告訴我們 r 必須要夠小,讓他的 rmax 不至於超過 lmax[ x ] ,因此所有可行的右端點們會形成一個區間。再來我們要知道這個區間裡有幾個可行的右端點。記 lmax[ x ] = M ,還有對於任意一個在可行區間內的 r ,記 rmin[ r ] = m ,那麼也就是現在有很多個候選的 ( r , m ) ,而我們需要的是 M - m = r - x 的 ( r , m ) ,也就是 r + m = M + x 的 ( r , m ) ,因此可以用一個 map 維護所有在可行區間中的 r 的 r + m 值,就可以直接查詢有幾個數會等於 M + x 了。
code :
#include<bits/stdc++.h> #define LL long long using namespace std; const int maxn=300000+10 ; int a[maxn] ; int lmi[maxn],lma[maxn],rmi[maxn],rma[maxn] ; LL ans=0LL ; map<int,int> mp ; void solve(int l,int r) { if(l==r) { ans++ ; return ; } int mid=(l+r)/2 ; solve(l,mid) ; solve(mid+1,r) ; for(int i=mid;i>=l;i--) lmi[i]=(i==mid ? a[i] : min(a[i],lmi[i+1])) , lma[i]=(i==mid ? a[i] : max(a[i],lma[i+1])) ; for(int i=mid+1;i<=r;i++) rmi[i]=(i==mid+1 ? a[i] : min(a[i],rmi[i-1])) , rma[i]=(i==mid+1 ? a[i] : max(a[i],rma[i-1])) ; for(int i=l;i<=mid;i++) { int ri=i+lma[i]-lmi[i] ; if(ri>mid && ri<=r && rmi[ri]>lmi[i] && rma[ri]<lma[i]) ans++ ; } for(int i=mid+1;i<=r;i++) { int le=i-rma[i]+rmi[i] ; if(le<=mid && le>=l && lmi[le]>rmi[i] && lma[le]<rma[i]) ans++ ; } mp.clear() ; mp[mid+1+rmi[mid+1]]++ ; for(int i=mid,L=mid+1,R=mid+1;i>=l;i--) { while(R+1<=r && rma[R+1]<lma[i]) { R++ ; if(R>=L) mp[R+rmi[R]]++ ; } while(L<=r && rmi[L]>lmi[i]) { L++ ; if(L-1<=R) mp[L-1+rmi[L-1]]-- ; } ans+=mp[lma[i]+i] ; } mp.clear() ; mp[mid+1-rma[mid+1]]++ ; for(int i=mid,L=mid+1,R=mid+1;i>=l;i--) { while(R+1<=r && rmi[R+1]>lmi[i]) { R++ ; if(R>=L) mp[R-rma[R]]++ ; } while(L<=r && rma[L]<lma[i]) { L++ ; if(L-1<=R) mp[L-1-rma[L-1]]-- ; } ans+=mp[i-lmi[i]] ; } } main() { int n ; scanf("%d",&n) ; for(int i=1;i<=n;i++) { int x,y ; scanf("%d%d",&x,&y) ; a[x]=y ; } solve(1,n) ; printf("%lld\n",ans) ; }
沒有留言:
張貼留言