线段树模板总结

线段树

线段树,也叫区间树,是一个完全二叉树,它在各个节点保存一条线段(即“子数组”),因而常用于解决数列维护问题,它基本能保证每个操作的复杂度为O(lgN)。

他大概长这样子:

《线段树模板总结》

上面的标号代表每个节点的标号,把一段区间利用这种方式可以实现O(lgN)的时间复杂度。

要建立线段树,先说一下我的线段树的风格(HH大牛的风格):

  • maxn是题目给的最大区间,而节点数要开4倍,确切的来说节点数要开大于maxn的最小2^x的两倍
  • lson和rson分辨表示结点的左儿子和右儿子,由于每次传参数的时候都固定是这几个变量,所以可以用预定于比较方便的表示
  • 以前的写法是另外开两个个数组记录每个结点所表示的区间,其实这个区间不必保存,一边算一边传下去就行,只需要写函数的时候多两个参数,结合lson和rson的预定义可以很方便
  • PushUP(int rt)是把当前结点的信息更新到父结点
  • PushDown(int rt)是把当前结点的信息更新给儿子结点
  • rt表示当前子树的根(root),也就是当前所在的结点

区间更新是指更新某个区间内的叶子节点的值,因为涉及到的叶子节点不止一个,而叶子节点会影响其相应的非叶父节点,那么回溯需要更新的非叶子节点也会有很多,如果一次性更新完,操作的时间复杂度肯定不是O(lgn),例如当我们要更新区间[0,3]内的叶子节点时,需要更新出了叶子节点3,9外的所有其他节点。为此引入了线段树中的延迟标记概念,这也是线段树的精华所在。

延迟标记:每个节点新增加一个标记,记录这个节点是否进行了某种修改(这种修改操作会影响其子节点),对于任意区间的修改,我们先按照区间查询的方式将其划分成线段树中的节点,然后修改这些节点的信息,并给这些节点标记上代表这种修改操作的标记。在修改和查询的时候,如果我们到了一个节点p,并且决定考虑其子节点,那么我们就要看节点p是否被标记,如果有,就要按照标记修改其子节点的信息,并且给子节点都标上相同的标记,同时消掉节点p的标记

实现各个不同功能的模板:

单点更新,区间求和(HDU1166 敌兵布阵)

#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 50020
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int sum[N<<2];
void pushup(int rt)
{
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void build(int l,int r,int rt)
{
    if(l==r)
    {
        scanf("%d",&sum[rt]);
        return;
    }
    int m=(l+r)>>1;
    build(lson);
    build(rson);
    pushup(rt);
}
void update(int p,int add,int l,int r,int rt)
{
    if(l==r)
    {
        sum[rt]+=add;
        return;
    }
    int m=(l+r)>>1;
    if(p<=m)
        update(p,add,lson);
    else
        update(p,add,rson);
    pushup(rt);
}
int query(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)
        return sum[rt];
    int m=(l+r)>>1;
    int ans=0;
    if(L<=m)
        ans+=query(L,R,lson);
    if(R>m)
        ans+=query(L,R,rson);
    return ans;
}
int main()
{
    int t,n,a,b,q=1;
    char s[10];
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d",&n);
        build(1,n,1);
        printf("Case %d:\n",q++);
        while(scanf("%s",s)&&s[0]!='E')
        {
            scanf("%d%d",&a,&b);
            if(s[0]=='Q')
                printf("%d\n",query(a,b,1,n,1));
            if(s[0]=='A')
                update(a,b,1,n,1);
            if(s[0]=='S')
                update(a,-b,1,n,1);
        }
    }
    return 0;
}

单点更新,区间求最值(HDU1754 I Hate It)

无非是对上面的区间求和稍微做了一些改变,把每次求和的过程变成了每次取最大值

