2015年4月29日 星期三

[HOJ 368] 矩型計數

這題我的作法幾乎跟官方解一樣,可以先看看官方解的第38頁到第60頁,雖然是日文的,不過單看圖的話其實就很好理解了。

作法:

考慮分治法,把所有點切成左右兩半,對於矩形的兩端點都落在同一區塊內的矩形只要遞迴下去處理就可以了,所以這裡只需要考慮左端點在左半部,右端點在右半部的矩形們。首先觀察到,假設 ( x1 , y1 ) 和 ( x2 , y2 ) 是分別是一個合法矩形的左下角和右上角,且一個在左半部一個在右半部,並且假設中央線(也就是把點集切成左右兩半部的那條鉛直線)的 x 座標為 x0 ,那麼這就代表左下角為 ( x1 , y1 ) ,右上角為 ( x0 , y2 ) 的矩形內沒有其他點,因此我們可以反過來想,假設 Y 是最小的數,滿足以 ( x1 , y1 ) , ( x0 , Y )  為左下、右上角的矩形中有其他的點,那麼在計算以 ( x1 , y1 ) 為左下角的合法矩形個數時,就只要考慮右半部中 y 座標介於 y1 ~ Y 的點就可以了。並且不難發現 Y 的值的計算方法,只要找「在左半部且在 ( x1 , y1 ) 右上方的點之中 y 座標的最小值」就可以了,而這就可以簡單的按照 x 座標由大到小作,用 set 的 lower_bound 就可以找到每個在左半部的點的 Y 值了。

至於 ( x2 , y2 ),類似的推導過程可以得到這次是反過來找「在右半部且在 ( x2 , y2 ) 左下方的點之中 y 座標的最大值」,假設他為 Y2 好了,那麼這次得到的區間就會是 Y2 ~ y2 。而這也可以用 set 輕鬆解決。

所以現在我們在左右兩邊都得到了一些線段,並且由這些線段的定義可以知道,如果左邊有一條線段 [ a , b ] ,右邊有一條線段 [ c , d ] ,那麼 a < c < b < d 若且唯若產生這兩條線段的點形成一個合法矩形的左下角和右上角,因此問題轉化為如何求滿足這個條件的線段組的數量。解決這個問題的想法是,對於每一條左邊的線段,都去計算這條線段和多少條右邊的線段可以滿足條件。假設這條線段為 [ a , b ] ,那麼考慮右邊線段中所有上端點 > b 的線段們,如果把他們的下端點的 y 座標都標記起來,那麼所求就會等於 [ a , b ] 之間的被標記起來的 y 座標個數,由這件事就可以得到這個算法:先離散化所有的線段端點,把左右的線段分別按照上端點高度排序,按照上端點由高到低來處理左邊的線段,每次遇到一個線段的時候,假設他為 [ a , b ] ,那麼就把右邊所有上端點 > b 的線段,把他的下端點的值 + 1 ,然後查詢 [ a , b ] 之間的和,而這就用個BIT來做就可以了。

code :



#include<bits/stdc++.h>
#define LL long long
#define INF 1000000001
#define lowbit(x) (x&-x)
using namespace std;
const int maxn=200000+10 ;
 
int bit[maxn] ;
void add(int x,int val,int ma)
{
    while(x<=ma) bit[x]+=val , x+=lowbit(x) ;
}
int query(int l,int r)
{
    int ret=0 ; l-- ;
    while(r) ret+=bit[r] , r-=lowbit(r) ;
    while(l) ret-=bit[l] , l-=lowbit(l) ;
    return ret ;
}
 
struct P
{
    int x,y ;
    bool operator < (const P &rhs) const
    {
        return x==rhs.x ? y<rhs.y : x<rhs.x ;
    }
}a[maxn];
 
LL ans=0LL ;
P segl[maxn],segr[maxn] ;
int y[2*maxn] ;
set<int> st ;
void solve(int l,int r)
{
    if(l==r) return ;
    int mid=(l+r)/2 ;
    solve(l,mid) ;
    solve(mid+1,r) ;
 
    st.clear() ; st.insert(INF) ;
    int cntl=0,cnt=0 ;
    y[cnt++]=INF ;
    for(int i=mid;i>=l;i--)
    {
        auto it=st.upper_bound(a[i].y) ;
        segl[cntl++]=(P){*it,a[i].y} ;
        st.insert(a[i].y) ;
        y[cnt++]=a[i].y ;
    }
 
    st.clear() ; st.insert(-INF) ;
    y[cnt++]=-INF ;
    int cntr=0 ;
    for(int i=mid+1;i<=r;i++)
    {
        auto it=st.upper_bound(a[i].y) ; it-- ;
        segr[cntr++]=(P){a[i].y,*it} ;
        st.insert(a[i].y) ;
        y[cnt++]=a[i].y ;
    }
 
    sort(y,y+cnt) ;
    cnt=unique(y,y+cnt)-y ;
    sort(segl,segl+cntl) ;
    sort(segr,segr+cntr) ;
    for(int i=0;i<cntl;i++)
        segl[i].x=lower_bound(y,y+cnt,segl[i].x)-y+1 ,
        segl[i].y=lower_bound(y,y+cnt,segl[i].y)-y+1 ;
    for(int i=0;i<cntr;i++)
        segr[i].x=lower_bound(y,y+cnt,segr[i].x)-y+1 ,
        segr[i].y=lower_bound(y,y+cnt,segr[i].y)-y+1 ;
 
    fill(bit,bit+cnt+1,0) ;
    for(int i=cntr-1,j=cntl-1;i>=0;i--)
    {
        while(j>=0 && segl[j].x >= segr[i].x)
            add(segl[j].y,1,cnt) , j-- ;
        ans+=query(segr[i].y,segr[i].x) ;
    }
}
 
main()
{
    int n ; scanf("%d",&n) ;
    for(int i=1;i<=n;i++) scanf("%d%d",&a[i].x,&a[i].y) ;
    sort(a+1,a+n+1) ;
    solve(1,n) ;
    printf("%lld\n",ans) ;
}
 

沒有留言:

張貼留言