比如,假设要解决的问题是,要从 n 个物品中选取 m 个,并最优化某个较复杂的目标函数。如果设从前 i 个物品中选取 j 个,目标函数的最优值为 f(i,j),那么原问题的答案就是 f(n,m)。这类问题中,状态转移方程通常是二维的。直接实现该状态转移方程,时间复杂度是 O(nm) 的,难以接受。
进一步假设,没有数量限制的最优化问题容易解决。但是,选取到的最优数量未必满足原问题的数量限制。假设选取的物品过多。那么,就可以考虑在选取物品时,为每个选取到的物品都附加一个固定大小的惩罚 k(即「带权二分」中的「权」),仍然解没有数量限制的最优化问题。根据 k 的取值不同,选取到的最优数量也会有所不同;而且,随着 k 的变化,选取到的最优数量也是单调变化的。所以,可以通过二分,找到 k 使得选取到的最优数量恰为 m。假设此时目标函数的最优值为 fk(n),那么,只要消除额外附加的惩罚造成的价值损失,就能得到原问题的答案 f(n,m)=fk(n)+km。假设单次求解附加惩罚的问题的复杂度是 O(T(n)) 的,那么,算法的整体复杂度也就降低到了 O(T(n)logL),其中,O(logL) 是二分 k 需要的次数。
这就是 WQS 二分的基本想法。但是,这一想法能够行得通,前提是 f(n,m) 关于 m 是凸的。否则,可能不存在使得最优数量恰为 m 的附加惩罚 k。这也是这种 DP 优化方法常常称为「凸优化 DP」或「凸完全单调性 DP」的原因。
传统方法
设非空集合 X 为(有限的)决策空间,f:X→R 为目标函数,且另有函数 g:X→Rd 用于施加限制。需要求解的问题,可以看作是计算如下最优化问题的价值函数 v(y) 在某处的取值:
v(y)=x∈Xminsubject to f(x)g(x)=y.
比如,对于前文提到的限制数量的问题,X 可以理解为所有物品集合的子集族,x∈X 是单个子集,f(x) 是单个子集的价值函数,g(x) 是子集 x 中的元素个数。当然,g(x) 并非只能是数量限制,后文提供了更为广泛的限制条件的例子。
#include <algorithm>#include <cstring>#include <iostream>#include <tuple>#include <vector>int main() { int n, m; std::cin >> n >> m; std::vector<int> a(n + 1); for (int i = 1; i <= n; ++i) std::cin >> a[i]; // Calculate h(k) = max_x f(x) - k * g(x). // Meanwhile, obtain the maximum value g(x) of the optimizer x. auto calc = [&](int k) -> std::pair<long long, int> { long long dp[2] = {0, -0x3f3f3f3f3f3f3f3f}; int opt[2] = {0, 0}; for (int i = 1; i <= n; ++i) { long long tmp_dp[2]; int tmp_opt[2]; if (dp[0] > dp[1]) { tmp_dp[0] = dp[0]; tmp_opt[0] = opt[0]; } else if (dp[1] > dp[0]) { tmp_dp[0] = dp[1]; tmp_opt[0] = opt[1]; } else { tmp_dp[0] = dp[0]; tmp_opt[0] = std::max(opt[0], opt[1]); } tmp_dp[1] = dp[0] + a[i] - k; tmp_opt[1] = opt[0] + 1; std::memcpy(dp, tmp_dp, sizeof(dp)); std::memcpy(opt, tmp_opt, sizeof(opt)); } long long val; int opt_m; if (dp[0] > dp[1]) { val = dp[0]; opt_m = opt[0]; } else if (dp[1] > dp[0]) { val = dp[1]; opt_m = opt[1]; } else { val = dp[0]; opt_m = std::max(opt[0], opt[1]); } return {val, opt_m}; }; // WQS binary search. long long val, tar_val; int opt_m, tar_k; std::tie(val, opt_m) = calc(0); if (opt_m <= m) { // Have already reached the peak. tar_k = 0; tar_val = val; } else { // Find the maximum k such that g(x) >= m. int ll = 0, rr = 1000000; while (ll <= rr) { int mm = (ll + rr) / 2; std::tie(val, opt_m) = calc(mm); if (opt_m >= m) { tar_k = mm; tar_val = val; ll = mm + 1; } else { rr = mm - 1; } } } long long res = tar_val + (long long)tar_k * m; std::cout << res << std::endl; return 0;}
对偶方法
#include <algorithm>#include <iostream>#include <tuple>#include <type_traits>#include <vector>// Golden section search on integer domain (unimodal function)template <typename T, typename F>typename std::enable_if< std::is_integral<T>::value, std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::typegolden_section_search(T ll, T rr, F func) { constexpr long double phi = 0.618033988749894848204586834L; T ml = ll + static_cast<T>((rr - ll) * (1 - phi)); T mr = ll + static_cast<T>((rr - ll) * phi); auto fl = func(ml), fr = func(mr); while (ml < mr) { if (fl > fr) { rr = mr; mr = ml; fr = fl; ml = ll + static_cast<T>((rr - ll) * (1 - phi)); fl = func(ml); } else { ll = ml; ml = mr; fl = fr; mr = ll + static_cast<T>((rr - ll) * phi); fr = func(mr); } } T best_x = ll; auto best_val = func(ll); for (T i = ll + 1; i <= rr; ++i) { auto val = func(i); if (val > best_val) { best_val = val; best_x = i; } } return {best_x, best_val};}int main() { int n, m; std::cin >> n >> m; std::vector<int> a(n + 1); for (int i = 1; i <= n; ++i) std::cin >> a[i]; // Calculate h(k) = max_x f(x) + k * g(x). auto calc = [&](int k) -> long long { long long dp[2] = {0, -0x3f3f3f3f3f3f3f3f}; for (int i = 1; i <= n; ++i) { std::tie(dp[0], dp[1]) = std::make_pair(std::max(dp[0], dp[1]), dp[0] + a[i] + k); } return std::max(dp[0], dp[1]); }; // Solve the dual problem to find v(m). // Implemented as a minimization problem by adding negative signs. // Only consider tangent lines of negative slopes to ignore the part // of the curve after the peak. auto res = -golden_section_search(-1000000, 0, [&](int k) -> long long { return -calc(k) + (long long)k * m; }).second; std::cout << res << std::endl; return 0;}
引理 S 和 T 是无向连通图 G=(V,E) 的两个生成树。对于任意 e∈S∖T,都存在至少一条边 f∈T∖S,使得 S−e+f 和 T−f+e 都是图 G 的生成树。
设
证明 e=(u,v),且 P 是树 T 中连接 u 和 v 的唯一一条路径。因为 P+e 是图 T+e 中唯一的环路,所以删掉 P 中的任何一条边 f 都可以使得 T−f+e 是一棵生成树。与此同时,图 S−e 是有两个连通分量的森林,它们的顶点集分别记作 V1 和 V2,所以,只要选择边 f∈P 使得 f 连通了 V1 和 V2,就能保证 S−e+f 是一棵生成树。这样的边 f 总是存在的,因为 u 和 v 分别属于 V1 和 V2,而 P 连接了 u 和 v。而且,f∈/S,因为图 S−e 中,V1 和 V2 并不是连通的。这就完成了证明。
#include <algorithm>#include <iostream>#include <tuple>#include <type_traits>#include <vector>// Golden section search on integer domain (unimodal function)template <typename T, typename F>typename std::enable_if< std::is_integral<T>::value, std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::typegolden_section_search(T ll, T rr, F func) { constexpr long double phi = 0.618033988749894848204586834L; T ml = ll + static_cast<T>((rr - ll) * (1 - phi)); T mr = ll + static_cast<T>((rr - ll) * phi); auto fl = func(ml), fr = func(mr); while (ml < mr) { if (fl > fr) { rr = mr; mr = ml; fr = fl; ml = ll + static_cast<T>((rr - ll) * (1 - phi)); fl = func(ml); } else { ll = ml; ml = mr; fl = fr; mr = ll + static_cast<T>((rr - ll) * phi); fr = func(mr); } } T best_x = ll; auto best_val = func(ll); for (T i = ll + 1; i <= rr; ++i) { auto val = func(i); if (val > best_val) { best_val = val; best_x = i; } } return {best_x, best_val};}int main() { int n; std::cin >> n; std::vector<int> a(n + 1); for (int i = 1; i <= n; ++i) std::cin >> a[i]; for (int i = n; i >= 1; --i) a[i] -= a[i - 1]; long long v; std::cin >> v; // Cost of adding M more teleporters to a segment of length LEN. auto f = [&](int len, int m) -> long long { long long rem = len % (m + 1); int q = len / (m + 1); return (m + 1 - rem) * q * q + rem * (q + 1) * (q + 1); }; // Calculate h(k) = min_x f(x) - k * g(x). auto calc = [&](long long k) -> long long { long long res = 0; for (int i = 1; i <= n; ++i) { res += -golden_section_search(0, a[i], [&](int m) -> long long { return -f(a[i], m) + m * k; }).second; } return res; }; // Find the smallest k such that h(k) + k * m <= v. long long ll = -(1LL << 30), rr = 0, ti = 0; while (ll <= rr) { auto mm = ll + (rr - ll) / 2; auto fi = calc(mm); auto ub = fi - calc(mm + 1); if (fi + ub * mm <= v) { ti = mm; rr = mm - 1; } else { ll = mm + 1; } } std::cout << (int)((calc(ti) - v - 1 - ti) / (-ti)) << std::endl; return 0;}
方法二
代码仅做示意,由于浮点数精度问题无法通过原题数据范围。
#include <algorithm>#include <cmath>#include <iostream>#include <tuple>#include <type_traits>#include <vector>// Golden section search on integer domain (unimodal function)template <typename T, typename F>typename std::enable_if< std::is_integral<T>::value, std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::typegolden_section_search(T ll, T rr, F func) { constexpr long double phi = 0.618033988749894848204586834L; T ml = ll + static_cast<T>((rr - ll) * (1 - phi)); T mr = ll + static_cast<T>((rr - ll) * phi); auto fl = func(ml), fr = func(mr); while (ml < mr) { if (fl > fr) { rr = mr; mr = ml; fr = fl; ml = ll + static_cast<T>((rr - ll) * (1 - phi)); fl = func(ml); } else { ll = ml; ml = mr; fl = fr; mr = ll + static_cast<T>((rr - ll) * phi); fr = func(mr); } } T best_x = ll; auto best_val = func(ll); for (T i = ll + 1; i <= rr; ++i) { auto val = func(i); if (val > best_val) { best_val = val; best_x = i; } } return {best_x, best_val};}// Golden section search on floating-point domain (unimodal function)template <typename T, typename F>typename std::enable_if< std::is_floating_point<T>::value, std::pair<T, decltype(std::declval<F>()(std::declval<T>()))>>::typegolden_section_search(T ll, T rr, F func, T eps) { constexpr long double phi = 0.618033988749894848204586834L; T ml = ll + (rr - ll) * (1 - phi); T mr = ll + (rr - ll) * phi; auto fl = func(ml), fr = func(mr); while ((rr - ll) > eps) { if (fl > fr) { rr = mr; mr = ml; fr = fl; ml = ll + (rr - ll) * (1 - phi); fl = func(ml); } else { ll = ml; ml = mr; fl = fr; mr = ll + (rr - ll) * phi; fr = func(mr); } } T mid = (ll + rr) / 2; return {mid, func(mid)};}int main() { int n; std::cin >> n; std::vector<int> a(n + 1); for (int i = 1; i <= n; ++i) std::cin >> a[i]; for (int i = n; i >= 1; --i) a[i] -= a[i - 1]; long long v; std::cin >> v; // Cost of adding M more teleporters to a segment of length LEN. auto f = [&](int len, int m) -> long long { long long rem = len % (m + 1); int q = len / (m + 1); return (m + 1 - rem) * q * q + rem * (q + 1) * (q + 1); }; // Calculate h(k) = min_x f(x) - k * g(x). auto calc = [&](long double k) -> long double { long double res = 0; for (int i = 1; i <= n; ++i) { res += -golden_section_search(0, a[i], [&](int m) -> long double { return -m + k * f(a[i], m); }).second; } return res; }; // Solve the dual problem. auto res = golden_section_search( -1.0l, 0.0l, [&](long double k) -> long double { return calc(k) + k * v; }, 1e-12l) .second; std::cout << (int)ceill(res) << std::endl; return 0;}
定理中提供的条件看似比凸函数更强一些,但是,对于算法竞赛能够遇到的情形,特别是 X 为有限集合时,仅强调凸函数就已经足够。由离散集合上的正常凸函数 v 延拓而来的函数 v~ 必然是下半连续的凸函数,因为有限多个点的凸包必然是闭凸包,而所谓下半连续的凸函数,就等价于它的上境图是闭凸包。至于正常凸函数中的「正常」一词,只要 v(y) 在至少一个点处取得有限值且是凸函数,就可以保证。 ↩