#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 200040
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int MAX[N<<2];
void pushup(int rt)
{
    MAX[rt]=max(MAX[rt<<1],MAX[rt<<1|1]);
}
void build(int l,int r,int rt)
{
    if(l==r)
    {
        scanf("%d",&MAX[rt]);
        return;
    }
    int m=(l+r)>>1;
    build(lson);
    build(rson);
    pushup(rt);
}
void update(int p,int add,int l,int r,int rt)
{
    if(l==r)
    {
        MAX[rt]=add;
        return;
    }
    int m=(l+r)>>1;
    if(p<=m)
        update(p,add,lson);
    else
        update(p,add,rson);
    pushup(rt);
}
int query(int L,int R,int l,int r,int rt)
{
    if(L<=l&&r<=R)
        return MAX[rt];
    int m=(l+r)>>1;
    int ans=0;
    if(L<=m)
        ans=max(ans,query(L,R,lson));
    if(R>m)
        ans=max(ans,query(L,R,rson));
    return ans;
}
int main()
{
    int n,m,a,b;
    char s[5];
    while(~scanf("%d%d",&n,&m))
    {
        mem(MAX,0);
        build(1,n,1);
        while(m--)
        {
            scanf("%s%d%d",s,&a,&b);
            if(s[0]=='Q')
                printf("%d\n",query(a,b,1,n,1));
            if(s[0]=='U')
                update(a,b,1,n,1);
        }
    }
    return 0;
}

重点 区间更新

  • 区间求和,把一段区间加上某一个值(POJ3468A Simple Problem with Integers )
#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 100050
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
ll sum[N<<2],lazy[N<<2];
void pushup(ll rt)
{
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void pushdown(ll rt,ll m)
{
    if(lazy[rt])
    {
        lazy[rt<<1]+=lazy[rt];
        lazy[rt<<1|1]+=lazy[rt];
        sum[rt<<1]+=lazy[rt]*(m-(m>>1));
        sum[rt<<1|1]+=lazy[rt]*(m>>1);
        lazy[rt]=0;
    }
}
void build(ll l,ll r,ll rt)
{
    lazy[rt]=0;
    if(l==r)
    {
        scanf("%lld",&sum[rt]);
        return;
    }
    ll m=(l+r)>>1;
    build(lson);
    build(rson);
    pushup(rt);
}
void update(ll L,ll R,ll c,ll l,ll r,ll rt)
{
    if(L<=l&&r<=R)
    {
        lazy[rt]+=c;
        sum[rt]+=(ll)c*(r-l+1);
        return;
    }
    pushdown(rt,r-l+1);
    ll m=(l+r)>>1;
    if(L<=m) update(L,R,c,lson);
    if(m<R) update(L,R,c,rson);
    pushup(rt);
}
ll query(ll L,ll R,ll l,ll r,ll rt)
{
    if(L<=l&&r<=R)
        return sum[rt];
    pushdown(rt,r-l+1);
    ll m=(l+r)>>1;
    ll ans=0;
    if(L<=m)
        ans+=query(L,R,lson);
    if(R>m)
        ans+=query(L,R,rson);
    return ans;
}
int main()
{
    ll n,m,a,b,c;
    char s[5];
    scanf("%lld%lld",&n,&m);
    build(1,n,1);
    while(m--)
    {
        scanf("%s",s);
        if(s[0]=='Q')
        {
            scanf("%lld%lld",&a,&b);
            printf("%lld\n",query(a,b,1,n,1));
        }
        if(s[0]=='C')
        {
            scanf("%lld%lld%lld",&a,&b,&c);
            update(a,b,c,1,n,1);
        }
    }
    return 0;
}
  • HDU1698 Just a Hook 把一段区间的值变成某一个值
#include <cstdio>  
#include <cstring>  
#include <cctype>  
#include <string>  
#include <set>  
#include <iostream>  
#include <stack>  
#include <cmath>  
#include <queue>  
#include <vector>  
#include <algorithm>  
#define mem(a,b) memset(a,b,sizeof(a))  
#define inf 0x3f3f3f3f  
#define mod 10000007  
#define debug() puts("what the fuck!!!")  
#define N 111111  
#define M 1000000  
#define ll long long  
using namespace std;  
#define lson l,m,rt<<1  
#define rson m+1,r,rt<<1|1  
int lazy[4*N];//用来标记,为0表示没有被标记,以要更新的值来做标记  
int sum[4*N];//sum[i]代表以i为根节点的和  
void pushup(int rt)//向上更新和  
{  
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];  
  
}  
void pushdown(int rt,int m)//对某一个区间进行改变,如果被标记了,在查询的时候就得把改变传给子节点,因为查询的并不一定是当前区间  
{  
    //m为区间长度  
    if(lazy[rt])  
    {  
        lazy[rt<<1]=lazy[rt<<1|1]=lazy[rt];//传递给子节点  
        //更新左儿子和右儿子的和  
        sum[rt<<1]=(m-(m>>1))*lazy[rt];  
        sum[rt<<1|1]=(m>>1)*lazy[rt];  
        lazy[rt]=0;//取消对当前节点的标记  
    }  
}  
void build(int l,int r,int rt)  
{  
    lazy[rt]=0;//初始化左右节点都没有被标记  
    sum[rt]=1;//初始值都为1  
    if(l==r) return;  
    int m=(l+r)>>1;  
    build(lson);  
    build(rson);  
    pushup(rt);  
}  
void update(int L,int R,int c,int l,int r,int rt)  
{  
    if(L<=l&&r<=R)  
    {  
        lazy[rt]=c;  
        sum[rt]=c*(r-l+1);//更新代表某个区间的节点和  
        //printf("sum[%d]=%d,L=%d,R=%d,c=%d,lazy[rt]=%d\n",rt,sum[rt],L,R,c,lazy[rt]);  
        return;  
    }  
    pushdown(rt,r-l+1);//向下传递  
    int m=(l+r)>>1;  
    if(L<=m)  update(L,R,c,lson);  
    if(R>m)   update(L,R,c,rson);  
    pushup(rt);//向上传递更新和  
}  
int n,m,t,a,b,c,q=1;  
int main()  
{  
    scanf("%d",&t);  
    while(t--)  
    {  
        scanf("%d%d",&n,&m);  
        build(1,n,1);//建立线段树  
        while(m--)  
        {  
            scanf("%d%d%d",&a,&b,&c);  
            update(a,b,c,1,n,1);//更新区间  
        }  
        printf("Case %d: The total value of the hook is %d.\n",q++,sum[1]);  
    }  
    return 0;  
}

