指针是个好东西
不过就是得判空
还有别忘传引用(其实应该都传引用)
#include#include #include using namespace std;int inf=0x7fffffff;struct node* nil;struct node{ int num; int val; int size; node* ch[2]; node (int v) :val(v) { size=1; num=1; ch[0]=ch[1]=nil; } void sum() { size=num; if(ch[0]!=nil) size+=ch[0]->size; if(ch[1]!=nil) size+=ch[1]->size; return ; } int cmp(int v) { if(val==v) return -1; return (val>v ? 0 : 1); } int cmpkth(int k) { int s=( ch[0]==nil ? 0 : ch[0]->size ); if(k>s&&k<=s+num) return -1; if(k<=s) return 0; else return 1; }};node* root;void visit(node* x){ if(x==nil) return ; visit(x->ch[0]); printf("%d ",x->val); visit(x->ch[1]); return ;}void rotato(node* &x,int base){ node* k=x->ch[base^1]; x->ch[base^1]=k->ch[base]; k->ch[base]=x; x->sum(); k->sum(); x=k;}void splay(node* &x,int v){ int d=x->cmp(v); if(d!=-1&&x->ch[d]!=nil) { int d2=x->ch[d]->cmp(v); if(d2!=-1&&x->ch[d]->ch[d2]!=nil) { splay(x->ch[d]->ch[d2],v); if(d==d2) rotato(x,d2^1),rotato(x,d^1); else rotato(x->ch[d],d2^1),rotato(x,d^1); } else rotato(x,d^1); }}void splaykth(node* &x,int k){ int d=x->cmpkth(k); if(d!=-1) { if(d==1) k=k-x->ch[0]->size-x->num; int d2=x->ch[d]->cmpkth(k); if(d2!=-1) { int k2=(d2==1 ? k-x->ch[d]->ch[0]->size-x->ch[d]->num : k); splaykth(x->ch[d]->ch[d2],k2); if(d==d2) rotato(x,d2^1),rotato(x,d^1); else rotato(x->ch[d],d2^1),rotato(x,d^1); } else rotato(x,d^1); } return ;}void pre(node* x,int val,int &ans){ if(x==nil) return ; if(x->val val>ans) ans=x->val; if(x->ch[1]!=nil) pre(x->ch[1],val,ans); } else if(x->val>=val&&x->ch[0]!=nil) pre(x->ch[0],val,ans);}void nxt(node* x,int val,int &ans){ if(x==nil) return ; if(x->val>val) { if(x->val val; if(x->ch[0]!=nil) nxt(x->ch[0],val,ans); } else if(x->val<=val&&x->ch[1]!=nil) nxt(x->ch[1],val,ans);}int find(node* &x,int val){ splay(x,val); return x->ch[0]->size+1;}int kth(node* &x,int k){ splaykth(x,k); return x->val;}node *spilt(node* &x,int val){ if(x==nil) return nil; splay(x,val); node* t1; node* t2; if(x->val<=val) t1=x,t2=x->ch[1],t1->ch[1]=nil; else t2=x,t1=x->ch[0],t2->ch[0]=nil; x->sum(); x=t1; return t2;}void merge(node* &t1,node* &t2){ if(t1==nil) swap(t1,t2); splay(t1,inf); t1->ch[1]=t2; t2=nil; t1->sum();}void insert(node* &x,int val){ //visit(x);printf("\n"); node* t2=spilt(x,val); //visit(x);printf("\n"); if(x!=nil&&x->val==val) { x->num+=1; x->sum(); } else { node* nw=new node(val); merge(x,nw); } merge(x,t2); //visit(root);printf("\n");}void erase(node* &x,int val){ node* t2=spilt(x,val); x->num-=1; if(x->num==0) { node* t3=x; x=x->ch[0]; delete t3; } merge(x,t2);}int read(){ int s=0,f=1; char in=getchar(); while(in<'0'||in>'9') { if(in=='-') f=-1; in=getchar(); } while(in>='0'&&in<='9') { s=(s<<1)+(s<<3)+in-'0'; in=getchar(); } return s*f;}int main(){ int n=read(); int a,b; int ans; nil=new node(0); root=nil->ch[0]=nil->ch[1]=nil; nil->size=nil->num=0; for(int i=1;i<=n;i++) { a=read(); b=read(); switch(a) { case 1: insert(root,b);break; case 2: erase(root,b);break; case 3: printf("%d\n",find(root,b));break; case 4: printf("%d\n",kth(root,b));break; case 5: insert(root,b);ans=-0x7fffffff;pre(root,b,ans);printf("%d\n",ans);erase(root,b);break; case 6: insert(root,b);ans=0x7fffffff;nxt(root,b,ans);printf("%d\n",ans);erase(root,b);break; default: break; } } return 0;}