diff --git a/.verify-helper/timestamps.remote.json b/.verify-helper/timestamps.remote.json index cf01aa5f..d69f62e4 100644 --- a/.verify-helper/timestamps.remote.json +++ b/.verify-helper/timestamps.remote.json @@ -144,6 +144,7 @@ "tests/library_checker_aizu_tests/trees/kth_path_tree_lift.test.cpp": "2025-08-14 10:27:46 -0600", "tests/library_checker_aizu_tests/trees/lca_all_methods_aizu.test.cpp": "2025-08-21 12:17:27 -0600", "tests/library_checker_aizu_tests/trees/lca_all_methods_lib_checker.test.cpp": "2025-08-21 12:17:27 -0600", -"tests/library_checker_aizu_tests/trees/shallowest_aizu_tree_height.test.cpp": "2025-08-30 00:58:07 -0600", +"tests/library_checker_aizu_tests/trees/shallowest_aizu_tree_height.test.cpp": "2025-09-03 10:28:33 -0600", +"tests/library_checker_aizu_tests/trees/shallowest_lib_checker_tree_path_composite.test.cpp": "2025-09-03 10:28:33 -0600", "tests/library_checker_aizu_tests/trees/subtree_isomorphism.test.cpp": "2025-08-14 10:27:46 -0600" } \ No newline at end of file diff --git a/library/trees/shallowest_decomp_tree.hpp b/library/trees/shallowest_decomp_tree.hpp index 7509491f..8256982c 100644 --- a/library/trees/shallowest_decomp_tree.hpp +++ b/library/trees/shallowest_decomp_tree.hpp @@ -21,7 +21,7 @@ void shallowest(auto& adj, auto f) { return dp; }; dfs(dfs, 0, 0); - for (vi vec : order | views::reverse) + for (const vi& vec : order | views::reverse) for (int v : vec) { f(v); for (int u : adj[v]) diff --git a/tests/library_checker_aizu_tests/trees/shallowest_lib_checker_tree_path_composite.test.cpp b/tests/library_checker_aizu_tests/trees/shallowest_lib_checker_tree_path_composite.test.cpp new file mode 100644 index 00000000..6c115054 --- /dev/null +++ b/tests/library_checker_aizu_tests/trees/shallowest_lib_checker_tree_path_composite.test.cpp @@ -0,0 +1,93 @@ +#define PROBLEM \ + "https://judge.yosupo.jp/problem/tree_path_composite_sum" +#undef _GLIBCXX_DEBUG +#include "../template.hpp" +#include "../../../library/trees/shallowest_decomp_tree.hpp" +#include "../../../library/math/mod_int.hpp" +using line = array; +// returns f(g(x)) = f[0]*(g[0]*x+g[1]) + f[1] +line compose(line f, line g) { + return {f[0] * g[0], f[1] + f[0] * g[1]}; +} +int main() { + cin.tie(0)->sync_with_stdio(0); + int n; + cin >> n; + vector a(n); + for (int i = 0; i < n; i++) cin >> a[i].x; + vector> adj(n); + vector> weight(n); + for (int i = 0; i < n - 1; i++) { + int u, v, b, c; + cin >> u >> v >> b >> c; + adj[u].push_back(v); + adj[v].push_back(u); + weight[u].push_back({b, c}); + weight[v].push_back({b, c}); + } + vector res(n); + for (int i = 0; i < n; i++) res[i] = a[i]; + shallowest(adj, [&](int cent) { + assert(ssize(adj[cent]) == ssize(weight[cent])); + mint total_sum_evaluated = 0; + int total_cnt_nodes = 0; + mint curr_sum_evaluated = 0; + int curr_cnt_nodes = 0; + auto dfs = [&](auto&& self, int v, int p, + line downwards, line upwards, + bool forwards) -> void { + // f(x) + f(y) + f(z) = b*x+c + b*y+c + b*z+c = + // b*(x+y+z) + c*3 + res[v] = res[v] + upwards[0] * total_sum_evaluated + + upwards[1] * total_cnt_nodes; + if (forwards) { + res[v] = + res[v] + upwards[0] * a[cent] + upwards[1]; + res[cent] = + res[cent] + downwards[0] * a[v] + downwards[1]; + } + curr_cnt_nodes++; + curr_sum_evaluated = curr_sum_evaluated + + downwards[0] * a[v] + downwards[1]; + for (int i = 0; i < ssize(adj[v]); i++) { + int u = adj[v][i]; + line curr_line = weight[v][i]; + if (u != p) { + self(self, u, v, compose(downwards, curr_line), + compose(curr_line, upwards), forwards); + } + } + }; + for (int i = 0; i < ssize(adj[cent]); i++) { + curr_sum_evaluated = 0; + curr_cnt_nodes = 0; + dfs(dfs, adj[cent][i], cent, weight[cent][i], + weight[cent][i], 1); + total_sum_evaluated = + total_sum_evaluated + curr_sum_evaluated; + total_cnt_nodes += curr_cnt_nodes; + } + total_sum_evaluated = 0; + total_cnt_nodes = 0; + for (int i = ssize(adj[cent]) - 1; i >= 0; i--) { + curr_sum_evaluated = 0; + curr_cnt_nodes = 0; + dfs(dfs, adj[cent][i], cent, weight[cent][i], + weight[cent][i], 0); + total_sum_evaluated = + total_sum_evaluated + curr_sum_evaluated; + total_cnt_nodes += curr_cnt_nodes; + } + for (int v : adj[cent]) { + for (int i = 0; i < ssize(adj[v]); i++) { + if (adj[v][i] == cent) { + swap(weight[v][i], weight[v].back()); + weight[v].pop_back(); + break; + } + } + } + }); + for (int i = 0; i < n; i++) cout << res[i].x << ' '; + cout << '\n'; +}