线段树的区间染色问题 (ZOJ1610 Count the Colors)

#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 10050x
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
int sum[N<<2],num[N],tot;
void pushdown(int rt)
{
    if(sum[rt]!=-1)
    {
        sum[rt<<1]=sum[rt<<1|1]=sum[rt];
        sum[rt]=-1;
    }
}
void update(int L,int R,int c,int l,int r,int rt)
{
    if(L<=l&&r<=R)
    {
        sum[rt]=c;
        return;
    }
    pushdown(rt);
    int m=(l+r)>>1;
    if(L<=m) update(L,R,c,lson);
    if(m<R) update(L,R,c,rson);
}
void query(int l,int r,int rt)
{
    if(l==r)
    {
        if(sum[rt]>=0&&sum[rt]!=tot)
            num[sum[rt]]++;
        tot=sum[rt];//如果区间连续就不记录
        return;
    }
    pushdown(rt);
    int m=(l+r)>>1;
    query(lson);
    query(rson);
}
int main()
{
    int n,a,b,c;
    while(~scanf("%d",&n))
    {
        mem(num,0);
        mem(sum,-1);
        tot=-1;
        for(int i=1; i<=n; i++)
        {
            scanf("%d%d%d",&a,&b,&c);
            update(a+1,b,c,1,8000,1);
        }
        query(1,8000,1);
        for(int i=0; i<=8000; i++)
            if(num[i])
                printf("%d %d\n",i,num[i]);
        puts("");
    }
    return 0;
}

扫描线,离散化,矩形面积并POJ1151 Atlantis

