别样的线段树
D. Points
原题链接:https://codeforces.com/problemset/problem/19/D
开始思路:
看到题目后有一个想法,先将所有坐标进行离散化,在横坐标方向上建立线段树,每个节点维护一个 \(set\) 即对应区间 \(l\) ~ \(r\) 上 \(y\) 轴上的坐标,然后每次增删都可以在 \(O(log^2(n))\) 内完成,然后查询时,对区间进行直接二分,然后每次将对应区间的集合合并后取出,每次有效性检验检查是否存在大于当前查询的 \(y\),直接二分时间复杂度为 \(O(log(n))\),每次取出时间复杂度最坏情况下可 \(O(n)\),每次 \(upper\) _ \(bound\) 查询为 \(O(log(n))\),三者相乘,不出意外直接 \(tle\)了
\(tle\)代码:
#include<bits/stdc++.h>
using namespace std; typedef long long ll;
typedef pair<int,int> PII;const int N=2e5+10,mod=998244353;int n,m;vector<int> all,query;
vector<PII> points;struct Node{int l,r;set<int> ys;
}tr[4*N];void build(int u,int l,int r){if(l==r) tr[u]={l,r};else{tr[u]={l,r};int mid=l+r>>1;build(u<<1,l,mid),build(u<<1|1,mid+1,r);}
}void add(int u,int x,int y){if(tr[u].l==x&&tr[u].r==x) tr[u].ys.insert(y);else{int mid=tr[u].l+tr[u].r>>1;if(x<=mid) add(u<<1,x,y);else add(u<<1|1,x,y);tr[u].ys.insert(y);}
}void rme(int u,int x,int y){if(tr[u].l==x&&tr[u].r==x) tr[u].ys.erase(y);else{int mid=tr[u].l+tr[u].r>>1;if(x<=mid) rme(u<<1,x,y);else rme(u<<1|1,x,y);tr[u].ys.erase(y);}
}set<int> get(int u,int l,int r){if(tr[u].l>=l&&tr[u].r<=r) return tr[u].ys;else{int mid=tr[u].l+tr[u].r>>1;if(r<=mid) return get(u<<1,l,r);else if(l>mid) return get(u<<1|1,l,r);else{set<int> left,right;left=get(u<<1,l,r),right=get(u<<1|1,l,r);if(left.size()<=right.size()) swap(left,right);for(auto y:right) left.insert(y);return left;}}
}inline bool check(int L,int R,int mi){set<int> temp=get(1,L,R);return temp.upper_bound(mi)!=temp.end();
}PII find(int x,int y){int l=x+1,r=all.size();while(l<r){int mid=l+r>>1;if(check(l,mid,y))r=mid;else l=mid+1; }if(l>r||!check(l,l,y)) return {-1,-1};set<int> temp=get(1,l,l);int tx=l,ty=*temp.upper_bound(y);return {all[tx-1],all[ty-1]};
}void solve(){cin>>n;query=vector<int>(n);points=vector<PII>(n);string s;int x,y;for(int i=0;i<n;i++){cin>>s>>x>>y;query[i]=(s=="add")?1:(s=="remove")?2:3; points[i]={x,y};all.push_back(x),all.push_back(y);}sort(all.begin(),all.end());all.erase(unique(all.begin(),all.end()),all.end());build(1,1,all.size());for(int i=0;i<n;i++){auto [x,y]=points[i];x=lower_bound(all.begin(),all.end(),x)-all.begin()+1;y=lower_bound(all.begin(),all.end(),y)-all.begin()+1;if(query[i]==1) add(1,x,y);else if(query[i]==2) rme(1,x,y);else{PII t=find(x,y);if(t.first!=-1) cout<<t.first<<' '<<t.second<<'\n';else cout<<-1<<endl;}}
}int main() {cin.tie(0)->sync_with_stdio(false);cout.tie(0);int t=1;// cin>>t;while(t--)solve();
}
正确做法:线段树上二分
这题确实应该用 \(set\),但应该是直接存储对应 \(x\) 坐标上的 \(y\)坐标,而线段树应该维护对应区间下 \(y\) 方向上坐标的最大值。每次插入,先将对应 \(y\) 坐标插入到对应 \(x\) 坐标下的 \(set\) 中,然后再去看对应坐标下存储的最大值是否改变,然后再去对线段树进行修改;对于删除也是上面的思路。而查询则要用线段树二分,固定查询区间 \(l\)~\(r\)找到是否存在第一个存储大于 \(y\) 的横坐标,每次只去找与查询区间有交集的位置,并且在左边符合条件的情况下,优先查询左边。
代码:
#include<bits/stdc++.h>
using namespace std; typedef long long ll;
typedef pair<int,int> PII;const int N=4e5+10,mod=998244353;int n,m;vector<int> all,query;
vector<PII> points;struct Node{int l,r;int lmy;
}tr[4*N];set<int> colx[N];void build(int u,int l,int r){if(l==r) tr[u]={l,r,-1};else{tr[u]={l,r,-1};int mid=l+r>>1;build(u<<1,l,mid),build(u<<1|1,mid+1,r);}
}void add(int u,int x,int y){if(tr[u].l==x&&tr[u].r==x) tr[u].lmy=y;else{int mid=tr[u].l+tr[u].r>>1;if(x<=mid) add(u<<1,x,y);else add(u<<1|1,x,y);tr[u].lmy=max(tr[u<<1].lmy,tr[u<<1|1].lmy);}
}int find(int u,int l,int low){ // 线段树上二分查询if(l>all.size()) return -1;if(tr[u].l==tr[u].r){if(tr[u].lmy>low) return tr[u].l;return -1;}int mid=tr[u].l+tr[u].r>>1,res=-1;if(l<=mid&&tr[u<<1].lmy>low) res=find(u<<1,l,low);if(res!=-1) return res;if(tr[u<<1|1].lmy>low) return find(u<<1|1,l,low);return -1;
} void solve(){cin>>n;query=vector<int>(n);points=vector<PII>(n);string s;int x,y;for(int i=0;i<n;i++){cin>>s>>x>>y;query[i]=(s=="add")?1:(s=="remove")?2:3; points[i]={x,y};all.push_back(x),all.push_back(y);}sort(all.begin(),all.end());all.erase(unique(all.begin(),all.end()),all.end());build(1,1,all.size());unordered_map<int,int> mp;for(int i=0;i<all.size();i++) mp[all[i]]=i+1;for(int i=0;i<n;i++){auto [x,y]=points[i];x=mp[x], y=mp[y];if(query[i]==1){if(colx[x].empty()||*(--colx[x].end())<y) add(1,x,y);colx[x].insert(y);}if(query[i]==2){colx[x].erase(y);if(colx[x].empty()) add(1,x,-1);else add(1,x,*(--colx[x].end()));}if(query[i]==3){int t=find(1,x+1,y);if(t==-1) cout<<-1<<endl;else{int tx=t, ty=*colx[tx].upper_bound(y);cout<<all[tx-1]<<' '<<all[ty-1]<<endl;}}}
}int main() {cin.tie(0)->sync_with_stdio(false);cout.tie(0);int t=1;// cin>>t;while(t--)solve();
}