Skip to content

Commit 73a2e2c

Browse files
committed
add another test
1 parent 021ba44 commit 73a2e2c

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#define PROBLEM \
2+
"https://judge.yosupo.jp/problem/tree_path_composite_sum"
3+
#include "../template.hpp"
4+
#include "../../../library/trees/shallowest_decomp_tree.hpp"
5+
#include "../../../library/math/mod_int.hpp"
6+
using line = array<mint, 2>;
7+
// returns f(g(x)) = f[0]*(g[0]*x+g[1]) + f[1]
8+
line compose(line f, line g) {
9+
return {f[0] * g[0], f[1] + f[0] * g[1]};
10+
}
11+
int main() {
12+
cin.tie(0)->sync_with_stdio(0);
13+
int n;
14+
cin >> n;
15+
vector<mint> a(n);
16+
for (int i = 0; i < n; i++) cin >> a[i].x;
17+
vector<vector<int>> adj(n);
18+
vector<vector<line>> weight(n);
19+
for (int i = 0; i < n - 1; i++) {
20+
int u, v, b, c;
21+
cin >> u >> v >> b >> c;
22+
adj[u].push_back(v);
23+
adj[v].push_back(u);
24+
weight[u].push_back({b, c});
25+
weight[v].push_back({b, c});
26+
}
27+
vector<mint> res(n);
28+
for (int i = 0; i < n; i++) res[i] = a[i];
29+
shallowest(adj, [&](int cent) {
30+
assert(ssize(adj[cent]) == ssize(weight[cent]));
31+
mint total_sum_evaluated = 0;
32+
int total_cnt_nodes = 0;
33+
mint curr_sum_evaluated;
34+
int curr_cnt_nodes;
35+
auto dfs = [&](auto&& self, int v, int p,
36+
line downwards, line upwards,
37+
bool forwards) -> void {
38+
// f(x) + f(y) + f(z) = b*x+c + b*y+c + b*z+c =
39+
// b*(x+y+z) + c*3
40+
res[v] = res[v] + upwards[0] * total_sum_evaluated +
41+
upwards[1] * total_cnt_nodes;
42+
if (forwards) {
43+
res[v] =
44+
res[v] + upwards[0] * a[cent] + upwards[1];
45+
res[cent] =
46+
res[cent] + downwards[0] * a[v] + downwards[1];
47+
}
48+
curr_cnt_nodes++;
49+
curr_sum_evaluated = curr_sum_evaluated +
50+
downwards[0] * a[v] + downwards[1];
51+
for (int i = 0; i < ssize(adj[v]); i++) {
52+
int u = adj[v][i];
53+
line curr_line = weight[v][i];
54+
if (u != p) {
55+
self(self, u, v, compose(downwards, curr_line),
56+
compose(curr_line, upwards), forwards);
57+
}
58+
}
59+
};
60+
for (int i = 0; i < ssize(adj[cent]); i++) {
61+
curr_sum_evaluated = 0;
62+
curr_cnt_nodes = 0;
63+
dfs(dfs, adj[cent][i], cent, weight[cent][i],
64+
weight[cent][i], 1);
65+
total_sum_evaluated =
66+
total_sum_evaluated + curr_sum_evaluated;
67+
total_cnt_nodes += curr_cnt_nodes;
68+
}
69+
total_sum_evaluated = 0;
70+
total_cnt_nodes = 0;
71+
for (int i = ssize(adj[cent]) - 1; i >= 0; i--) {
72+
curr_sum_evaluated = 0;
73+
curr_cnt_nodes = 0;
74+
dfs(dfs, adj[cent][i], cent, weight[cent][i],
75+
weight[cent][i], 0);
76+
total_sum_evaluated =
77+
total_sum_evaluated + curr_sum_evaluated;
78+
total_cnt_nodes += curr_cnt_nodes;
79+
}
80+
for (int v : adj[cent]) {
81+
for (int i = 0; i < ssize(adj[v]); i++) {
82+
if (adj[v][i] == cent) {
83+
swap(weight[v][i], weight[v].back());
84+
weight[v].pop_back();
85+
break;
86+
}
87+
}
88+
}
89+
});
90+
for (int i = 0; i < n; i++) cout << res[i].x << ' ';
91+
cout << '\n';
92+
}

0 commit comments

Comments
 (0)