@@ -17,6 +17,9 @@ using Zygote: Zygote
17
17
@inline sumabs2 (model, x, p, st) = sum (abs2, first (Lux. apply (model, x, p, st)))
18
18
@inline sumabs2 (model, x) = sum (abs2, model (x))
19
19
20
+ @inline sumcos (x) = sum (cos, x)
21
+ @inline ∇sumcos (x) = Enzyme. gradient (Reverse, sumcos, x)
22
+
20
23
function benchmark_group_to_backend (benchmark_group:: String )
21
24
benchmark_group == " CPU" && return CPUDevice ()
22
25
benchmark_group == " CUDA" && return CUDADevice ()
34
37
function setup_benchmarks! (suite:: BenchmarkGroup , backend:: String )
35
38
dev = benchmark_group_to_backend (backend)
36
39
40
+ # Simple Benchmarks
41
+ setup_simple_benchmark! (suite, backend)
42
+
43
+ # Add Lux Benchmarks
37
44
setup_vit_benchmark! (suite, backend, dev)
45
+ setup_vgg_benchmark! (suite, backend, dev)
46
+
47
+ return nothing
48
+ end
49
+
50
+ # Some Simple Benchmarks
51
+ function setup_simple_benchmark! (suite:: BenchmarkGroup , backend)
52
+ for opt_pass in (:all , :only_enzyme , :after_enzyme , :before_enzyme )
53
+ tag = opt_pass == :all ? " Reactant" : " Reactant (optimize = $(Meta. quot (opt_pass)) )"
54
+
55
+ suite[" (Basics) 2D sum (2 x 10)" ][" forward (compilation)" ][backend][tag] = @benchmarkable begin
56
+ @compile optimize = $ (opt_pass) sum (x)
57
+ end setup = begin
58
+ x = Reactant. ConcreteRArray (ones (2 , 10 ))
59
+ end
60
+
61
+ suite[" (Basics) sum(cos, x) (2 x 10)" ][" forward (compilation)" ][backend][tag] = @benchmarkable begin
62
+ @compile optimize = $ (opt_pass) sumcos (x)
63
+ end setup = begin
64
+ x = Reactant. ConcreteRArray (ones (2 , 10 ))
65
+ end
66
+ end
67
+
68
+ suite[" Basics ∇sumcos (2 x 10)" ][" forward (compilation)" ][backend][" Reactant" ] = @benchmarkable begin
69
+ @compile optimize = :all ∇sumcos (x)
70
+ end setup = begin
71
+ x = Reactant. ConcreteRArray (ones (2 , 10 ))
72
+ end
38
73
39
74
return nothing
40
75
end
@@ -50,6 +85,20 @@ function setup_vit_benchmark!(suite::BenchmarkGroup, backend, dev::AbstractDevic
50
85
end
51
86
end
52
87
88
+ function setup_vgg_benchmark! (suite:: BenchmarkGroup , backend, dev:: AbstractDevice )
89
+ for depth in (11 , 13 , 16 , 19 ), bsize in (4 , 16 , 32 ), batchnorm in (false , true )
90
+ benchmark_name = " VGG$(depth) bn=$(batchnorm) (224 x 224 x 3 x $(bsize) )"
91
+ setup_lux_forward_pass_benchmark! (
92
+ suite,
93
+ benchmark_name,
94
+ backend,
95
+ Vision. VGG (depth; pretrained= false , batchnorm),
96
+ (224 , 224 , 3 , bsize),
97
+ dev,
98
+ )
99
+ end
100
+ end
101
+
53
102
function setup_lux_forward_pass_benchmark! (
54
103
suite:: BenchmarkGroup ,
55
104
benchmark_name:: String ,
@@ -89,6 +138,20 @@ function setup_lux_forward_pass_benchmark!(
89
138
GC. gc ()
90
139
reclaim ($ dev)
91
140
end
141
+
142
+ suite[benchmark_name][" forward (compilation)" ][backend][tag] = @benchmarkable begin
143
+ @compile optimize = $ (opt_pass) Lux. apply ($ model, x_ra, ps_ra, st_test_ra)
144
+ end setup = begin
145
+ GC. gc ()
146
+ reclaim ($ dev)
147
+ x, ps, st = general_lux_setup ($ model, $ x_dims)
148
+ st_test = Lux. testmode (st)
149
+ x_ra = Reactant. to_rarray (x)
150
+ ps_ra = Reactant. to_rarray (ps)
151
+ st_test_ra = Reactant. to_rarray (st_test)
152
+ GC. gc ()
153
+ reclaim ($ dev)
154
+ end
92
155
end
93
156
94
157
return nothing
0 commit comments