因为要让方差越大越好,所以要让序列 \(c\) 尽可能的不稳定。也就是说,要让小的数尽可能小,大的数尽可能大。
所以显然,\(c_i\) 一定在 \(a_i\) 和 \(b_i\) 之中。
由于题目中的数组都是排序过的,所以一定存在一个分割点 \(i\) 满足对于所有 \(1 \le j \le i\),\(c_j\) 都等于 \(a_j\),所有 \(i+1 \le j \le n\),\(c_j\) 都等于 \(b_j\)。
所以我们枚举分割点,求出序列的方差然后取 \(max\) 就好了。
接下来的问题就是怎么求方差乘上 \(n^2\) 的结果。
\[\begin{aligned}n^2 \cdot \sum_{i=1}^n{(a_i-\overline{a})^2} &= n \cdot \sum_{i=1}^n{(a_i^2-2 \cdot a_i \cdot \overline{a}+\overline{a}^2)} \\ &= n \cdot \sum_{i=1}^n{a_i^2}-2n\overline{a}\sum_{i=1}^na_i+n^2\cdot \overline{a}^2 \\ &=n \cdot \sum_{i=1}^n{a_i^2}-2\cdot (\sum_{i=1}^n{a_i})^2+(\sum_{i=1}^n{a_i^2}) \\ &= n \cdot \sum_{i=1}^n{a_i^2}-(\sum_{i=1}^n{a_i})^2\end{aligned}
\]
所以只要用前缀和后缀和维护一下就好了。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 1e9 + 7;
const int N = 1000005;
const int INF = 0x3f3f3f3f;
ll a[N], b[N];
void print(__int128 x) {if (x < 0) {putchar('-');x = -x;}if (x < 10) {putchar(x + 48);return;}print(x / 10);putchar(x % 10 + 48);
}
__int128 sa1[N], sa2[N];
__int128 sb1[N], sb2[N];
int main() {int n;scanf("%d", &n);for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);for (int i = 1; i <= n; i++) scanf("%lld", &b[i]);for (int i = 1; i <= n; i++) {sa1[i] = sa1[i - 1] + (__int128)a[i];sa2[i] = sa2[i - 1] + (__int128)a[i] * a[i];}for (int i = n; i >= 1; i--) {sb1[i] = sb1[i + 1] + (__int128)b[i];sb2[i] = sb2[i + 1] + (__int128)b[i] * b[i];}__int128 ans = 0;for (int i = 1; i <= n; i++) ans = max(ans, n * (sa2[i] + sb2[i + 1]) - (sa1[i] + sb1[i + 1]) * (sa1[i] + sb1[i + 1]));print(ans);putchar('\n');return 0;
}