|
| 1 | +#define PROBLEM \ |
| 2 | + "https://judge.yosupo.jp/problem/tree_path_composite_sum" |
| 3 | +#include "../template.hpp" |
| 4 | +#include "../../../library/trees/shallowest_decomp_tree.hpp" |
| 5 | +#include "../../../library/math/mod_int.hpp" |
| 6 | +using line = array<mint, 2>; |
| 7 | +// returns f(g(x)) = f[0]*(g[0]*x+g[1]) + f[1] |
| 8 | +line compose(line f, line g) { |
| 9 | + return {f[0] * g[0], f[1] + f[0] * g[1]}; |
| 10 | +} |
| 11 | +int main() { |
| 12 | + cin.tie(0)->sync_with_stdio(0); |
| 13 | + int n; |
| 14 | + cin >> n; |
| 15 | + vector<mint> a(n); |
| 16 | + for (int i = 0; i < n; i++) cin >> a[i].x; |
| 17 | + vector<vector<int>> adj(n); |
| 18 | + vector<vector<line>> weight(n); |
| 19 | + for (int i = 0; i < n - 1; i++) { |
| 20 | + int u, v, b, c; |
| 21 | + cin >> u >> v >> b >> c; |
| 22 | + adj[u].push_back(v); |
| 23 | + adj[v].push_back(u); |
| 24 | + weight[u].push_back({b, c}); |
| 25 | + weight[v].push_back({b, c}); |
| 26 | + } |
| 27 | + vector<mint> res(n); |
| 28 | + for (int i = 0; i < n; i++) res[i] = a[i]; |
| 29 | + shallowest(adj, [&](int cent) { |
| 30 | + assert(ssize(adj[cent]) == ssize(weight[cent])); |
| 31 | + mint total_sum_evaluated = 0; |
| 32 | + int total_cnt_nodes = 0; |
| 33 | + mint curr_sum_evaluated; |
| 34 | + int curr_cnt_nodes; |
| 35 | + auto dfs = [&](auto&& self, int v, int p, |
| 36 | + line downwards, line upwards, |
| 37 | + bool forwards) -> void { |
| 38 | + // f(x) + f(y) + f(z) = b*x+c + b*y+c + b*z+c = |
| 39 | + // b*(x+y+z) + c*3 |
| 40 | + res[v] = res[v] + upwards[0] * total_sum_evaluated + |
| 41 | + upwards[1] * total_cnt_nodes; |
| 42 | + if (forwards) { |
| 43 | + res[v] = |
| 44 | + res[v] + upwards[0] * a[cent] + upwards[1]; |
| 45 | + res[cent] = |
| 46 | + res[cent] + downwards[0] * a[v] + downwards[1]; |
| 47 | + } |
| 48 | + curr_cnt_nodes++; |
| 49 | + curr_sum_evaluated = curr_sum_evaluated + |
| 50 | + downwards[0] * a[v] + downwards[1]; |
| 51 | + for (int i = 0; i < ssize(adj[v]); i++) { |
| 52 | + int u = adj[v][i]; |
| 53 | + line curr_line = weight[v][i]; |
| 54 | + if (u != p) { |
| 55 | + self(self, u, v, compose(downwards, curr_line), |
| 56 | + compose(curr_line, upwards), forwards); |
| 57 | + } |
| 58 | + } |
| 59 | + }; |
| 60 | + for (int i = 0; i < ssize(adj[cent]); i++) { |
| 61 | + curr_sum_evaluated = 0; |
| 62 | + curr_cnt_nodes = 0; |
| 63 | + dfs(dfs, adj[cent][i], cent, weight[cent][i], |
| 64 | + weight[cent][i], 1); |
| 65 | + total_sum_evaluated = |
| 66 | + total_sum_evaluated + curr_sum_evaluated; |
| 67 | + total_cnt_nodes += curr_cnt_nodes; |
| 68 | + } |
| 69 | + total_sum_evaluated = 0; |
| 70 | + total_cnt_nodes = 0; |
| 71 | + for (int i = ssize(adj[cent]) - 1; i >= 0; i--) { |
| 72 | + curr_sum_evaluated = 0; |
| 73 | + curr_cnt_nodes = 0; |
| 74 | + dfs(dfs, adj[cent][i], cent, weight[cent][i], |
| 75 | + weight[cent][i], 0); |
| 76 | + total_sum_evaluated = |
| 77 | + total_sum_evaluated + curr_sum_evaluated; |
| 78 | + total_cnt_nodes += curr_cnt_nodes; |
| 79 | + } |
| 80 | + for (int v : adj[cent]) { |
| 81 | + for (int i = 0; i < ssize(adj[v]); i++) { |
| 82 | + if (adj[v][i] == cent) { |
| 83 | + swap(weight[v][i], weight[v].back()); |
| 84 | + weight[v].pop_back(); |
| 85 | + break; |
| 86 | + } |
| 87 | + } |
| 88 | + } |
| 89 | + }); |
| 90 | + for (int i = 0; i < n; i++) cout << res[i].x << ' '; |
| 91 | + cout << '\n'; |
| 92 | +} |
0 commit comments