来源:
参考:
- 树形dp -(树上背包类) 的选课问题一节
题目描述
在大学里每个学生,为了达到一定的学分,必须从很多课程里选择一些课程来学习,在课程里有些课程必须在某些课程之前学习,如高等数学总是在其它课程之前学习。现在有 门功课,每门课有个学分,每门课有一门或没有直接先修课(若课程 是课程 的先修课即只有学完了课程 ,才能学习课程 )。一个学生要从这些课程里选择 门课程学习 1,问他能获得的最大学分是多少?
输入格式
第一行有两个整数 , 用空格隔开 , 。
接下来的 行,第 行包含两个整数 和 , 表示第 门课的直接先修课, 表示第 门课的学分。若 表示没有直接先修课 ,。
输出格式
只有一行,选 门课程的最大学分。
输入输出样例
输入 #1
7 4
2 2
0 1
0 4
2 1
7 1
7 6
2 2
输出 #1
13
思路
典型的树上背包问题.
有几点需要注意的是:
- 题目中 ” 一个学生要从这些课程里选择 门课程学习 ” 这句话隐藏了 的条件, 所以不用考虑 M>N 的数据的情况。
解法
朴素解法
// 解法1: 分组背包
// u - 当前节点
void dfs(int u, int M, const vector<Node>& nodes, vector<vector<int>>& dp) {
// 先将当前节点的学分放入背包
for (int i = 1; i <= M; ++i) {
dp[u][i] = nodes[u].score;
}
// 此方法中, 我们认为 dp[u][j] 表示在以 u 为根的子树中,必须选择 u,总共选择不超过 j 门课所能获得的最大学分
// 因此在这里我们初始化了 dp[u][1~M] 为 nodes[u].score
for (int s : nodes[u].sons) {
dfs(s, M, nodes, dp); // 先遍历子节点树, 确保 dp[s] ready
for (int j = M; j > 1; --j) { // 这里必须倒序, 因为是0-1背包
for (int k = 1; j - k >= 1; ++k) {
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[s][k]);
}
}
}
}
// M 表示选课程数量, k[i] 表示第 i 门课的直接先修课的编号, s[i] 表示第 i 门课的学分
// 编号从 1 开始, 对应课程数组中的第 i-1 个元素 0 表示没有先修课
int selectCourses_grouped_knapsack(int M, const vector<int>& k, const vector<int>& s) {
// 因为给的数据是一片森林(多个树), 不好操作, 先把所有无先修课的课程都设置一个公共的先修课--节点0
int N = k.size() + 1;
M++; // 因为额外加了一门虚拟节点0, 所以这里 M +1
vector<Node> nodes;
nodes.reserve(N);
nodes.emplace_back(0, -1); // 虚拟根节点
for (int i = 0; i < int(k.size()); ++i) {
nodes.emplace_back(s[i], k[i]); // 这里 k[i] 直接是 nodes 中的序号
}
for (int i = 1; i < (int)nodes.size(); ++i) {
nodes[nodes[i].parent].sons.emplace_back(i);
}
// dp数组,
vector<vector<int>> dp(N, vector<int>(M + 1, 0));
dfs(0, M, nodes, dp);
return dp[0][M];
}优化上下边界
在解法一中,我们为了处理某些节点树中节点数小于 M 的情况, 定义了 dp[u][j] 表示在以 u 为根的子树中,必须选择 u,总共选择不超过 j 门课所能获得的最大学分。
如果设节点树 u 的节点数为 n_u, dp[u][n_u+1~] 是等于 dp[u][n_u] , 最后的方案中肯定是没有选择超过 n_u 门课的情况, 所以 dp[u][n_u+1~] 实际是无意义的。
我们可以修改 dp[u][j] 的定义为在以 u 为根的子树中,必须选择 u,总共选择 j 门课所能获得的最大学分, 对于 j > n_u 的 dp[u][j], 其值是无意义的, 设为 0 即可。
基于此定义, 我们考虑单独的节点 u 时, 只需赋值 dp[u][1] = nodes[u].score 即可, dp[u][2~M] 保持为 0, 再按 0-1 背包的方式逐一加入 u 的所有子节点树, 每个子节点树只有 d[s][1~n_s] 的值是有意义的, 每次合并时也只会考虑 dp[s][1~n_s] 的值, 这样最终得到也只有 d[u][1~n_u] 是有意义的。
同时, 我们也可以缩小 j 和 k 的范围
// u - 当前节点
int dfs2(int u, int M, const vector<Node>& nodes, vector<vector<int>>& dp) {
int nu = 1;
// 先将当前节点放入背包
dp[u][1] = nodes[u].score;
for (int s : nodes[u].sons) {
int ns = dfs2(s, M, nodes, dp); // 先遍历子节点树, 确保 dp[s] ready
// 优化边界, 让 1<=j-k<=nu, 1<=k<=ns, j<= nu+ns 且 j <= M
for (int j = min(nu + ns, M); j > 1; --j) { // 这里必须倒序, 因为是0-1背包
int lim = min(ns, j - 1);
for (int k = max(1, j - nu); k <= lim; ++k) {
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[s][k]);
}
}
nu += ns;
}
return nu; // 返回当前树的节点数
}
// M 表示选课程数量, k[i] 表示第 i 门课的直接先修课的编号, s[i] 表示第 i 门课的学分
// 编号从 1 开始, 对应课程数组中的第 i-1 个元素 0 表示没有先修课
int selectCourses_grouped_knapsack2(int M, const vector<int>& k, const vector<int>& s) {
// 因为给的数据是一片森林(多个树), 不好操作, 先把所有无先修课的课程都设置一个公共的先修课--节点0
int N = k.size() + 1;
M++; // 因为额外加了一门虚拟节点0, 所以这里 M +1
vector<Node> nodes;
nodes.reserve(N);
nodes.emplace_back(0, -1); // 虚拟根节点
for (int i = 0; i < int(k.size()); ++i) {
nodes.emplace_back(s[i], k[i]); // 这里 k[i] 直接是 nodes 中的序号
}
for (int i = 1; i < (int)nodes.size(); ++i) {
nodes[nodes[i].parent].sons.emplace_back(i);
}
// dp数组,
vector<vector<int>> dp(N, vector<int>(M + 1, 0));
dfs2(0, M, nodes, dp);
return dp[0][M];
}dfs 序解法
// 解法3: dfs序解法
// 获取dfs后序遍历编号以及节点树大小
void dfs3(int u, const vector<Node>& nodes, vector<int>& awa, vector<int> siz) {
siz[u] = 1;
for (int s : nodes[u].sons) {
dfs3(s, nodes, awa, siz);
siz[u] += siz[s];
}
awa.emplace_back(u);
}
int selectCourses_dfs(int M, const vector<int>& k, const vector<int>& s) {
int N = k.size() + 1;
M++;
vector<Node> nodes;
nodes.reserve(N);
nodes.emplace_back(0, -1); // 虚拟根节点
for (int i = 0; i < int(k.size()); ++i) {
nodes.emplace_back(s[i], k[i]); // 这里 k[i] 直接是 nodes 中的序号
}
for (int i = 1; i < (int)nodes.size(); ++i) {
nodes[nodes[i].parent].sons.emplace_back(i);
}
vector<int> siz(N); // 树大小
vector<int> awa; // 后序遍历序号
awa.reserve(N);
dfs3(0, nodes, awa, siz);
// dp数组, f[i][j] 表示前i个节点中选j个的最大值
vector<vector<int>> f(N + 1, vector<int>(M + 1, 0));
for (int i = 1; i <= N; ++i) {
for (int j = M; j > 0; --j) {
f[i][j] = max(f[i - 1][j - 1] + nodes[awa[i - 1]].score, f[i - siz[awa[i - 1]]][j]);
}
}
// 注: awa索引从0开始,所以这里需i-1
return f[N][M];
}多叉转二叉树解法
// 解法4: 多叉转二叉树解法
void conv(int u, vector<Node>& nodes) {
int pre = -1;
for (int s : nodes[u].sons) {
if (pre == -1) {
nodes[u].left = s;
} else {
nodes[pre].right = s;
}
pre = s;
conv(s, nodes);
}
}
void cntsiz(int u, vector<Node>& nodes) {
if (nodes[u].left != -1) {
cntsiz(nodes[u].left, nodes);
nodes[u].siz += nodes[nodes[u].left].siz;
}
if (nodes[u].right != -1) {
cntsiz(nodes[u].right, nodes);
nodes[u].siz += nodes[nodes[u].right].siz;
}
}
// 二叉树遍历
void dfs4(int u, int M, const vector<Node>& nodes, vector<vector<int>>& dp) {
int l = nodes[u].left, r = nodes[u].right;
int lsiz = 0, rsiz = 0;
if (l != -1) {
lsiz = nodes[l].siz;
dfs4(l, M, nodes, dp);
}
if (r != -1) {
rsiz = nodes[r].siz;
dfs4(r, M, nodes, dp);
}
int ilim = min(M, nodes[u].siz);
for (int i = 1; i <= ilim; ++i) {
if (l != -1 && r != -1) { // 两个子树都存在
dp[u][i] = dp[r][i]; //不选当前节点的情况
// 注:不需要 dp[r][min(i, rsiz)],因为如果 i > rsiz,说明右子树已经分配满了,
// 后续的 max 会处理,这句话就相当于没用了
// 选当前节点的情况,i-1个节点从左右子树中挑
// 这两个边界的核心是保证 dp[l][j] 的 j 在 [0, lsiz] 范围内
// 和 dp[r][i-1-j] 的 i-1-j 在 [0, rsiz] 范围内
// 超出范围的情况是无意义的
int lj = max(0, i - 1 - rsiz);
int rj = min(i - 1, lsiz);
for (int j = lj; j <= rj; ++j) {
// 左边 j个,右边 i-1-j个
dp[u][i] = max(dp[u][i], dp[l][j] + dp[r][i - 1 - j] + nodes[u].score);
}
} else if (l != -1) { // 只有左子树
dp[u][i] = dp[l][i - 1] + nodes[u].score; // 必须选当前节点才能选左子树
} else if (r != -1) { // 只有右子树
dp[u][i] = max(dp[r][i - 1] + nodes[u].score, dp[r][i]);
} else {
dp[u][i] = nodes[u].score; // 只有当前节点
}
}
}
int selectCourses_binary_tree(int M, const vector<int>& k, const vector<int>& s) {
int N = k.size() + 1;
M++;
vector<Node> nodes;
nodes.reserve(N);
nodes.emplace_back(0, -1); // 虚拟根节点
for (int i = 0; i < int(k.size()); ++i) {
nodes.emplace_back(s[i], k[i]); // 这里 k[i] 直接是 nodes 中的序号
}
for (int i = 1; i < (int)nodes.size(); ++i) {
nodes[nodes[i].parent].sons.emplace_back(i);
}
conv(0, nodes);
cntsiz(0, nodes);
vector<vector<int>> dp(N, vector<int>(M + 1, 0));
dfs4(0, M, nodes, dp);
return dp[0][M];
}Error
写时犯了一个错误,企图 1 次 dfs 完成树的转换和大小计算:
void conv(int pos) { siz[pos] = 1; int prei = -1; for (auto i : G[pos]) { if (prei == -1) { lc[pos] = i; } else { rc[prei] = i; } prei = i; conv(i); } // 企图1次dfs完成树的转换和大小计算 if (lc[pos] != NO) { siz[pos] += siz[lc[pos]]; } if (rc[pos] != NO) { // 但 conv 到 pos 时, rc[pos] 还没被赋值 siz[pos] += siz[rc[pos]]; } }这种写法是为了一次 dfs 就完成二叉树的转换以及树大小的计算,但因为在转换过程中 rc[pos] 还没有被赋值,所以会导致 siz[pos] 的计算错误。所以只能先转换树,再计算大小。
泛化物品求并解法
详见 2.4 泛化物品求并解法
// 泛化物品求并
// dep 表示从虚拟根节点到当前节点的路径上真实课程的数量
void dfs5(int u, int M, int dep, const vector<Node>& nodes, vector<vector<int>>& dp) {
for (int s : nodes[u].sons) {
// 遍历子节点 s
for (int j = dep + 1; j <= M; ++j) {
dp[s][j] = dp[u][j - 1] + nodes[s].score;
}
// 递归遍历 s 子树的所有节点
dfs5(s, M, dep + 1, nodes, dp);
// 此时 dp[s] 表示选s及u的方案, dp[u] 表示选u不选s的方案,两者仅可取其一
for (int j = dep + 1; j <= M; ++j) {
dp[u][j] = max(dp[u][j], dp[s][j]);
}
}
}
int selectCourses_generalized_items(int M, const vector<int>& k, const vector<int>& s) {
int N = k.size() + 1;
vector<Node> nodes;
nodes.reserve(N);
nodes.emplace_back(0, -1); // 虚拟根节点
for (int i = 0; i < int(k.size()); ++i) {
nodes.emplace_back(s[i], k[i]); // 这里 k[i] 直接是 nodes 中的序号
}
for (int i = 1; i < (int)nodes.size(); ++i) {
nodes[nodes[i].parent].sons.emplace_back(i);
}
// dp[u][j] 表示遍历到 u 节点时,选择 j 真实门课且必选 u 的最大学分
vector<vector<int>> dp(N, vector<int>(M + 1, 0));
dfs5(0, M, 0, nodes, dp);
return dp[0][M]; // 注意这里的 M 没有 +1, 因为 dp[u][j] 里的 j 表示的是选的真实课程数
}结果对比
Select Courses Problem
TestCase: (M, k, s, Expected)
---
case 1/2: (4, [2, 0, 0, 2, 7, 7, 2], [2, 1, 4, 1, 1, 6, 2], 13)
selectCourses_grouped_knapsack: 13 (Time: 0.0038 ms) ✔
selectCourses_grouped_knapsack2: 13 (Time: 0.0046 ms) ✔
selectCourses_dfs: 13 (Time: 0.0039 ms) ✔
selectCourses_binary_tree: 13 (Time: 0.0044 ms) ✔
selectCourses_generalized_items: 13 (Time: 0.0026 ms) ✔
---
case 2/2: (4, [0, 1, 2, 3, 4, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7, 8], 22)
selectCourses_grouped_knapsack: 22 (Time: 0.0036 ms) ✔
selectCourses_grouped_knapsack2: 22 (Time: 0.0037 ms) ✔
selectCourses_dfs: 22 (Time: 0.0041 ms) ✔
selectCourses_binary_tree: 22 (Time: 0.0036 ms) ✔
selectCourses_generalized_items: 22 (Time: 0.0031 ms) ✔
测试用例数据量太小, 体现不出差距.
Footnotes
-
这里隐含了 M≤N 的条件 ↩