diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index f20d9c9..2ff12ae 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -5,6 +5,9 @@ steps: version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext command: | julia --project -e ' println("--- :julia: Instantiating project") diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 7979b22..9c14db2 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -7,6 +7,9 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext agents: queue: "juliagpu" cuda: "*" @@ -27,6 +30,9 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1b306d2..9860756 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,6 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -51,6 +50,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v5 with: files: lcov.info @@ -60,7 +61,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -75,6 +75,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v5 with: files: lcov.info diff --git a/Project.toml b/Project.toml index e1406d5..4671de8 100644 --- a/Project.toml +++ b/Project.toml @@ -17,17 +17,24 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +[weakdeps] +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[extensions] +NeuralOperatorsReactantExt = "Reactant" + [compat] ArgCheck = "2.3" ChainRulesCore = "1.24" ConcreteStructs = "0.2.3" FFTW = "1.8" -Lux = "1" -LuxCore = "1" -LuxLib = "1.2" -MLDataDevices = "1.2.0" -NNlib = "0.9.21" +Lux = "1.2.1" +LuxCore = "1.1" +LuxLib = "1.3.7" +MLDataDevices = "1.5" +NNlib = "0.9.24" Random = "1.10" +Reactant = "0.2.5" Static = "1.1.1" WeightInitializers = "1" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 29b4d3c..aee1e1e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" @@ -11,6 +12,7 @@ NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -18,12 +20,14 @@ CairoMakie = "0.12.11" CondaPkg = "0.2.23" DataDeps = "0.7.13" Documenter = "1.7.0" -Lux = "1" +Enzyme = "0.13.24" +Lux = "1.2.1" LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" NeuralOperators = "0.5" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Printf = "1.10" PythonCall = "0.9.23" +Reactant = "0.2.11" Zygote = "0.6.71" diff --git a/docs/pages.jl b/docs/pages.jl index 2c9c8a4..e0fbea1 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -6,6 +6,7 @@ pages = [ "NOMAD" => "models/nomad.md" ], "Tutorials" => [ + "XLA Compilation" => "tutorials/reactant.md", "Burgers Equation" => "tutorials/burgers.md" ], "API Reference" => "api.md" diff --git a/docs/src/tutorials/reactant.md b/docs/src/tutorials/reactant.md new file mode 100644 index 0000000..8077c22 --- /dev/null +++ b/docs/src/tutorials/reactant.md @@ -0,0 +1,60 @@ +# Compiling NeuralOperators.jl using Reactant.jl + +```@example reactant +using NeuralOperators, Lux, Random, Enzyme, Reactant + +function sumabs2first(model, ps, st, x) + z, _ = model(x, ps, st) + return sum(abs2, z) +end + +dev = reactant_device() +``` + +## Compiling DeepONet + +```@example reactant +deeponet = DeepONet() +ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev; + +u = rand(Float32, 64, 32) |> dev; +y = rand(Float32, 1, 128, 32) |> dev; +nothing # hide + +@jit deeponet((u, y), ps, st) +``` + +Computing the gradient of the DeepONet model. + +```@example reactant +function ∇deeponet(model, ps, st, (u, y)) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const((u, y)) + ) +end + +@jit ∇deeponet(deeponet, ps, st, (u, y)) +``` + +## Compiling FourierNeuralOperator + +```@example reactant +fno = FourierNeuralOperator() +ps, st = Lux.setup(Random.default_rng(), fno) |> dev; + +x = rand(Float32, 2, 32, 5) |> dev; + +@jit fno(x, ps, st) +``` + +Computing the gradient of the FourierNeuralOperator model. + +```@example reactant +function ∇fno(model, ps, st, x) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const(x) + ) +end + +@jit ∇fno(fno, ps, st, x) +``` diff --git a/ext/NeuralOperatorsReactantExt.jl b/ext/NeuralOperatorsReactantExt.jl new file mode 100644 index 0000000..d3665aa --- /dev/null +++ b/ext/NeuralOperatorsReactantExt.jl @@ -0,0 +1,31 @@ +module NeuralOperatorsReactantExt + +using FFTW: FFTW +using NeuralOperators: NeuralOperators, FourierTransform +using NNlib: NNlib +using Reactant: Reactant, TracedRArray, AnyTracedRArray + +# XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed +function NeuralOperators.transform( + ft::FourierTransform, x::AnyTracedRArray{T, N}) where {T, N} + x_c = Reactant.TracedUtils.promote_to( + TracedRArray{Complex{T}, N}, + Reactant.TracedUtils.materialize_traced_array(x) + ) + return FFTW.fft(x_c, 1:ndims(ft)) +end + +function NeuralOperators.inverse( + ft::FourierTransform, x::AnyTracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N} + return real(FFTW.ifft(x, 1:ndims(ft))) +end + +function NeuralOperators.fast_pad_zeros(x::AnyTracedRArray, pad_dims) + return NNlib.pad_zeros( + Reactant.TracedUtils.materialize_traced_array(x), + NeuralOperators.expand_pad_dims(pad_dims); + dims=ntuple(identity, ndims(x) - 2) + ) +end + +end diff --git a/src/layers.jl b/src/layers.jl index dff7b13..9a38630 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -76,8 +76,7 @@ function operator_conv(x, tform::AbstractTransform, weights) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false; - dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) + x_padded = fast_pad_zeros(x_p, pad_dims) return inverse(tform, x_padded, size(x)) end diff --git a/src/utils.jl b/src/utils.jl index 459a1b4..2cca4b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -51,3 +51,8 @@ function ∇safe_batched_adjoint( ::Type{<:AbstractGPUDevice}, Δ::AbstractArray{T, 3}) where {T} return NoTangent(), stack(adjoint, eachslice(Δ; dims=3)) end + +function fast_pad_zeros(x, pad_dims)::typeof(x) + return NNlib.pad_zeros( + x, expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2)) +end diff --git a/test/Project.toml b/test/Project.toml index 8890976..b863b98 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ LuxCore = "1" LuxLib = "1.2" LuxTestUtils = "1.1.2" MLDataDevices = "1" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Pkg = "1.10" Preferences = "1" Random = "1.10"