题目
题目链接
解题思路
给定 \(n\) 个整数和一个整数 \(m\),要求计算所有可能的两两异或结果中大于 \(m\) 的数量。可以使用字典树(Trie)来高效地存储和查询异或结果。
- 字典树构建:将每个数字的二进制表示插入字典树中。
- 查询异或结果:对于每个数字,查找已有数字中与其异或大于 \(m\) 的数量。
- 统计结果:在查找过程中,统计符合条件的异或对的数量。
关键点
- 使用字典树存储二进制位,便于快速查找。
- 通过位运算判断异或结果是否大于 \(m\)。
- 采用贪心策略,优先选择能使异或结果更大的路径。
代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;const int N = 1e5 + 1;
int idx = 0;
int son[31 * N][2]; // 字典树节点
int num[31 * N]; // 节点计数int find(int x, int m) {int p = 0;int ans = 0;for (int i = 25; i >= 0; i--) {int bit = (x >> i) & 1;int mbit = (m >> i) & 1;if (mbit == 0) {// m的当前位是0,可以取更大的值if (son[p][1 - bit]) {ans += num[son[p][1 - bit]];}if (son[p][bit]) p = son[p][bit];else break;} else {// m的当前位是1,必须取更大的值if (son[p][1 - bit]) p = son[p][1 - bit];else break;}}return ans;
}void insert(int x) {int p = 0;for (int i = 25; i >= 0; i--) {int u = (x >> i & 1);if (!son[p][u]) {son[p][u] = ++idx;}p = son[p][u];num[p]++;}
}void solve() {int n, m;cin >> n >> m;i64 ans = 0;for (int i = 0; i < n; i++) {int x;cin >> x;ans += find(x, m); // 先查找insert(x); // 再插入}cout << ans;
}int main() {ios::sync_with_stdio(false);cin.tie(nullptr);solve();return 0;
}
import java.util.*;public class Main {static final int N = 100001;static int idx = 0;static int[][] son = new int[31 * N][2];static int[] num = new int[31 * N];static int find(int x, int m) {int p = 0;int ans = 0;for (int i = 25; i >= 0; i--) {int bit = (x >> i) & 1;int mbit = (m >> i) & 1;if (mbit == 0) {// m的当前位是0,可以取更大的值if (son[p][1 - bit] != 0) {ans += num[son[p][1 - bit]];}if (son[p][bit] != 0) p = son[p][bit];else break;} else {// m的当前位是1,必须取更大的值if (son[p][1 - bit] != 0) p = son[p][1 - bit];else break;}}return ans;}static void insert(int x) {int p = 0;for (int i = 25; i >= 0; i--) {int u = (x >> i & 1);if (son[p][u] == 0) {son[p][u] = ++idx;}p = son[p][u];num[p]++;}}public static void main(String[] args) {Scanner sc = new Scanner(System.in);int n = sc.nextInt();int m = sc.nextInt();long ans = 0;for (int i = 0; i < n; i++) {int x = sc.nextInt();ans += find(x, m);insert(x);}System.out.println(ans);sc.close();}
}
class TrieNode:def __init__(self):self.children = [None, None] # 0和1的子节点self.count = 0 # 该节点的计数class Trie:def __init__(self):self.root = TrieNode()self.idx = 0 # 节点索引def insert(self, x):p = self.rootfor i in range(25, -1, -1): # 从高位到低位bit = (x >> i) & 1if p.children[bit] is None:p.children[bit] = TrieNode()p = p.children[bit]p.count += 1 # 更新节点计数def find(self, x, m):p = self.rootans = 0for i in range(25, -1, -1): # 从高位到低位bit = (x >> i) & 1mbit = (m >> i) & 1if mbit == 0:# m的当前位是0,可以取更大的值if p.children[1 - bit] is not None:ans += p.children[1 - bit].countif p.children[bit] is not None:p = p.children[bit]else:breakelse:# m的当前位是1,必须取更大的值if p.children[1 - bit] is not None:p = p.children[1 - bit]else:breakreturn ansdef main():import sysinput = sys.stdin.readdata = input().split()n = int(data[0])m = int(data[1])trie = Trie()ans = 0for i in range(n):x = int(data[i + 2]) # 从第三个元素开始读取数字ans += trie.find(x, m) # 先查找trie.insert(x) # 再插入print(ans)if __name__ == "__main__":main()
算法及复杂度
- 算法:字典树(Trie)+ 位运算
- 时间复杂度:\(\mathcal{O}(n * 26)\),其中 \(n\) 是数组长度
- 空间复杂度:\(\mathcal{O}(31 * n)\),用于存储字典树节点