题意简述
一副 \(n+m\) 张牌的扑克牌,\(m\) 张 joker。初始牌堆里有这样一副牌。随机抽一张牌拿走,如果是 joker,将所有牌放回牌堆并打乱。问你抽到过所有 \(n\) 张正常牌的期望抽牌次数是多少?对 \(M = 19260817\) 取模。
\(n \leq 10^8\),\(m \leq 10^{18}\)。
题目分析
概率期望类题目,考虑 DP,并且期望 DP 套路是从后往前递推。
显然应该可以状压 DP,但是其非常不利于后续优化。所以尝试使用线性 DP。
DP 记录值显然是当前状态到终态需要的期望抽牌次数。状态有哪些呢?牌堆、抽到过哪些牌了。如果朴素记录就是状压了,但是我们发现,\(n\) 张普通牌和 \(m\) 张 joker 并没有本质差别,是等价的,无非前者需要区分有没有被抽到过。
不妨使用 \(f_{i,j}\) 表示当前已经抽到过 \(i\) 张牌,牌堆里有 \(j\) 张牌,到终态的期望抽牌次数。我们需要明确的是,哪些状态是合法的,显然需要 \(j\) 中包含 \(m\) 张牌,和 \(n-i\) 张 \(i\) 中没有的普通牌,即 \(j \geq n+m-i\)。对于 \(i+j \geq n+m\) 的情况表示 \(i+j-n-m\) 张牌已经抽到过了,但后来被重新加入牌堆中。
明确好状态,就可以转移了。我们有 \(\frac{n-i}{j}\) 的概率,抽到一张全新的牌,转移到 \(f_{i+1,j-1}\);有 \(\frac{i+j-n-m}{j}\) 的概率,抽到一张抽到过的牌,转移到 \(f_{i,j-1}\);有 \(\frac{m}{j}\) 的概率,抽到 joker,转移到 \(f_{i,n+m}\)。验证一下,\(\frac{n-i}{j}+\frac{i+j-n-m}{j}+\frac{m}{j}=1\),没有问题。
边界 \(f_{n,j}=0\),答案 \(f_{0,n+m}\)。这不好递推,怎么办呢?
我们可以把它看做二维平面内的随机游走,向左下、左、行末行走。这个往行末行走就很经典。我们可以设 \(f_{i,j}=k_{i,j}\cdot f_{i,n+m}+b_{i,j}\),从 \(j=n+m-i\) 推到 \(j=n+m\),就是一个方程,方程解出来,\(f_{i}\) 就解出来了。具体可以见文末代码。
上述 DP 时空复杂度 \(\Theta(n^2)\),需要优化。经过打表发现,\(f_{i}\) 对 \(j\) 为等差数列。
【404 not found】
作者太菜了,还不会证。
我们设 \(f_{i,j}=\lambda_i+\mu_i\cdot(n+m-j)\),我们只需要任意两项 \(j\),就能确定 \(\lambda_i, \mu_i\),也就确定了 \(f_{i}\),为了方便起见,取末两项解方程。
于是可以 \(\mathcal{O}(n \log M)\),若 \(m = \mathcal{O}(n)\),则可以完全线性 \(\mathcal{O}(n)\)。边界 \(\lambda_n=\mu_n=0\),答案 \(\lambda_0\)。
代码
取模板子
namespace Mod_Int_Class {template <typename T, typename _Tp>constexpr bool in_range(_Tp val) {return std::numeric_limits<T>::min() <= val && val <= std::numeric_limits<T>::max();}template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>static constexpr inline bool is_prime(_Tp val) {if (val < 2) return false;for (_Tp i = 2; i * i <= val; ++i)if (val % i == 0)return false;return true;}template <auto _mod = 19260817, typename T = int, typename S = long long>class Mod_Int {static_assert(in_range<T>(_mod), "mod must in the range of type T.");static_assert(std::is_integral<T>::value, "type T must be an integer.");static_assert(std::is_integral<S>::value, "type S must be an integer.");public:constexpr Mod_Int() noexcept = default;template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>constexpr Mod_Int(_Tp v) noexcept: val(0) {if (0 <= S(v) && S(v) < mod) val = v;else val = (S(v) % mod + mod) % mod;}constexpr T const& raw() const {return this -> val;}static constexpr T mod = _mod;template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>constexpr friend Mod_Int pow(Mod_Int a, _Tp p) {return a ^ p;}constexpr friend Mod_Int sub(Mod_Int a, Mod_Int b) {return a - b;}constexpr friend Mod_Int& tosub(Mod_Int& a, Mod_Int b) {return a -= b;}constexpr friend Mod_Int add(Mod_Int a) { return a; }template <typename... args_t>constexpr friend Mod_Int add(Mod_Int a, args_t... args) {return a + add(args...);}constexpr friend Mod_Int mul(Mod_Int a) { return a; }template <typename... args_t>constexpr friend Mod_Int mul(Mod_Int a, args_t... args) {return a * mul(args...);}template <typename... args_t>constexpr friend Mod_Int& toadd(Mod_Int& a, args_t... b) {return a = add(a, b...);}template <typename... args_t>constexpr friend Mod_Int& tomul(Mod_Int& a, args_t... b) {return a = mul(a, b...);}template <T __mod = mod, typename = std::enable_if_t<is_prime(__mod)>>static constexpr inline T inv(T a) {assert(a != 0);return _pow(a, mod - 2);}constexpr Mod_Int& operator + () const {return *this;}constexpr Mod_Int operator - () const {return _sub(0, val);}constexpr Mod_Int inv() const {return inv(val);}constexpr friend inline Mod_Int operator + (Mod_Int a, Mod_Int b) {return _add(a.val, b.val);}constexpr friend inline Mod_Int operator - (Mod_Int a, Mod_Int b) {return _sub(a.val, b.val);}constexpr friend inline Mod_Int operator * (Mod_Int a, Mod_Int b) {return _mul(a.val, b.val);}constexpr friend inline Mod_Int operator / (Mod_Int a, Mod_Int b) {return _mul(a.val, inv(b.val));}template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>constexpr friend inline Mod_Int operator ^ (Mod_Int a, _Tp p) {return _pow(a.val, p);}constexpr friend inline Mod_Int& operator += (Mod_Int& a, Mod_Int b) {return a = _add(a.val, b.val);}constexpr friend inline Mod_Int& operator -= (Mod_Int& a, Mod_Int b) {return a = _sub(a.val, b.val);}constexpr friend inline Mod_Int& operator *= (Mod_Int& a, Mod_Int b) {return a = _mul(a.val, b.val);}constexpr friend inline Mod_Int& operator /= (Mod_Int& a, Mod_Int b) {return a = _mul(a.val, inv(b.val));}template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>constexpr friend inline Mod_Int& operator ^= (Mod_Int& a, _Tp p) {return a = _pow(a.val, p);}constexpr friend inline bool operator == (Mod_Int a, Mod_Int b) {return a.val == b.val;}constexpr friend inline bool operator != (Mod_Int a, Mod_Int b) {return a.val != b.val;}constexpr Mod_Int& operator ++ () {this -> val + 1 == mod ? this -> val = 0 : ++this -> val;return *this;}constexpr Mod_Int& operator -- () {this -> val == 0 ? this -> val = mod - 1 : --this -> val;return *this;}constexpr Mod_Int operator ++ (int) {Mod_Int res = *this;this -> val + 1 == mod ? this -> val = 0 : ++this -> val;return res;}constexpr Mod_Int operator -- (int) {Mod_Int res = *this;this -> val == 0 ? this -> val = mod - 1 : --this -> val;return res;}friend std::istream& operator >> (std::istream& is, Mod_Int<mod, T, S>& x) {T ipt;return is >> ipt, x = ipt, is;}friend std::ostream& operator << (std::ostream& os, Mod_Int<mod, T, S> x) {return os << x.val;}protected:T val;static constexpr inline T _add(T a, T b) {return a >= mod - b ? a + b - mod : a + b;}static constexpr inline T _sub(T a, T b) {return a < b ? a - b + mod : a - b;}static constexpr inline T _mul(T a, T b) {return static_cast<S>(a) * b % mod;}template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>static constexpr inline T _pow(T a, _Tp p) {T res = 1;for (; p; p >>= 1, a = _mul(a, a))if (p & 1) res = _mul(res, a);return res;}};using mint = Mod_Int<>;using mod_t = mint;constexpr mint operator ""_m (unsigned long long x) {return mint(x);}constexpr mint operator ""_mod (unsigned long long x) {return mint(x);}
}using namespace Mod_Int_Class;
$\mathcal{O}(n^2)$ 部分分 & 打表
#include <cstdio>
#include <iostream>
#include <limits>
#include <cassert>
#include <vector>
using namespace std;int n, m;namespace $1 {bool check() {return n <= 1000;}void solve() {vector<vector<mint>> f(n + 1, vector<mint>(n + 1));for (int i = n - 1; i >= 0; --i) {vector<mint> k(i + 1), b(i + 1);k[0] = 1_mod * m / (n + m - i);b[0] = 1_mod * (n - i) / (n + m - i) * f[i + 1][0] + 1;for (int j = 1; j <= i; ++j) {k[j] = 1_mod * m / (n + m - i + j)+ 1_mod * j / (n + m - i + j) * k[j - 1];b[j] = 1_mod * (n - i) / (n + m - i + j) * f[i + 1][j]+ 1_mod * j / (n + m - i + j) * b[j - 1] + 1;}f[i][i] = b[i] / (1 - k[i]);for (int j = 0; j < i; ++j)f[i][j] = k[j] * f[i][i] + b[j];}printf("%d\n", f[0][0].raw());vector<vector<double>> g(n + 1, vector<double>(n + 1));for (int i = n - 1; i >= 0; --i) {vector<double> k(i + 1), b(i + 1);k[0] = 1. * m / (n + m - i);b[0] = 1. * (n - i) / (n + m - i) * g[i + 1][0] + 1;for (int j = 1; j <= i; ++j) {k[j] = 1. * m / (n + m - i + j)+ 1. * j / (n + m - i + j) * k[j - 1];b[j] = 1. * (n - i) / (n + m - i + j) * g[i + 1][j]+ 1. * j / (n + m - i + j) * b[j - 1] + 1;}g[i][i] = b[i] / (1 - k[i]);for (int j = 0; j < i; ++j)g[i][j] = k[j] * g[i][i] + b[j];for (int j = 0; j <= i; ++j)printf("%.10lf ", g[i][j]);puts("");// for (int j = 1; j <= i; ++j)// printf("%.10lf ", g[i][j] - g[i][j - 1]);// puts("");}}
}signed main() {#ifndef XuYuemingfreopen("toad.in", "r", stdin);freopen("toad.out", "w", stdout);#endifscanf("%d%d", &n, &m);if ($1::check()) return $1::solve(), 0;$yzh::solve();return 0;
}
$\mathcal{O}(n \log M)$ 正解
namespace $yzh {const int N = 1000010;mint lambda[N], mu[N];void solve() {lambda[n] = mu[n] = 0;for (int i = n - 1; i >= 0; --i) {mu[i] = ((n - i) * mu[i + 1] - 1) / (n + m - i + 1);lambda[i] = i * mu[i] / (n - i) + lambda[i + 1] + mu[i + 1] + 1_mod * (n + m) / (n - i);}printf("%d", lambda[0].raw());}
}
卡常后
#pragma GCC optimize("Ofast", "inline", "fast-math", "unroll-loops")
#include <cstdio>const int N = 1000010, mod = 19260817;int n, m, lambda, mu, Inv[N << 1];
inline int add(int a, int b) { return a >= mod - b ? a + b - mod : a + b; }signed main() {freopen("toad.in", "r", stdin);freopen("toad.out", "w", stdout);scanf("%d%d", &n, &m), Inv[1] = 1;for (register int i = 2, *I = Inv + 2; i <= n + m + 1; ++i, ++I)*I = 1ll * (mod - Inv[mod % i]) * (mod / i) % mod;for (register int i = 1; i <= n; ++i) {int t = mu;mu = 1ll * add(1ll * i * mu % mod, mod - 1) * Inv[m + i + 1] % mod;lambda = add(1ll * add(n + m, 1ll * (n - i) * mu % mod) * Inv[i] % mod, add(lambda, t));}printf("%d", lambda);return 0;
}