#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 220
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
struct Seg
{
    double l,r,h;
    int f;
    Seg() {}
    Seg(double a,double b,double c,int d):l(a),r(b),h(c),f(d) {}
    bool operator < (const Seg &cmp) const
    {
        return h<cmp.h;
    }
} e[N];
struct node
{
    int cnt;
    double len;
} t[N<<2];
double X[N];
void pushdown(int l,int r,int rt)
{
    if(t[rt].cnt)
        t[rt].len=X[r+1]-X[l];
    else if(l==r)
        t[rt].len=0;
    else
        t[rt].len=t[rt<<1].len+t[rt<<1|1].len;
}
void update(int L,int R,int l,int r,int rt,int val)
{
    if(L<=l&&r<=R)
    {
        t[rt].cnt+=val;
        pushdown(l,r,rt);
        return;
    }
    int m=(l+r)>>1;
    if(L<=m) update(L,R,lson,val);
    if(R>m) update(L,R,rson,val);
    pushdown(l,r,rt);
}
int main()
{
    int n,q=1;
    double a,b,c,d;
    while(~scanf("%d",&n)&&n)
    {
        mem(t,0);
        int num=0;
        for(int i=0; i<n; i++)
        {
            scanf("%lf%lf%lf%lf",&a,&b,&c,&d);
            X[num]=a;
            e[num++]=Seg(a,c,b,1);
            X[num]=c;
            e[num++]=Seg(a,c,d,-1);
        }
        sort(X,X+num);
        sort(e,e+num);
        int m=unique(X,X+num)-X;
        double ans=0;
        for(int i=0; i<num; i++)
        {
            int l=lower_bound(X,X+m,e[i].l)-X;
            int r=lower_bound(X,X+m,e[i].r)-X-1;
            update(l,r,0,m,1,e[i].f);
            ans+=t[1].len*(e[i+1].h-e[i].h);
        }
        printf("Test case #%d\nTotal explored area: %.2lf\n\n",q++,ans);
    }
    return 0;
}

扫描线,离散化,矩形面积交( HDU1255 覆盖的面积)

#include <cstdio>
#include <cstring>
#include <cctype>
#include <string>
#include <set>
#include <iostream>
#include <stack>
#include <cmath>
#include <queue>
#include <vector>
#include <algorithm>
#define mem(a,b) memset(a,b,sizeof(a))
#define inf 0x3f3f3f3f
#define N 2200
#define ll long long
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
struct Seg
{
    double l,r,h;
    int f;
    Seg() {}
    Seg(double a,double b,double c,int d):l(a),r(b),h(c),f(d) {}
    bool operator < (const Seg &cmp) const
    {
        return h<cmp.h;
    }
} e[N];
struct node
{
    int cnt;
    double len,s;
} t[N<<2];
double X[N];
void pushdown(int l,int r,int rt)
{
    if(t[rt].cnt)
        t[rt].len=X[r+1]-X[l];
    else if(l==r)
        t[rt].len=0;
    else
        t[rt].len=t[rt<<1].len+t[rt<<1|1].len;
    if(t[rt].cnt>1)
        t[rt].s=X[r+1]-X[l];
    else if(l==r)
        t[rt].s=0;
    else if(t[rt].cnt==1)
        t[rt].s=t[rt<<1].len+t[rt<<1|1].len;
    else
        t[rt].s=t[rt<<1].s+t[rt<<1|1].s;
}
void update(int L,int R,int l,int r,int rt,int val)
{
    if(L<=l&&r<=R)
    {
        t[rt].cnt+=val;
        pushdown(l,r,rt);
        return;
    }
    int m=(l+r)>>1;
    if(L<=m) update(L,R,lson,val);
    if(R>m) update(L,R,rson,val);
    pushdown(l,r,rt);
}
int main()
{
    int n,q;
    double a,b,c,d;
    scanf("%d",&q);
    while(q--)
    {
        scanf("%d",&n);
        mem(t,0);
        int num=0;
        for(int i=0; i<n; i++)
        {
            scanf("%lf%lf%lf%lf",&a,&b,&c,&d);
            X[num]=a;
            e[num++]=Seg(a,c,b,1);
            X[num]=c;
            e[num++]=Seg(a,c,d,-1);
        }
        sort(X,X+num);
        sort(e,e+num);
        int m=unique(X,X+num)-X;
        double ans=0;
        for(int i=0; i<num; i++)
        {
            int l=lower_bound(X,X+m,e[i].l)-X;
            int r=lower_bound(X,X+m,e[i].r)-X-1;
            update(l,r,0,m,1,e[i].f);
            ans+=t[1].s*(e[i+1].h-e[i].h);
        }
        printf("%.2lf\n",ans);
    }
    return 0;
}

 

 

 

 

 

 

 

点赞

发表评论

电子邮件地址不会被公开。 必填项已用*标注