Skip to content

Commit 427581b

Browse files
authored
Merge branch 'main' into mhauru/varnamedvector-speed
2 parents 3b8b4a8 + 90b591b commit 427581b

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

.github/workflows/Benchmarking.yml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@ jobs:
3333
echo "$version_info" >> $GITHUB_ENV
3434
echo "EOF" >> $GITHUB_ENV
3535
36-
# Capture benchmark output into a variable
36+
# Capture benchmark output into a variable. The sed and tail calls cut out anything but the
37+
# final block of results.
3738
echo "Running Benchmarks..."
38-
benchmark_output=$(julia --project=benchmarks benchmarks/benchmarks.jl)
39-
39+
benchmark_output=$(\
40+
julia --project=benchmarks benchmarks/benchmarks.jl \
41+
| sed -n '/Final results:/,$p' \
42+
| tail -n +2\
43+
)
44+
4045
# Print benchmark results directly to the workflow log
4146
echo "Benchmark Results:"
4247
echo "$benchmark_output"
43-
48+
4449
# Set the benchmark output as an env var for later steps
4550
echo "BENCHMARK_OUTPUT<<EOF" >> $GITHUB_ENV
4651
echo "$benchmark_output" >> $GITHUB_ENV

benchmarks/benchmarks.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,27 @@ using StableRNGs: StableRNG
77

88
rng = StableRNG(23)
99

10+
function print_results(results_table)
11+
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
12+
header = [
13+
"Model",
14+
"Dim",
15+
"AD Backend",
16+
"VarInfo",
17+
"Linked",
18+
"t(eval)/t(ref)",
19+
"t(grad)/t(eval)",
20+
]
21+
return pretty_table(
22+
table_matrix;
23+
column_labels=header,
24+
backend=:text,
25+
formatters=[fmt__printf("%.1f", [6, 7])],
26+
fit_table_in_display_horizontally=false,
27+
fit_table_in_display_vertically=false,
28+
)
29+
end
30+
1031
# Create DynamicPPL.Model instances to run benchmarks on.
1132
smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))
1233
loop_univariate1k, multivariate1k = begin
@@ -84,17 +105,9 @@ for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinati
84105
relative_ad_eval_time,
85106
),
86107
)
108+
println("Results so far:")
109+
print_results(results_table)
87110
end
88111

89-
table_matrix = hcat(Iterators.map(collect, zip(results_table...))...)
90-
header = [
91-
"Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"
92-
]
93-
pretty_table(
94-
table_matrix;
95-
column_labels=header,
96-
backend=:text,
97-
formatters=[fmt__printf("%.1f", [6, 7])],
98-
fit_table_in_display_horizontally=false,
99-
fit_table_in_display_vertically=false,
100-
)
112+
println("Final results:")
113+
print_results(results_table)

0 commit comments

Comments
 (0)