(模板)Splay 平衡树

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作: 1.插入 xx 数 2.删除 xx 数(若有多个相同的数,因只删除一个) 3.查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 。若有多个相同的数,因输4.出最小的排名) 5.查询排名为 xx 的数 6.求 xx 的前驱(前驱定义为小于 xx ,且最大的数) 7.求 xx 的后继(后继定义为大于 xx ,且最小的数)

洛谷P3369 不讲解,直接上代码:

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int MAXN=1000000;
int ch[MAXN][2],f[MAXN],size[MAXN],cnt[MAXN],key[MAXN];
int nodecnt,root;
void clear(int x)//清除(当这个点被删除之后)
{
    ch[x][0]=ch[x][1]=f[x]=size[x]=cnt[x]=key[x]=0;
}
bool get(int x)//判断是父节点的左节点还是右节点
{
    return ch[f[x]][1]==x;
}
void update(int x)//更新当前点以下的元素个数(发生修改以后更新)
{
    if (x){
        size[x]=cnt[x];
        if (ch[x][0]) size[x]+=size[ch[x][0]];
        if (ch[x][1]) size[x]+=size[ch[x][1]];
    }
}
void rotate(int x)//旋转
{
    int old=f[x],oldf=f[old],whichx=get(x);
    ch[old][whichx]=ch[x][whichx^1]; f[ch[old][whichx]]=old;
    ch[x][whichx^1]=old; f[old]=x;
    f[x]=oldf;
    if (oldf)
        ch[oldf][ch[oldf][1]==old]=x;
    update(old); update(x);
}
void splay(int x)//splay,不停地rotate直到满足要求
{
    for (int fa;fa=f[x];rotate(x))
        if (f[fa])
            rotate((get(x)==get(fa))?fa:x);
    root=x;
}
void insert(int x)//插入
{
    if (root==0) {
        nodecnt++;
        ch[nodecnt][0] = ch[nodecnt][1] = f[nodecnt] = 0;
        root = nodecnt;
        size[nodecnt] = cnt[nodecnt] = 1;
        key[nodecnt] = x;
        return;
    }
    int now=root,fa=0;
    while(1){
        if (x==key[now]){
            cnt[now]++; update(now); update(fa); splay(now); break;
        }
        fa=now;
        now=ch[now][key[now]<x];
        if (now==0){
            nodecnt++;
            ch[nodecnt][0]=ch[nodecnt][1]=0;
            f[nodecnt]=fa;
            size[nodecnt]=cnt[nodecnt]=1;
            ch[fa][key[fa]<x]=nodecnt;
            key[nodecnt]=x;
            update(fa);
            splay(nodecnt);
            break;
        }
    }
}
int find(int x)//查询x的排名(排名定义为比当前数小的数的个数 +1+1 。若有多个相同的数,应输出最小的排名)
{
    int now=root,ans=0;
    while(1){
        if (x<key[now])
            now=ch[now][0];
        else{
            ans+=(ch[now][0]?size[ch[now][0]]:0);
            if (x==key[now]){
                splay(now); return ans+1;
            }
            ans+=cnt[now];
            now=ch[now][1];
        }
    }
}
int findx(int x)//寻找排名为x的数
{
    int now=root;
    while(1){
        if (ch[now][0]&&x<=size[ch[now][0]])
            now=ch[now][0];
        else{
            int temp=(ch[now][0]?size[ch[now][0]]:0)+cnt[now];
            if (x<=temp) return key[now];
            x-=temp; now=ch[now][1];
        }
    }
}
int pre()//求 xx 的前驱(前驱定义为小于 xx ,且最大的数),注意这个数不一定在树中
{
    int now=ch[root][0];
    while (ch[now][1]) now=ch[now][1];
    return now;
}
//注意,查找某个数的前驱后驱时,要先把这个数插进树中,然后把这个数转到根节点(insert中已集成),
//求x的前驱其实就是求x的左子树的最右边的一个结点,后继是求x的右子树的左边一个结点,求完后要删除这个点
int next1()//求 xx 的后继(后继定义为大于 xx ,且最小的数)
{
    int now=ch[root][1];
    while (ch[now][0]) now=ch[now][0];
    return now;
}
void del(int x)//删除值为x的节点
{
    int whatever=find(x);//主要作用就是把x旋转到根节点...不能省略
    if (cnt[root]>1) {
        cnt[root]--;
        update(root);
        return;
    }
    if (!ch[root][0]&&!ch[root][1]) {
        clear(root);
        root = 0;
        return;
    }
    if (!ch[root][0]) {
        int oldroot = root;
        root = ch[root][1];
        f[root] = 0;
        clear(oldroot);
        return;
    }
    else if (!ch[root][1]) {
        int oldroot = root;
        root = ch[root][0];
        f[root] = 0;
        clear(oldroot);
        return;
    }
    int leftbig=pre(),oldroot=root;
    splay(leftbig);
    ch[root][1]=ch[oldroot][1];
    f[ch[oldroot][1]]=root;
    clear(oldroot);
    update(root);
}
 int main()
 {
     int n,i,j,k;
     cin>>n;
     for(i=1;i<=n;i++){
         scanf("%d%d",&j,&k);
         if(j==1){
             insert(k);
         }
         else if(j==2){
             del(k);
         }
         else if(j==3){
             cout<<find(k)<<endl;
         }
         else if(j==4){
             cout<<findx(k)<<endl;
         }
         else if(j==5){
             insert(k);cout<<key[pre()]<<endl;del(k);
         }
         else if(j==6){
             insert(k);cout<<key[next1()]<<endl;del(k);
         }
     }

     return 0;
 }