Matvey's Birthday
题目链接。
Problem
给定一个仅包含 a
~h
的字符串(八个字符)。
有一个 \(n\) 个结点的无向图,编号为 \(0\) 到 \(n−1\)。结点 \(i\) 与结点 \(j\) 间有边相连当且仅当 \(|i-j|=1\) 或 \(S_i=S_j\)。
求这个无向图的直径和有多少对点间的最短距离与直径相同。
数据范围:\(2 \le n \le 10^5\)。
Sol
不难发现,直径一定不会超过 \(15\)。因为我可以通过传送后走一步来切换字符。然后这个东西不是很好贪心,\(n\) 又特别小,不难想到 DP。
不妨定义 \(f_{i, j}\) 表示点 \(i\) 走到第 \(j\) 个颜色的最短距离。这个东西是很好处理的,用类似于 BFS 的东西即可得到。
现在讨论直径,由于直径是 \(\max\limits_{i, j \in [1, n] \cap \mathbb Z}\min(|i - j|, \min\limits_c f_{i, c} + f_{j, c} + 1)\)。经过我们的转化,时间复杂度由 \(\mathcal{O}(n^2)\) 变为了 \(\mathcal{O}(n^2)\)!。
里面的 \(\min\) 是很不好的,由于答案一定不会超过 \(15\),可以把 \(|i - j| \le 15\) 的二元组 \((i, j)\) 单独拉出来跑。然后现在就只需要求 \(\max\limits_{i, j \in [1, n] \cap \mathbb Z} \min\limits_c f_{i, c} + f_{j, c} + 1\) 了。这个东西乍一看并不好做,但是发现传送只对颜色有要求,所以对于 \(a_i\) 相同的点,\(f_{i, c}\) 至多只会有两个值,因为如果不行的话,一定可以传送一次到达。如果记 \(g_{i, j}\) 表示颜色 \(i\) 走到颜色 \(j\) 的最小距离(\(g\) 可以在求 \(f\) 时一起得到),则显然有 \(g_{a_i, j} \le f_{i,j} \le g_{a_i, j} + 1\)。里面的 \(\min\) 不好拆掉,于是考虑枚举 \(i\) 算 \(j\) 的答案。然后不妨令 \(h_{i, j} = g_{a_i, j} - f_{i, j}\)。然后由于 \(h\) 的值域很小,可以把后面一维压起来,变为 \(h_{i}\)。发现相同的 \((a_i, h_i)\) 的答案,无论 \(c, j\) 是多少,答案一定是一样的。于是可以枚举 \(a_i, h_i, c, j\) 来进行统计。\(h_i\) 的值有 \(\mathcal{O}(2^c)\) 种,所以时间复杂度为 \(\mathcal{O}(nc^22^c)\)。空间 \(\mathcal{O}(nc + c2^c)\),这是因为我需要知道之前 \((a_i, h_i)\) 出现了多少次,这需要开一个桶。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, ll> pii;
#define fi first
#define se second
mt19937_64 eng(time(0) ^ clock());
template <typename T>
T rnd(T l, T r) { return eng() % (r - l + 1) + l; }
int n;
int a[100005];
int f[10][100005], g[10][10], buc[10][256];
vector<int> col[8];
int vsc[10];
pii operator + (pii a, pii b) {if (a.fi < b.fi) swap(a, b);return make_pair(a.fi, a.se + (a.fi == b.fi) * b.se);
}
int main() {scanf("%d", &n);string str;cin >> str;for (int i = 1; i <= n; i++)col[a[i] = str[i - 1] - 'a'].emplace_back(i);for (int i = 0; i < 8; i++) {queue<int> que;for (int j = 1; j <= n; j++)f[i][j] = n + 1;for (int j = 0; j < 8; j++)vsc[j] = 0, g[i][j] = n + 1;for (int j : col[i])que.emplace(j), f[i][j] = 0;g[i][i] = 0;vsc[i] = 1;while (!que.empty()) {int u = que.front();que.pop();if (u > 1 && f[i][u - 1] == n + 1)f[i][u - 1] = f[i][u] + 1, que.emplace(u - 1);if (u < n && f[i][u + 1] == n + 1)f[i][u + 1] = f[i][u] + 1, que.emplace(u + 1);if (vsc[a[u]])continue;vsc[a[u]] = 1;for (int j : col[a[u]])if (f[i][j] == n + 1)f[i][j] = f[i][u] + 1, que.emplace(j);}}for (int i = 0; i < 8; i++)for (int j = 1; j <= n; j++)g[i][a[j]] = min(g[i][a[j]], f[i][j]);pii ans = make_pair(0, 0);for (int i = 2; i <= n; i++)for (int j = max(1, i - 15); j < i; j++) {int v = i - j;for (int k = 0; k < 8; k++)v = min(v, f[k][i] + f[k][j] + 1);ans = ans + make_pair(v, 1);}for (int i = 17; i <= n; i++) {int w = 0;for (int j = 0; j < 8; j++)if (f[j][i - 16] - g[a[i - 16]][j] <= 1)w += (f[j][i - 16] - g[a[i - 16]][j]) << j;buc[a[i - 16]][w]++;for (int j = 0; j < 8; j++)for (int k = 0; k < (1 << 8); k++) {if (!buc[j][k]) continue;w = n + 1;for (int l = 0; l < 8; l++)w = min(w, f[l][i] + g[j][l] + (k >> l & 1) + 1);ans = ans + make_pair(w, buc[j][k]);}}printf("%d %lld\n", ans.fi, ans.se);return 0;
}