Skip to content

Commit

Permalink
docs: rename to reactant compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 55fc1fe commit 58b32f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 27 deletions.
2 changes: 1 addition & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pages = [
"NOMAD" => "models/nomad.md"
],
"Tutorials" => [
"XLA Compilation" => "tutorials/xla_compilation.md",
"XLA Compilation" => "tutorials/reactant.md",
"Burgers Equation" => "tutorials/burgers.md"
],
"API Reference" => "api.md"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Compiling NeuralOperators.jl using Reactant.jl

```@example xla_compilation
```@example reactant
using NeuralOperators, Lux, Random, Enzyme, Reactant
function sumabs2first(model, ps, st, x)
Expand All @@ -13,12 +13,12 @@ dev = reactant_device()

## Compiling DeepONet

```@example xla_compilation
```@example reactant
deeponet = DeepONet()
ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev;
u = rand(Float32, 64, 1024) |> dev;
y = rand(Float32, 1, 128, 1024) |> dev;
u = rand(Float32, 64, 32) |> dev;
y = rand(Float32, 1, 128, 32) |> dev;
nothing # hide
deeponet_compiled = @compile deeponet((u, y), ps, st)
Expand All @@ -27,18 +27,11 @@ deeponet_compiled((u, y), ps, st)[1]

Computing the gradient of the DeepONet model.

```@example xla_compilation
```@example reactant
function ∇deeponet(model, ps, st, (u, y))
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse,
sumabs2first,
Const(model),
Duplicated(ps, dps),
Const(st),
Const((u, y))
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const((u, y))
)
return dps
end
∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y))
Expand All @@ -47,30 +40,23 @@ end

## Compiling FourierNeuralOperator

```@example xla_compilation
```@example reactant
fno = FourierNeuralOperator()
ps, st = Lux.setup(Random.default_rng(), fno) |> dev;
x = rand(Float32, 2, 1024, 5) |> dev;
x = rand(Float32, 2, 32, 5) |> dev;
fno_compiled = @compile fno(x, ps, st)
fno_compiled(x, ps, st)[1]
```

Computing the gradient of the FourierNeuralOperator model.

```@example xla_compilation
```@example reactant
function ∇fno(model, ps, st, x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse,
sumabs2first,
Const(model),
Duplicated(ps, dps),
Const(st),
Const(x)
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const(x)
)
return dps
end
∇fno_compiled = @compile ∇fno(fno, ps, st, x)
Expand Down

0 comments on commit 58b32f2

Please sign in to comment.