题目链接:https://www.luogu.com.cn/problem/P5169
暴力解法就是:线性基,然后对于每次询问的 \(x\),枚举 \((u,v)\) 对,跑一遍线性基就是否存在连接 \(u - v\) 的异或和为 \(x\) 的路径。时间复杂度 \(O(q \cdot n^2 \cdot 18)\)。很明显会超时。
按时很明显这样会超时。
以下是两种 AC 的解法。
- 解法1:线性基 + FWT + 并查集 (这是我一开始的解法)
- 解法2:线性基 + 2次FWT(更帅!)
解法1:线性基 + FWT + 并查集
- 路径性质转换:利用异或和的性质,任意两点间路径的异或和可表示为“树上唯一路径异或和”再异或上若干个“环的异或和”。
- 提取基底:通过 DFS 找出所有环,用线性基 a[] 维护这些环的异或贡献。
- 统计初步点对:利用 FWT 计算卷积 c[] * c[],求出仅通过树边时,异或和为 \(x\) 的点对数量(即满足 \(dis[u] \oplus dis[v] = x\) 的 \((u, v)\) 数量)。
- 并查集合并:重点在于这步。由于线性基中的环可以任意组合,如果点对 \((u, v)\) 的基础异或和为 \(x\),那么它可以通过异或环得到 \(x \oplus \text{linear\_basis\_elements}\)。你使用并查集将能互相转化的异或值合并,并将对应的点对数量累加到根节点。
- 查询:直接输出对应并查集根节点的统计值。
示例程序:
#include <bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
const int maxn = 1e5 + 5;int n, m, Q, a[18], dis[maxn];
long long c[(1<<18) + 5];
bool vis[maxn];
struct Edge {int v, w;
};
vector<Edge> g[maxn];void add(int x) {for (int i = 17; i >= 0; i--) {if ((x >> i) & 1) {if (!a[i]) {a[i] = x;return;}x ^= a[i];}}
}void dfs(int u, int sum) {vis[u] = true;dis[u] = sum;for (auto &[v, w] : g[u]) {if (!vis[v])dfs(v, sum ^ w);else {int tmp = dis[u] ^ dis[v] ^ w;if (tmp) {add(tmp);}}}
}int cal(int x) {for (int i = 17; i >= 0; i--) {if (a[i])x = min(x, x ^ a[i]);}return x;
}void fwt(long long a[], int step) { // 1:顺变换; 2:逆变换for (int x = 2; x <= (1<<18); x <<= 1) {int k = x >> 1;for (int i = 0; i < (1<<18); i += x) {for (int j = 0; j < k; j++) {long long s1 = a[i+j], s2 = a[i+j+k];a[i+j] = s1 + s2;a[i+j+k] = s1 - s2;if (step == 2) {a[i+j] /= 2;a[i+j+k] /= 2;}}}}
}struct DSU {int f[(1<<18)+5];long long sz[(1<<18)+5];void init() {for (int i = 0; i < (1<<18); i++) {f[i] = -1;sz[i] = c[i];}}int find(int x) {return (f[x] == -1) ? x : f[x] = find(f[x]);}void funion(int x, int y) {int a = find(x), b = find(y);if (a != b) {f[b] = a;sz[a] += sz[b];}}} dsu;void solve() {for (int i = 1; i <= n; i++)c[ dis[i] ]++;// 0特殊处理long long zero = 0;for (int i = 0; i < (1<<18); i++)zero += c[i] * (c[i] - 1);fwt(c, 1);for (int i = 0; i < (1<<18); i++)c[i] *= c[i];fwt(c, 2);c[0] = zero + n;dsu.init();for (int x = 0; x < (1<<18); x++) {for (int i = 17; i >= 0; i--) {dsu.funion(x, x ^ a[i]);}}
}int main() {scanf("%d%d%d", &n, &m, &Q);for (int i = 0, u, v, w; i < m; i++) {scanf("%d%d%d", &u, &v, &w);g[u].push_back({v, w});g[v].push_back({u, w});}dfs(1, 0);solve();for (int i = 0, x; i < Q; i++) {scanf("%d", &x);long long ans = dsu.sz[ dsu.find(x) ] % mod;printf("%lld\n", ans);}return 0;
}
解法2:线性基 + FWT(不需要并查集)
第二种解法能炫酷,直接两次 FWT 解决。
核心思路来自 Memory_of_winter大佬的博客
这段代码更“帅”,因为它将原本需要并查集维护的集合逻辑,巧妙地转化为了两次 FWT 卷积。
简要分析如下:
二次卷积逻辑:
- 核心原理:一个点对 \((u, v)\) 关于 \(x\) 是巧妙的,当且仅当存在一个线性基能表出的值 \(k\),使得 \(dist(u, v) \oplus k = x\)。这等价于 \(dist(u, v) \oplus x = k\)。
- 实现:将第一步得到的路径分布 c[] 与线性基指示函数 d[] 再次进行 异或卷积。卷积后的 c[x] 即为所有满足 \(dist(u, v) \oplus k = x\) 的组合数。
#include <bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
const int maxn = 1e5 + 5;int n, m, Q, a[18], dis[maxn];
long long c[(1<<18) + 5], d[(1<<18) + 5];
bool vis[maxn];
struct Edge {int v, w;
};
vector<Edge> g[maxn];void add(int x) {for (int i = 17; i >= 0; i--) {if ((x >> i) & 1) {if (!a[i]) {a[i] = x;return;}x ^= a[i];}}
}void dfs(int u, int sum) {vis[u] = true;dis[u] = sum;for (auto &[v, w] : g[u]) {if (!vis[v])dfs(v, sum ^ w);else {int tmp = dis[u] ^ dis[v] ^ w;if (tmp) {add(tmp);}}}
}int cal(int x) {for (int i = 17; i >= 0; i--) {if (a[i])x = min(x, x ^ a[i]);}return x;
}void fwt(long long a[], int step) { // 1:顺变换; 2:逆变换for (int x = 2; x <= (1<<18); x <<= 1) {int k = x >> 1;for (int i = 0; i < (1<<18); i += x) {for (int j = 0; j < k; j++) {long long s1 = a[i+j], s2 = a[i+j+k];a[i+j] = s1 + s2;a[i+j+k] = s1 - s2;if (step == 2) {a[i+j] /= 2;a[i+j+k] /= 2;}}}}
}void solve() {for (int i = 1; i <= n; i++)c[ dis[i] ]++;// 0特殊处理long long zero = 0;for (int i = 0; i < (1<<18); i++)zero += c[i] * (c[i] - 1);fwt(c, 1);for (int i = 0; i < (1<<18); i++)c[i] *= c[i];fwt(c, 2);c[0] = zero + n;for (int x = 0; x < (1<<18); x++) {int y = x;for (int i = 17; i >= 0; i--)y = min(y, y ^ a[i]);d[x] = !y;}fwt(c, 1);fwt(d, 1);for (int i = 0; i < (1<<18); i++)c[i] *= d[i];fwt(c, 2);
}int main() {scanf("%d%d%d", &n, &m, &Q);for (int i = 0, u, v, w; i < m; i++) {scanf("%d%d%d", &u, &v, &w);g[u].push_back({v, w});g[v].push_back({u, w});}dfs(1, 0);solve();for (int i = 0, x; i < Q; i++) {scanf("%d", &x);long long ans = c[x] % mod;printf("%lld\n", ans);}return 0;
}
