前言:
(虚假的想象学竞赛,实际的数学竞赛)
题意:
给出一个长度为 \(n\) 的二进制序列,我们对于每一个分割点(可以看做在元素与元素之间),其贡献为分割点右边的 \(cnt_1 - cnt_0\) 与左边的 \(cnt_1 - cnt_0\) 乘积,并且定义这个序列的得分为这个序列所有分割点贡献的最大值。
现在希望你求出给定的长度为 \(n\) 的二进制序列的所有子序列的贡献之和(子序列可以不连续,也就是类似于子集的定义)。同时我们有 \(q\) 次询问,每一次都会修改这个二进制序列上的某一个值(使其异或上 \(1\)),对于每一次修改之后都要回答上述的贡献之和。
思路:
显然的,对于一个固定的序列,我们可以 \(O(len)\) 暴力去做,但是很明显非常不优秀,那我们考虑一步一步优化。首先需要明确的是,对于一个固定的序列,其 \(cnt_1 - cnt_0\) 总为一个定值,我们记为 \(p\)。假设任意分割点左边的贡献为 \(a\),右边的为 \(b = (p-a)\)。那么贡献即为:
我们需要得到 \(f(a)_{\max},a\in[0,p]\),其实这个地方的定义域是不严谨的,只是这么写方便理解。那么一般来说最值在 \(a=\frac{p}{2}\) 的时候取到,显然这是在定义域之内的。又因为 \(a\) 是整数,所以我们某一个固定序列的得分即为:
现在考虑如何形式化的表示所有子序列的得分。令 \(x\) 为这个长度为 \(n\) 的序列的 \(cnt_1\) ,\(y\) 为 \(cnt_0\) ,显然 \(x+y=n\) ,给出式子:
其中后面中括号包起来很奇怪的那一坨的意义是判断 \(i-j\) 是否为奇数。想等式比较难,但是我觉得等式都列出来了应该不难理解,至于 \(\frac{1}{4}\) 的由来是两个 \(\frac{i-j}{2}\) 提出去的。
现在好像复杂度还是不够优秀,没事我们有数学牢大的帮助。先不考虑后面判断奇偶的问题,只考虑前面的式子。考虑把完全平方展开,对于 \(i^2\) 和 \(j^2\) 发现形式一样,可以只考虑一种。在给出最终推到之前,需要给出一个重要的等式:
考虑组合意义(双射)证明,右边式子可以看做现在 \(x\) 个里面选择 \(i\) 个,然后在 \(i\) 个里面选择 \(1\) 个。左边可以看做现在 \(x\) 里面选择 \(1\) 个,然后在剩下的里面选择 \(i-1\) 个,可以证明的是,这是双射的。
运用相同的思想,可以得到:
唯一需要的技巧就是把 \(i(i-1)+i\) 分开算。
那么开始暴力简化式子之后就可以得到:
这个式子的推导过程是有趣的,但是太长了,请根据上述前置等式自行推导(对于第一次接触组合推导的组合小白来说,这是具有启发性的)。
然后还需处理一个问题,就是那个判断奇偶性的部分,怎么快速处理。我们这样考虑,假设外层的 \(i\) 一直枚举,很显然,要满足 \(i+j\) 是一个奇数,那么必然有内层的 \(j\) 是一个偶数,反之同理,也就是内层存在的贡献应为:
这个两个式子应该是非常著名的,他们的贡献都为 \(2^{y-1}\),又因为外层的总贡献为 \(2^x\) 所以总贡献就为 \(2^{n-1}\) ,最终答案减去他就好啦,所以:
然后最后就只剩下单个位置修改的问题了,很简单,修改了之后改变一下 \(cnt_{0/1}\) 的值就好啦。
Code:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define LL long long
inline int read(){char c=getchar();bool f=0;int x=0;while(c > '9' || c < '0') f|=c=='-',c=getchar();while(c >= '0'&&c <= '9') x=(x<<1)+(x<<3)+(c^48),c=getchar();if(f) x=-x;return x;
}
const int N = 2e5 + 10;
const int MOD = 998244353;
char s[N];
int inv2 = (MOD + 1) / 2,p2[N],n,q;
LL cnt[2];LL calc(LL x)
{if(x == 0) return 0;else if(x == 1) return x * p2[n - x] % MOD * inv2 % MOD * inv2 % MOD;else return (x * p2[x - 1] + x * (x - 1) % MOD * p2[x - 2]) % MOD * p2[n - x] % MOD * inv2 % MOD * inv2 % MOD;
}
void solve()
{cnt[0] = cnt[1] = 0;n = read(),q = read();scanf("%s",s + 1);for(int i = 1;i <= n;++i)++cnt[s[i] - '0'];while(q--){int id = read();--cnt[s[id] - '0'];s[id] = '1' - (s[id] - '0');++cnt[s[id] - '0'];LL ans = 0;ans = (calc(cnt[0]) + calc(cnt[1])) % MOD;ans = ans - (cnt[0] * cnt[1] % MOD * p2[n - 1] % MOD + p2[n - 1] + MOD) * inv2 % MOD * inv2 % MOD;ans = (ans % MOD + MOD) % MOD;std::cout << ans << '\n'; }
}int main()
{int T;T = read();p2[0] = 1;for(int i = 1;i <= N - 10;++i) p2[i] = (p2[i - 1] << 1) % MOD;while(T--) solve();return 0;
}