Skip to content

Commit

Permalink
fix: bypass segfault with fill complex
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 8, 2024
1 parent f34ec8f commit 36e5bed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
)
end
if isa(rhs, Number)
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
attr = MLIR.IR.DenseElementsAttribute(fill(T(rhs)))
return TracedRNumber{T}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
)
Expand Down
5 changes: 5 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,9 @@ end

pad_fn2 = Base.Fix2(NNlib.pad_constant, (1, 0, 1, 3))
@test @jit(∇sumabs2(pad_fn2, x_ra)) ∇sumabs2(pad_fn2, x)

x = rand(ComplexF32, 4, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(NNlib.pad_constant(x_ra, (1, 1))) NNlib.pad_constant(x, (1, 1))
end

0 comments on commit 36e5bed

Please sign in to comment.