From 4a068550bf61554448cbd812a01cf00acc83cf07 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 10:25:34 -0500 Subject: [PATCH 1/9] feat: compile neural operators using Reactant --- docs/Project.toml | 2 + docs/pages.jl | 1 + docs/src/tutorials/xla_compilation.md | 82 +++++++++++++++++++++++++++ src/layers.jl | 2 +- 4 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 docs/src/tutorials/xla_compilation.md diff --git a/docs/Project.toml b/docs/Project.toml index 29b4d3c..c4ccdf2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" @@ -11,6 +12,7 @@ NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/docs/pages.jl b/docs/pages.jl index 2c9c8a4..87a4d3c 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -6,6 +6,7 @@ pages = [ "NOMAD" => "models/nomad.md" ], "Tutorials" => [ + "XLA Compilation" => "tutorials/xla_compilation.md", "Burgers Equation" => "tutorials/burgers.md" ], "API Reference" => "api.md" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/xla_compilation.md new file mode 100644 index 0000000..b79b05b --- /dev/null +++ b/docs/src/tutorials/xla_compilation.md @@ -0,0 +1,82 @@ +# Compiling NeuralOperators.jl using Reactant.jl + +```@example xla_compilation +using NeuralOperators, Lux, Random, Enzyme, Reactant + +function sumabs2first(model, ps, st, (u, y)) + z, _ = model((u, y), ps, st) + return sum(abs2, z) +end + +if "gpu" in keys(Reactant.XLA.backends) + Reactant.set_default_backend("gpu") +end + +dev = xla_device() +``` + +## Compiling DeepONet + +```@example xla_compilation +deeponet = DeepONet() +ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev + +u = rand(Float32, 64, 1024) |> dev +y = rand(Float32, 1, 128, 1024) |> dev +nothing # hide + +deeponet_compiled = @compile deeponet((u, y), ps, st) +deeponet_compiled((u, y), ps, st) +``` + +Computing the gradient of the DeepONet model. + +```@example xla_compilation +function ∇deeponet(model, ps, st, (u, y)) + dps = Enzyme.make_zero(ps) + Enzyme.autodiff( + Enzyme.Reverse, + sumabs2first, + Const(model), + Duplicated(ps, dps), + Const(st), + Const((u, y)) + ) + return dps +end + +∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y)) +∇deeponet_compiled(deeponet, ps, st, (u, y)) +``` + +## Compiling FourierNeuralOperator + +```@example xla_compilation +fno = FourierNeuralOperator() +ps, st = Lux.setup(Random.default_rng(), fno) |> dev + +x = rand(Float32, 2, 1024, 5) |> dev + +fno_compiled = @compile fno(x, ps, st) +fno_compiled(x, ps, st) +``` + +Computing the gradient of the FourierNeuralOperator model. + +```@example xla_compilation +function ∇fno(model, ps, st, x) + dps = Enzyme.make_zero(ps) + Enzyme.autodiff( + Enzyme.Reverse, + sumabs2first, + Const(model), + Duplicated(ps, dps), + Const(st), + Const(x) + ) + return dps +end + +∇fno_compiled = @compile ∇fno(fno, ps, st, x) +∇fno_compiled(fno, ps, st, x) +``` diff --git a/src/layers.jl b/src/layers.jl index dff7b13..1f0cbd0 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -76,7 +76,7 @@ function operator_conv(x, tform::AbstractTransform, weights) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_constant(x_p, expand_pad_dims(pad_dims), false; + x_padded = NNlib.pad_zeros(x_p, expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) return inverse(tform, x_padded, size(x)) From 1610a7e72a91d730471fd669dd6ac017267fd191 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Nov 2024 14:37:00 -0500 Subject: [PATCH 2/9] fix: update versions --- Project.toml | 8 ++++---- docs/Project.toml | 2 +- docs/src/tutorials/xla_compilation.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index e1406d5..15d269a 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,10 @@ ArgCheck = "2.3" ChainRulesCore = "1.24" ConcreteStructs = "0.2.3" FFTW = "1.8" -Lux = "1" -LuxCore = "1" -LuxLib = "1.2" -MLDataDevices = "1.2.0" +Lux = "1.2.1" +LuxCore = "1.1" +LuxLib = "1.3.7" +MLDataDevices = "1.5" NNlib = "0.9.21" Random = "1.10" Static = "1.1.1" diff --git a/docs/Project.toml b/docs/Project.toml index c4ccdf2..7aaae9b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -20,7 +20,7 @@ CairoMakie = "0.12.11" CondaPkg = "0.2.23" DataDeps = "0.7.13" Documenter = "1.7.0" -Lux = "1" +Lux = "1.2.1" LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/xla_compilation.md index b79b05b..7080d05 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/xla_compilation.md @@ -12,7 +12,7 @@ if "gpu" in keys(Reactant.XLA.backends) Reactant.set_default_backend("gpu") end -dev = xla_device() +dev = reactant_device() ``` ## Compiling DeepONet From f92ebb24bcd993b800a6bfdb050b41bd01dde712 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Nov 2024 12:22:31 -0500 Subject: [PATCH 3/9] chore: remove selection code --- docs/src/tutorials/xla_compilation.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/xla_compilation.md index 7080d05..ed3b7b8 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/xla_compilation.md @@ -8,10 +8,6 @@ function sumabs2first(model, ps, st, (u, y)) return sum(abs2, z) end -if "gpu" in keys(Reactant.XLA.backends) - Reactant.set_default_backend("gpu") -end - dev = reactant_device() ``` From 758b8abcf1ad9d2e2c8edc4b4fc41f1e433f313b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 8 Nov 2024 17:55:57 -0500 Subject: [PATCH 4/9] chore: minor updates --- docs/Project.toml | 2 +- docs/src/tutorials/xla_compilation.md | 18 +++++++++--------- test/Project.toml | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 7aaae9b..204864c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -25,7 +25,7 @@ LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" NeuralOperators = "0.5" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Printf = "1.10" PythonCall = "0.9.23" Zygote = "0.6.71" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/xla_compilation.md index ed3b7b8..57b3620 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/xla_compilation.md @@ -3,8 +3,8 @@ ```@example xla_compilation using NeuralOperators, Lux, Random, Enzyme, Reactant -function sumabs2first(model, ps, st, (u, y)) - z, _ = model((u, y), ps, st) +function sumabs2first(model, ps, st, x) + z, _ = model(x, ps, st) return sum(abs2, z) end @@ -15,14 +15,14 @@ dev = reactant_device() ```@example xla_compilation deeponet = DeepONet() -ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev +ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev; -u = rand(Float32, 64, 1024) |> dev -y = rand(Float32, 1, 128, 1024) |> dev +u = rand(Float32, 64, 1024) |> dev; +y = rand(Float32, 1, 128, 1024) |> dev; nothing # hide deeponet_compiled = @compile deeponet((u, y), ps, st) -deeponet_compiled((u, y), ps, st) +deeponet_compiled((u, y), ps, st)[1] ``` Computing the gradient of the DeepONet model. @@ -49,12 +49,12 @@ end ```@example xla_compilation fno = FourierNeuralOperator() -ps, st = Lux.setup(Random.default_rng(), fno) |> dev +ps, st = Lux.setup(Random.default_rng(), fno) |> dev; -x = rand(Float32, 2, 1024, 5) |> dev +x = rand(Float32, 2, 1024, 5) |> dev; fno_compiled = @compile fno(x, ps, st) -fno_compiled(x, ps, st) +fno_compiled(x, ps, st)[1] ``` Computing the gradient of the FourierNeuralOperator model. diff --git a/test/Project.toml b/test/Project.toml index 8890976..b863b98 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ LuxCore = "1" LuxLib = "1.2" LuxTestUtils = "1.1.2" MLDataDevices = "1" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Pkg = "1.10" Preferences = "1" Random = "1.10" From 03811371004a85765e4d348fe704ce91fe17a617 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 21:16:41 -0500 Subject: [PATCH 5/9] docs: rename to reactant compilation --- .github/workflows/CI.yml | 2 - Project.toml | 2 +- docs/pages.jl | 2 +- .../{xla_compilation.md => reactant.md} | 38 ++++++------------- 4 files changed, 14 insertions(+), 30 deletions(-) rename docs/src/tutorials/{xla_compilation.md => reactant.md} (61%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1b306d2..68b36fe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,6 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -60,7 +59,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/Project.toml b/Project.toml index 15d269a..195fd47 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Lux = "1.2.1" LuxCore = "1.1" LuxLib = "1.3.7" MLDataDevices = "1.5" -NNlib = "0.9.21" +NNlib = "0.9.24" Random = "1.10" Static = "1.1.1" WeightInitializers = "1" diff --git a/docs/pages.jl b/docs/pages.jl index 87a4d3c..e0fbea1 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -6,7 +6,7 @@ pages = [ "NOMAD" => "models/nomad.md" ], "Tutorials" => [ - "XLA Compilation" => "tutorials/xla_compilation.md", + "XLA Compilation" => "tutorials/reactant.md", "Burgers Equation" => "tutorials/burgers.md" ], "API Reference" => "api.md" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/reactant.md similarity index 61% rename from docs/src/tutorials/xla_compilation.md rename to docs/src/tutorials/reactant.md index 57b3620..a90836b 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/reactant.md @@ -1,6 +1,6 @@ # Compiling NeuralOperators.jl using Reactant.jl -```@example xla_compilation +```@example reactant using NeuralOperators, Lux, Random, Enzyme, Reactant function sumabs2first(model, ps, st, x) @@ -13,12 +13,12 @@ dev = reactant_device() ## Compiling DeepONet -```@example xla_compilation +```@example reactant deeponet = DeepONet() ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev; -u = rand(Float32, 64, 1024) |> dev; -y = rand(Float32, 1, 128, 1024) |> dev; +u = rand(Float32, 64, 32) |> dev; +y = rand(Float32, 1, 128, 32) |> dev; nothing # hide deeponet_compiled = @compile deeponet((u, y), ps, st) @@ -27,18 +27,11 @@ deeponet_compiled((u, y), ps, st)[1] Computing the gradient of the DeepONet model. -```@example xla_compilation +```@example reactant function ∇deeponet(model, ps, st, (u, y)) - dps = Enzyme.make_zero(ps) - Enzyme.autodiff( - Enzyme.Reverse, - sumabs2first, - Const(model), - Duplicated(ps, dps), - Const(st), - Const((u, y)) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const((u, y)) ) - return dps end ∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y)) @@ -47,11 +40,11 @@ end ## Compiling FourierNeuralOperator -```@example xla_compilation +```@example reactant fno = FourierNeuralOperator() ps, st = Lux.setup(Random.default_rng(), fno) |> dev; -x = rand(Float32, 2, 1024, 5) |> dev; +x = rand(Float32, 2, 32, 5) |> dev; fno_compiled = @compile fno(x, ps, st) fno_compiled(x, ps, st)[1] @@ -59,18 +52,11 @@ fno_compiled(x, ps, st)[1] Computing the gradient of the FourierNeuralOperator model. -```@example xla_compilation +```@example reactant function ∇fno(model, ps, st, x) - dps = Enzyme.make_zero(ps) - Enzyme.autodiff( - Enzyme.Reverse, - sumabs2first, - Const(model), - Duplicated(ps, dps), - Const(st), - Const(x) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const(x) ) - return dps end ∇fno_compiled = @compile ∇fno(fno, ps, st, x) From 29cffd1ba2b58a2ab4368ddc1c05250187e6adbc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Nov 2024 14:36:30 -0500 Subject: [PATCH 6/9] feat: bypass rfft and irfft issue with Reactant --- .buildkite/documentation.yml | 3 +++ .buildkite/testing.yml | 6 ++++++ .github/workflows/CI.yml | 4 ++++ Project.toml | 7 +++++++ ext/NeuralOperatorsReactantExt.jl | 18 ++++++++++++++++++ 5 files changed, 38 insertions(+) create mode 100644 ext/NeuralOperatorsReactantExt.jl diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index f20d9c9..2ff12ae 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -5,6 +5,9 @@ steps: version: "1.10" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext command: | julia --project -e ' println("--- :julia: Instantiating project") diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 7979b22..9c14db2 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -7,6 +7,9 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext agents: queue: "juliagpu" cuda: "*" @@ -27,6 +30,9 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true + dirs: + - src + - ext env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 68b36fe..9860756 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v5 with: files: lcov.info @@ -73,6 +75,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext - uses: codecov/codecov-action@v5 with: files: lcov.info diff --git a/Project.toml b/Project.toml index 195fd47..4671de8 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +[weakdeps] +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + +[extensions] +NeuralOperatorsReactantExt = "Reactant" + [compat] ArgCheck = "2.3" ChainRulesCore = "1.24" @@ -28,6 +34,7 @@ LuxLib = "1.3.7" MLDataDevices = "1.5" NNlib = "0.9.24" Random = "1.10" +Reactant = "0.2.5" Static = "1.1.1" WeightInitializers = "1" julia = "1.10" diff --git a/ext/NeuralOperatorsReactantExt.jl b/ext/NeuralOperatorsReactantExt.jl new file mode 100644 index 0000000..19bc0ef --- /dev/null +++ b/ext/NeuralOperatorsReactantExt.jl @@ -0,0 +1,18 @@ +module NeuralOperatorsReactantExt + +using FFTW: FFTW +using NeuralOperators: NeuralOperators, FourierTransform +using Reactant: Reactant, TracedRArray + +# XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed +function NeuralOperators.transform(ft::FourierTransform, x::TracedRArray{T, N}) where {T, N} + x_c = Reactant.promote_to(TracedRArray{Complex{T}, N}, x) + return FFTW.fft(x_c, 1:ndims(ft)) +end + +function NeuralOperators.inverse( + ft::FourierTransform, x::TracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N} + return real(FFTW.ifft(x, 1:ndims(ft))) +end + +end From eeff7234458f042a042c7bfc61921bdc693e7a12 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Nov 2024 22:52:04 -0500 Subject: [PATCH 7/9] refactor: use `@jit` --- docs/src/tutorials/reactant.md | 12 ++++-------- src/layers.jl | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/src/tutorials/reactant.md b/docs/src/tutorials/reactant.md index a90836b..8077c22 100644 --- a/docs/src/tutorials/reactant.md +++ b/docs/src/tutorials/reactant.md @@ -21,8 +21,7 @@ u = rand(Float32, 64, 32) |> dev; y = rand(Float32, 1, 128, 32) |> dev; nothing # hide -deeponet_compiled = @compile deeponet((u, y), ps, st) -deeponet_compiled((u, y), ps, st)[1] +@jit deeponet((u, y), ps, st) ``` Computing the gradient of the DeepONet model. @@ -34,8 +33,7 @@ function ∇deeponet(model, ps, st, (u, y)) ) end -∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y)) -∇deeponet_compiled(deeponet, ps, st, (u, y)) +@jit ∇deeponet(deeponet, ps, st, (u, y)) ``` ## Compiling FourierNeuralOperator @@ -46,8 +44,7 @@ ps, st = Lux.setup(Random.default_rng(), fno) |> dev; x = rand(Float32, 2, 32, 5) |> dev; -fno_compiled = @compile fno(x, ps, st) -fno_compiled(x, ps, st)[1] +@jit fno(x, ps, st) ``` Computing the gradient of the FourierNeuralOperator model. @@ -59,6 +56,5 @@ function ∇fno(model, ps, st, x) ) end -∇fno_compiled = @compile ∇fno(fno, ps, st, x) -∇fno_compiled(fno, ps, st, x) +@jit ∇fno(fno, ps, st, x) ``` diff --git a/src/layers.jl b/src/layers.jl index 1f0cbd0..00bc5b3 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -77,7 +77,7 @@ function operator_conv(x, tform::AbstractTransform, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] x_padded = NNlib.pad_zeros(x_p, expand_pad_dims(pad_dims); - dims=ntuple(identity, ndims(x_p) - 2))::typeof(x_p) + dims=ntuple(identity, ndims(x_p) - 2)) return inverse(tform, x_padded, size(x)) end From 2c8449b460fae202ac9d9f2d6db55b121018a2cf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 09:36:14 +0530 Subject: [PATCH 8/9] fix: don't type assert reactant --- ext/NeuralOperatorsReactantExt.jl | 6 ++++++ src/layers.jl | 3 +-- src/utils.jl | 5 +++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ext/NeuralOperatorsReactantExt.jl b/ext/NeuralOperatorsReactantExt.jl index 19bc0ef..a04a290 100644 --- a/ext/NeuralOperatorsReactantExt.jl +++ b/ext/NeuralOperatorsReactantExt.jl @@ -2,6 +2,7 @@ module NeuralOperatorsReactantExt using FFTW: FFTW using NeuralOperators: NeuralOperators, FourierTransform +using NNlib: NNlib using Reactant: Reactant, TracedRArray # XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed @@ -15,4 +16,9 @@ function NeuralOperators.inverse( return real(FFTW.ifft(x, 1:ndims(ft))) end +function NeuralOperators.fast_pad_zeros(x::TracedRArray, pad_dims) + return NNlib.pad_zeros( + x, NeuralOperators.expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2)) +end + end diff --git a/src/layers.jl b/src/layers.jl index 00bc5b3..9a38630 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -76,8 +76,7 @@ function operator_conv(x, tform::AbstractTransform, weights) x_p = apply_pattern(x_tr, weights) pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)] - x_padded = NNlib.pad_zeros(x_p, expand_pad_dims(pad_dims); - dims=ntuple(identity, ndims(x_p) - 2)) + x_padded = fast_pad_zeros(x_p, pad_dims) return inverse(tform, x_padded, size(x)) end diff --git a/src/utils.jl b/src/utils.jl index 459a1b4..2cca4b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -51,3 +51,8 @@ function ∇safe_batched_adjoint( ::Type{<:AbstractGPUDevice}, Δ::AbstractArray{T, 3}) where {T} return NoTangent(), stack(adjoint, eachslice(Δ; dims=3)) end + +function fast_pad_zeros(x, pad_dims)::typeof(x) + return NNlib.pad_zeros( + x, expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2)) +end From 6cc113d54f9a094f82d065b5720dcc1677a5c55f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 09:55:02 +0530 Subject: [PATCH 9/9] fix: update to latest Reactant changes --- docs/Project.toml | 2 ++ ext/NeuralOperatorsReactantExt.jl | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 204864c..aee1e1e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -20,6 +20,7 @@ CairoMakie = "0.12.11" CondaPkg = "0.2.23" DataDeps = "0.7.13" Documenter = "1.7.0" +Enzyme = "0.13.24" Lux = "1.2.1" LuxCUDA = "0.3.3" MAT = "0.10.7" @@ -28,4 +29,5 @@ NeuralOperators = "0.5" Optimisers = "0.3.3, 0.4" Printf = "1.10" PythonCall = "0.9.23" +Reactant = "0.2.11" Zygote = "0.6.71" diff --git a/ext/NeuralOperatorsReactantExt.jl b/ext/NeuralOperatorsReactantExt.jl index a04a290..d3665aa 100644 --- a/ext/NeuralOperatorsReactantExt.jl +++ b/ext/NeuralOperatorsReactantExt.jl @@ -3,22 +3,29 @@ module NeuralOperatorsReactantExt using FFTW: FFTW using NeuralOperators: NeuralOperators, FourierTransform using NNlib: NNlib -using Reactant: Reactant, TracedRArray +using Reactant: Reactant, TracedRArray, AnyTracedRArray # XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed -function NeuralOperators.transform(ft::FourierTransform, x::TracedRArray{T, N}) where {T, N} - x_c = Reactant.promote_to(TracedRArray{Complex{T}, N}, x) +function NeuralOperators.transform( + ft::FourierTransform, x::AnyTracedRArray{T, N}) where {T, N} + x_c = Reactant.TracedUtils.promote_to( + TracedRArray{Complex{T}, N}, + Reactant.TracedUtils.materialize_traced_array(x) + ) return FFTW.fft(x_c, 1:ndims(ft)) end function NeuralOperators.inverse( - ft::FourierTransform, x::TracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N} + ft::FourierTransform, x::AnyTracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N} return real(FFTW.ifft(x, 1:ndims(ft))) end -function NeuralOperators.fast_pad_zeros(x::TracedRArray, pad_dims) +function NeuralOperators.fast_pad_zeros(x::AnyTracedRArray, pad_dims) return NNlib.pad_zeros( - x, NeuralOperators.expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2)) + Reactant.TracedUtils.materialize_traced_array(x), + NeuralOperators.expand_pad_dims(pad_dims); + dims=ntuple(identity, ndims(x) - 2) + ) end end