谨以此纪念这个废物逝去的一天。
别看它是一道黑题但是它不配。
首先它长得很像分数规划,直接二分答案,这样就把每条边的边权看成了 \(V(e)-\text{mid}\),然后你希望求经过边数在 \([L,U]\) 之间的最长路径,判断它是否 \(\ge 0\)。
考虑一个暴力 \(dp_{i,j}\) 表示 \(i\) 作为链顶,其子树长度为 \(j\) 的链最大的边权和是多少。
暴力转移是非常容易的:\(dp_{i,j}=\max_{(i,v,w)\in E}dp_{v,j-1}+w-\text{mid}\)。
求解也是比较轻松的,考虑在转移 \(dp_{v,j-1}\to dp_{i,j}\) 之前,枚举 \(dp_{i,j},dp_{v,k}\) 在满足 \(j+k+1\in [L,U]\) 的情况下,\(\text{ans}=\max \{dp_{i,j}+dp_{v,k}\}\),本质上就是合并两颗子树的答案。
将求解稍微优化一下,由于二分,此时复杂度应该是 \(\mathcal{O}(n^2\log V)\)。
考虑如何进一步的优化。观察到这个 \(dp\) 的定义和转移都非常符合我们长剖优化 DP 的形式。
直接套用我们在引例中使用长剖进行转移的形式,直接继承重儿子的 \(dp_{i,j}\),注意一下下标偏移即可,我们应该是可以非常容易的将 \(dp\) 的转移做到时空复杂度均为 \(\mathcal{O}(n)\) 的。
但是在转移完之后求解答案的过程却不那么好通过知道所有的 \(dp\) 值之后快速解决。
观察到在求解的 \(dp_{i,j}+dp_{v,k}\) 中,在知晓 \(j\) 的情况下,合法的 \(k\) 是连续的一段区间,故我们考虑直接将 \(dp\) 直接搬到线段树上去。
这样应该是比较好求解的,求解是在转移之前,通过查询区间最值,和长剖本身的复杂度,这一部分应该是 \(\mathcal{O}(n\log n)\)。
但是新的问题出现了,线段树的下标偏移在继承重儿子的时候并不是直接左移那么简单的(疑似可以直接线段树合并,我不知道。)。故我们直接很暴力的,将 \(i\) 对应的在线段树上的 \(dp\) 值,储存在其在链顶处的线段树上所对应的下标。然后稍微注意一下因为我们所有储存的信息都可以看做是链顶的 \(dp\) 值(因为你的下标已经对应在了链顶处,但是他在 \(i\) 处的 \(dp\) 下标和在链顶处的下标是不同的),故转移和求解还有一些细节需要处理。可能这里讲的不太清楚,具体可以看代码。
至于空间问题,可以直接开指针,每条长链公用一颗线段树即可。
空间复杂度 \(\mathcal{O}(n)\),时间复杂度 \(\mathcal{O}(n\times \log n\times \log V)\)。
注意,线段树需要四倍空间,导致你使用指针的时候需要非常精细且小心,我因为没计算好,调了一天。
#include <bits/stdc++.h>
using namespace std;
#define maxn 200005int n, l, r;
int fst[maxn], cnt;
struct node
{int tar, nxt, num;
}arr[maxn << 1];
void adds(int x, int y, int z)
{arr[++cnt].tar = y, arr[cnt].nxt = fst[x], fst[x] = cnt, arr[cnt].num = z;
}
int son[maxn], maxdep[maxn] = {-1};
void dfs(int x, int last)
{int mxid = 0;for (int i = fst[x]; i; i = arr[i].nxt){int j = arr[i].tar;if(j == last) continue;dfs(j, x);if(maxdep[j] > maxdep[mxid]) maxdep[x] = maxdep[j] + 1, mxid = j;}son[x] = mxid;
}
int top[maxn], dep[maxn];
void dfs2(int x, int last, int topf)
{top[x] = topf, dep[x] = dep[last] + 1;if(son[x]) dfs2(son[x], x, topf);for (int i = fst[x]; i; i = arr[i].nxt){int j = arr[i].tar;if(j == son[x] || j == last) continue;dfs2(j, x, j); }
}
double Dep[maxn];
double *data[maxn << 2], buf[maxn << 2];
int cur = 1;
double mid, ans = 0;
void init(int x, int last)
{for (int i = fst[x]; i; i = arr[i].nxt){int j = arr[i].tar;if(j == last) continue;Dep[j] = Dep[x] + arr[i].num - mid;init(j, x);}
}
void change(int op, int p, int L, int R, int x, double y)
{if(L > x || R < x) return;if(L == R) return data[op][p] = y, void(0);int mid = (L + R) >> 1;change(op, p << 1, L, mid, x, y), change(op, p << 1 | 1, mid + 1, R, x, y);data[op][p] = max(data[op][p << 1], data[op][p << 1 | 1]);
}
double query(int op, int p, int L, int R, int l, int r)
{if(l > r) return -1e17;if(L > r || R < l) return -1e17;if(L >= l && R <= r) return data[op][p];int mid = (L + R) >> 1;return max(query(op, p << 1, L, mid, l, r), query(op, p << 1 | 1, mid + 1, R, l, r));
}
void get_ans(int x, int last)
{if(son[x]){get_ans(son[x], x);ans = max(ans, query(top[x], 1, 0, maxdep[top[x]], max(l + dep[x] - dep[top[x]], 0), min(r + dep[x] - dep[top[x]], maxdep[top[x]])) - Dep[x] + Dep[top[x]]);}change(top[x], 1, 0, maxdep[top[x]], dep[x] - dep[top[x]], Dep[x] - Dep[top[x]]);for (int i = fst[x]; i; i = arr[i].nxt){int j = arr[i].tar;if(j == last || j == son[x]) continue;double k = arr[i].num - mid;data[j] = buf + cur, cur += 4 * (maxdep[j] + 3);get_ans(j, x);for (int i = 0; i <= maxdep[j]; ++i){double now = query(j, 1, 0, maxdep[j], i, i);double Now = query(top[x], 1, 0, maxdep[top[x]], max(l + dep[x] - dep[top[x]] - i - 1, 0), min(r + dep[x] - dep[top[x]] - i - 1, maxdep[top[x]]));ans = max(ans, now + k + Now - Dep[x] + Dep[top[x]]);}for (int i = 0; i <= maxdep[j]; ++i){double now = query(j, 1, 0, maxdep[j], i, i);double Now = query(top[x], 1, 0, maxdep[top[x]], i + dep[x] - dep[top[x]] + 1, i + dep[x] - dep[top[x]] + 1);change(top[x], 1, 0, maxdep[top[x]], i + dep[x] - dep[top[x]] + 1, max(now + Dep[x] - Dep[top[x]] + k, Now));}}
}
int main()
{scanf("%d %d %d", &n, &l, &r);for (int i = 1; i < n; ++i){int x, y, z;scanf("%d %d %d", &x, &y, &z);adds(x, y, z), adds(y, x, z);}dfs(1, 0), dfs2(1, 0, 1);double l = 0, r = 1e6, ans = 0;while(r - l > 1e-4){mid = (l + r) / 2; ::ans = -1, cur = 1;for (int i = 0; i <= 8 * n + 4; ++i) buf[i] = -1e17;data[1] = buf, cur += 4 * (maxdep[1] + 3);init(1, 0);get_ans(1, 0);if(::ans >= 0) ans = mid, l = mid;else r = mid;}printf("%.3lf\n", ans);return 0;
}