Skip to content

Commit

Permalink
perf: add HNN (2nd order AD) benchmarks
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
avik-pal committed Jan 11, 2025
1 parent 599ab8d commit f0101da
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
15 changes: 15 additions & 0 deletions perf/HNN/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
Reactant = {path = "../.."}
167 changes: 167 additions & 0 deletions perf/HNN/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
using Lux,
Random,
Reactant,
Enzyme,
Zygote,
BenchmarkTools,
LuxCUDA,
DataFrames,
OrderedCollections,
CSV,
Comonicon

struct HamiltonianNN{E,M} <: AbstractLuxWrapperLayer{:model}
model::M

HamiltonianNN{E}(model::M) where {E,M} = new{E,M}(model)
end

function (hnn::HamiltonianNN{false})(x::AbstractArray, ps, st)
model = StatefulLuxLayer{true}(hnn.model, ps, st)
∂x = only(Zygote.gradient(sum model, x))
n = size(x, ndims(x) - 1) ÷ 2
y = cat(
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
selectdim(∂x, ndims(∂x) - 1, 1:n);
dims=Val(ndims(∂x) - 1),
)
return y, model.st
end

function (hnn::HamiltonianNN{true})(x::AbstractArray, ps, st)
∂x = similar(x)
model = StatefulLuxLayer{true}(hnn.model, ps, st)
Enzyme.autodiff(Reverse, Const(sum model), Duplicated(x, ∂x))
n = size(x, ndims(x) - 1) ÷ 2
y = cat(
selectdim(∂x, ndims(∂x) - 1, (n + 1):(2n)),
selectdim(∂x, ndims(∂x) - 1, 1:n);
dims=Val(ndims(∂x) - 1),
)
return y, model.st
end

function loss_fn(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return MSELoss()(pred, y)
end

function ∇zygote_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Zygote.gradient(loss_fn, model, ps, st, x, y)
return dps, dx
end

function ∇enzyme_loss_fn(model, ps, st, x, y)
_, dps, _, dx, _ = Enzyme.gradient(
Reverse, loss_fn, Const(model), ps, Const(st), x, Const(y)
)
return dps, dx
end

function reclaim_fn(backend, reactant)
if backend == "gpu" && !reactant
CUDA.reclaim()
end
GC.gc(true)
return nothing
end

Comonicon.@main function main(; backend::String="gpu")
@assert backend in ("cpu", "gpu")

Reactant.set_default_backend(backend)
filename = joinpath(@__DIR__, "results_$(backend).csv")

@info "Using backend" backend

cdev = cpu_device()
gdev = backend == "gpu" ? gpu_device(; force=true) : cdev
xdev = reactant_device(; force=true)

df = DataFrame(
OrderedDict(
"Kind" => [],
"Fwd Vanilla" => [],
"Fwd Reactant" => [],
"Fwd Reactant SpeedUp" => [],
"Bwd Zygote" => [],
"Bwd Reactant" => [],
"Bwd Reactant SpeedUp" => [],
),
)

mlp = Chain(
Dense(32, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 128, gelu),
Dense(128, 1),
)

model_enz = HamiltonianNN{true}(mlp)
model_zyg = HamiltonianNN{false}(mlp)

ps, st = Lux.setup(Random.default_rng(), model_enz)

x = randn(Float32, 32, 1024)
y = randn(Float32, 32, 1024)

x_gdev = gdev(x)
y_gdev = gdev(y)
x_xdev = xdev(x)
y_xdev = xdev(y)

ps_gdev, st_gdev = gdev((ps, st))
ps_xdev, st_xdev = xdev((ps, st))

@info "Compiling Forward Functions"
lfn_compiled = @compile sync = true loss_fn(model_enz, ps_xdev, st_xdev, x_xdev, y_xdev)

@info "Running Forward Benchmarks"

t_gdev = @belapsed CUDA.@sync(loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)) setup = (reclaim_fn(
$backend, false
))

t_xdev = @belapsed $lfn_compiled($model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev) setup = (reclaim_fn(
$backend, true
))

@info "Forward Benchmarks" t_gdev t_xdev

@info "Compiling Backward Functions"
grad_fn_compiled = @compile sync = true ∇enzyme_loss_fn(
model_enz, ps_xdev, st_xdev, x_xdev, y_xdev
)

@info "Running Backward Benchmarks"

t_rev_gdev = @belapsed CUDA.@sync(
∇zygote_loss_fn($model_zyg, $ps_gdev, $st_gdev, $x_gdev, $y_gdev)
) setup = (reclaim_fn($backend, false))

t_rev_xdev = @belapsed $grad_fn_compiled(
$model_enz, $ps_xdev, $st_xdev, $x_xdev, $y_xdev
) setup = (reclaim_fn($backend, true))

@info "Backward Benchmarks" t_rev_gdev t_rev_xdev

push!(
df,
[
"HNN",
t_gdev,
t_xdev,
t_gdev / t_xdev,
t_rev_gdev,
t_rev_xdev,
t_rev_gdev / t_rev_xdev,
],
)

display(df)
CSV.write(filename, df)

@info "Results saved to $filename"
return nothing
end
2 changes: 2 additions & 0 deletions perf/HNN/results_cpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
HNN,0.012209751,0.002101077,5.811186834180757,0.173089096,0.004597676,37.64708430955117
2 changes: 2 additions & 0 deletions perf/HNN/results_gpu.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Kind,Fwd Vanilla,Fwd Reactant,Fwd Reactant SpeedUp,Bwd Zygote,Bwd Reactant,Bwd Reactant SpeedUp
HNN,0.000681027,8.4721e-5,8.038467440186022,0.003330234,0.00012123,27.470378619153674
1 change: 1 addition & 0 deletions perf/KAN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function reclaim_fn(backend, reactant)
CUDA.reclaim()
end
GC.gc(true)
return nothing
end

Comonicon.@main function main(; backend::String="gpu")
Expand Down

0 comments on commit f0101da

Please sign in to comment.