Last dance。
最后一篇文章,就写我两年前就看过但不敢尝试的题目吧。
首先,两数异或 \(\le x\) 的条件看起来是好维护的,显然可以 Trie 树上跑一跑,但我们发现当 \(x\) 某一位是 \(1\) 的时候非常难受,情况变得非常复杂。此时我进行了一些尝试,尝试直接刻画合法的 \(S\) 的结构,未果。
把思路调回 Trie 树,继续分类讨论。一个比较暴力的想法是直接设 \(g(p,q)\) 表示在 \(p,q\) 子树内选,并且限制在 \(p,q\) 之间(不考虑 \(p,q\) 自己内部的限制)的方案数。
仔细地分类讨论可以发现,如果 \(x\) 的这一位是 \(1\),\(p\) 内选了 \(0\) 的和 \(q\) 内选了 \(0\) 的不会产生限制,\(p\) 内选了 \(1\) 的不会和 \(q\) 内选了 \(1\) 的产生限制。那么方案数就是 \(g(\mathrm{trie}[p][0], \mathrm{trie}[q][1])\times g(\mathrm{trie}[p][1], \mathrm{trie}[q][0])\)。具体实现需要去除空集。时间复杂度的话,考虑一个点只会以一种路径被遍历到,所以复杂度 \(\mathcal O(n\log v)\)。
#include <bits/stdc++.h>
#define pb emplace_back
#define fir first
#define sec secondusing i64 = long long;
using pii = std::pair<int, int>;constexpr int maxn = 3e5 + 5, mod = 998244353;void add(int& x, int y) { if ((x += y) >= mod) x -= mod; return; }
void sub(int& x, int y) { if ((x -= y) < 0) x += mod; return; }
int inc(int x, int y) { return (x + y) >= mod ? (x + y - mod) : (x + y); }
int dec(int x, int y) { return (x < y) ? (x - y + mod) : (x - y); }
int n, lmt, trie[maxn * 30][2], siz[maxn * 30], sz, pw[maxn];void insert(int x) {int u = 1;for (int i = 29; ~i; --i) {int c = x >> i & 1;if (!trie[u][c]) trie[u][c] = ++sz;u = trie[u][c];++siz[u];}return;
}int calc(int x, int y, int bit) {if (!x || !y) return pw[siz[x] + siz[y]];if (x == y) {if (bit == -1) return pw[siz[x]];int c = lmt >> bit & 1;if (c) {return calc(trie[x][0], trie[x][1], bit - 1);} else {return dec(inc(calc(trie[x][0], trie[x][0], bit - 1), calc(trie[x][1], trie[x][1], bit - 1)), 1);}} else {if (bit == -1) return pw[siz[x] + siz[y]];int c = lmt >> bit & 1;if (c) {return 1ll * calc(trie[x][0], trie[y][1], bit - 1) * calc(trie[x][1], trie[y][0], bit - 1) % mod;} else {int rem = dec(inc(calc(trie[x][0], trie[y][0], bit - 1), calc(trie[x][1], trie[y][1], bit - 1)), 1);add(rem, 1ll * (pw[siz[trie[x][0]]] - 1) * (pw[siz[trie[x][1]]] - 1) % mod);add(rem, 1ll * (pw[siz[trie[y][0]]] - 1) * (pw[siz[trie[y][1]]] - 1) % mod);return rem;}}
}int main() {std::cin.tie(nullptr)->sync_with_stdio(false);std::cin >> n >> lmt, sz = 1;for (int i = 1; i <= n; ++i) {int y;std::cin >> y;insert(y);}for (int i = pw[0] = 1; i <= n; ++i) {pw[i] = 2 * pw[i - 1] % mod;}std::cout << (calc(1, 1, 29) - 1 + mod) % mod << '\n';return 0;
}
那么,在此结束吧。