diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 6105dcb71..530d8c1a9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -86,7 +86,7 @@ function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} ) end if isa(rhs, Number) - attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T})) + attr = MLIR.IR.DenseElementsAttribute(fill(T(rhs))) return TracedRNumber{T}( (), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) ) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 5b2f8bf1f..91a1a8d33 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -157,4 +157,9 @@ end pad_fn2 = Base.Fix2(NNlib.pad_constant, (1, 0, 1, 3)) @test @jit(∇sumabs2(pad_fn2, x_ra)) ≈ ∇sumabs2(pad_fn2, x) + + x = rand(ComplexF32, 4, 4) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1)) end