Skip to content

Commit

Permalink
fix: don't type assert reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 14, 2024
1 parent 2e53eb8 commit c655138
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
6 changes: 6 additions & 0 deletions ext/NeuralOperatorsReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NeuralOperatorsReactantExt

using FFTW: FFTW
using NeuralOperators: NeuralOperators, FourierTransform
using NNlib: NNlib
using Reactant: Reactant, TracedRArray

# XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed
Expand All @@ -15,4 +16,9 @@ function NeuralOperators.inverse(
return real(FFTW.ifft(x, 1:ndims(ft)))
end

function NeuralOperators.fast_pad_zeros(x::TracedRArray, pad_dims)
return NNlib.pad_zeros(
x, NeuralOperators.expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2))
end

end
3 changes: 1 addition & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ function operator_conv(x, tform::AbstractTransform, weights)
x_p = apply_pattern(x_tr, weights)

pad_dims = size(x_t)[1:(end - 2)] .- size(x_p)[1:(end - 2)]
x_padded = NNlib.pad_zeros(x_p, expand_pad_dims(pad_dims);
dims=ntuple(identity, ndims(x_p) - 2))
x_padded = fast_pad_zeros(x_p, pad_dims)

return inverse(tform, x_padded, size(x))
end
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ function ∇safe_batched_adjoint(
::Type{<:AbstractGPUDevice}, Δ::AbstractArray{T, 3}) where {T}
return NoTangent(), stack(adjoint, eachslice(Δ; dims=3))
end

function fast_pad_zeros(x, pad_dims)::typeof(x)
return NNlib.pad_zeros(
x, expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2))
end

0 comments on commit c655138

Please sign in to comment.