P9055 [集训队互测 2021] 数列重排
部分分其实可以给出很多的启发。
首先 \(f(0)\) 显然任何区间都能满足条件,答案应该是 \(\frac{n(n - 1)}{2}\)。
然后考虑 \(f(m)\),一种构造方式是先来 \(X\) 组 \(0、1、 ...、 m - 1\),此时所有长度 \(\geq m\) 的区间都满足条件。但我们还会有一些多出来的数,处理方式是将这些数放到序列开头,并且调整每一组使得开头的这一段数为每一组数的后缀。结构大概类似于 \(0、 2、 5、 ...、 0、 2、 5、 ...、 0、 2、 5\),可以发现此时依旧满足所有 \(\geq m\) 的区间都满足条件。
由 \(f(m)\) 的启发考虑 \(f(i)\),我们可以先将 \(0 \sim i- 1\) 都按照 \(f(m)\) 的方法排好,然后我们考虑如何插入剩下的数。
注意到我们的构造方式已经使得任意长度 \(\geq i\) 的区间都满足条件,所以接下来插入的数如果插到某一个块(这里把上文的一组数叫做“块”)内部一定不优,换言之我们一定要把它插到块与块交界的地带或者整个序列的左侧和右侧。
我们分别考虑插到块与块交界地带会产生的不合法区间。
首先块与块之间应该有 \(X - 1\) 个空,然后如果在某个空插一个数会增加 \(2(i - 1) + r + 1\) 个不合法区间,其中 \(r\) 这个空已经插入的数量。因为它可以和本来就不合法的块内区间(即长度 \(< i\) 的区间)组成不合法区间,还可以和空内的组合以及自己也是一个不合法的区间。
然后如果插入到两侧则会产生 \(i - 1 + r + 1\) 个不合法区间,分析同上。
我们显然要让两侧和中间尽可能平衡,所以我们先在两侧插 \(2(i - 1)\) 个数,然后平均的对于两侧和中间插即可。
接下来我们考虑实现。我们可以容斥,减去所有不合法的区间。为了方便叙述我们令 \(calc(l, r)\) 表示首项与末项分别为 \(l、r\) 的公差为 \(1\) 的等差数列求和。
以下记 \(pre\) 为小于 \(i\) 的数量,\(suf\) 为大于等于 \(i\) 的数量。
先让 \(f(i) = calc(1, n)\),然后减去所有按照 \(f(m)\) 的方法排好后的不合法区间,即长度小于 \(i\) 的区间。显然长度为 \(l\) 的区间有 \(pre - l + 1\) 个,总个数就是 \(calc(pre - (i - 1) + 1, pre)\)。
然后我们要在两侧插数,注意我们插入的数个数为 \(\min(2(i - 1), suf)\)。此时记在左侧插入 \(lx\) 个数,右侧插入 \(rx\) 个数。此时增加的不合法区间(对于左侧)一定是左端点为这 \(lx\) 个数,右端点为这 \(lx\) 个数加上块内的 \(i - 1\) 个数,那么总和就是 \(calc(i, lx + i -1)\) 和 \(calc(i, rx + i - 1)\)。
接着是平均插数,具体情况之前分析过了。设完整地插了 \(z\) 轮,然后剩下 \(y\) 个多出来的数。那么增加的不合法区间个数就是 \(calc(2i - 1, 2i - 1 + z - 1) \times (j + 1) + y \times (2i - 1 + z)\)。其中 \(j + 1\) 是所有空的数量。
洛谷有点卡常,少用 long long
。值得注意的是 LOJ 完全不卡。
#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define initrand srand((unsigned)time(0))
#define random(x) ((LL)rand() * rand() % (x))
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
inline int read() {int x = 0, f = 1;char ch = getchar();while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }return x * f;
}const int N = 10000010, Mod = 998244353;
int m, l, r;
int f[N], a[N], s[N], X, n;
char op[N]; inline LL calc(int l, int r) {return 1ll * (l + r) * (r - l + 1) / 2 % Mod;
}inline int min(int x, int y) {return x < y ? x : y;
}inline void sub(int &x, int y) {x = (x < y ? x + Mod - y : x - y);
}int main() {m = read(), l = read(), r = read(), X = read();rep(i, 0, m - 1) {scanf(" %c", &op[i]);s[i] = a[i] = X + (op[i] == '1');n += a[i];} fro(i, m - 1, 0) s[i] += s[i + 1];f[0] = calc(1, n);int j = a[0]; rep(i, 1, m) {int suf = s[i], pre = n - s[i];f[i] = calc(1, n);sub(f[i], calc(pre - i + 2, pre));int tmp = min(2ll * (i - 1), suf);int lx = tmp / 2, rx = tmp - lx;sub(f[i], calc(i, i + lx - 1));sub(f[i], calc(i, i + rx - 1));// f[i] = ((f[i] - calc(i, i + lx - 1)) % Mod + Mod) % Mod;// f[i] = ((f[i] - calc(i, i + rx - 1)) % Mod + Mod) % Mod;int lst = suf - tmp, z = lst / (j + 1), y = lst % (j + 1);sub(f[i], 1ll * calc(2 * i - 1, 2 * i - 1 + z - 1) * (j + 1) % Mod);sub(f[i], 1ll * y * (2 * i - 1 + z) % Mod);// f[i] = ((f[i] - calc(2 * i - 1, 2 * i - 1 + z - 1) * (j + 1) % Mod) % Mod + Mod) % Mod;// f[i] = ((f[i] - y * (2 * i - 1 + z)) % Mod + Mod) % Mod;j = min(j, a[i]);}LL pw = 1, ans = 0;rep(i, 0, m) {if (l <= i && i <= r) ans = ans ^ (pw * f[i] % Mod);pw = pw * 233 % Mod; } printf("%lld\n", ans);return 0;
}