Skip to content

Commit 534bea3

Browse files
authored
feat: overload LinearAlgebra.kron (#607)
* feat: overload LinearAlgebra.kron * test: kron
1 parent 635f35c commit 534bea3

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ..Reactant:
66
AnyTracedRArray,
77
AnyTracedRMatrix,
88
AnyTracedRVector,
9+
AnyTracedRVecOrMat,
910
unwrapped_eltype,
1011
Ops,
1112
MLIR
@@ -347,4 +348,55 @@ function LinearAlgebra.ldiv!(
347348
return B
348349
end
349350

351+
# Kronecker Product
352+
function LinearAlgebra.kron(
353+
x::AnyTracedRVecOrMat{T1}, y::AnyTracedRVecOrMat{T2}
354+
) where {T1,T2}
355+
x = materialize_traced_array(x)
356+
y = materialize_traced_array(y)
357+
z = similar(x, Base.promote_op(*, T1, T2), LinearAlgebra._kronsize(x, y))
358+
LinearAlgebra.kron!(z, x, y)
359+
return z
360+
end
361+
362+
function LinearAlgebra.kron(x::AnyTracedRVector{T1}, y::AnyTracedRVector{T2}) where {T1,T2}
363+
x = materialize_traced_array(x)
364+
y = materialize_traced_array(y)
365+
z = similar(x, Base.promote_op(*, T1, T2), length(x) * length(y))
366+
LinearAlgebra.kron!(z, x, y)
367+
return z
368+
end
369+
370+
function LinearAlgebra.kron!(C::AnyTracedRVector, A::AnyTracedRVector, B::AnyTracedRVector)
371+
LinearAlgebra.kron!(
372+
reshape(C, length(B), length(A)), reshape(A, 1, length(A)), reshape(B, length(B), 1)
373+
)
374+
return C
375+
end
376+
377+
function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRMatrix)
378+
A = materialize_traced_array(A)
379+
B = materialize_traced_array(B)
380+
381+
final_shape = Int64[size(B, 1), size(A, 1), size(B, 2), size(A, 2)]
382+
383+
A = Ops.broadcast_in_dim(A, Int64[2, 4], final_shape)
384+
B = Ops.broadcast_in_dim(B, Int64[1, 3], final_shape)
385+
386+
C_tmp = Ops.reshape(Ops.multiply(A, B), size(C)...)
387+
set_mlir_data!(C, get_mlir_data(C_tmp))
388+
389+
return C
390+
end
391+
392+
function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRVector, B::AnyTracedRMatrix)
393+
LinearAlgebra._kron!(C, reshape(A, length(A), 1), B)
394+
return C
395+
end
396+
397+
function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRVector)
398+
LinearAlgebra._kron!(C, A, reshape(B, length(B), 1))
399+
return C
400+
end
401+
350402
end

test/integration/linear_algebra.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,17 @@ mul_symmetric(x) = Symmetric(x) * x
169169
@test @jit(fn(x_ra)) fn(x)
170170
end
171171
end
172+
173+
@testset "kron" begin
174+
@testset for T in (Int64, Float64, ComplexF64)
175+
@testset for (x_sz, y_sz) in [
176+
((3, 4), (2, 5)), ((3, 4), (2,)), ((3,), (2, 5)), ((3,), (5,)), ((10,), ())
177+
]
178+
x = x_sz == () ? rand(T) : rand(T, x_sz)
179+
y = y_sz == () ? rand(T) : rand(T, y_sz)
180+
x_ra = Reactant.to_rarray(x; track_numbers=Number)
181+
y_ra = Reactant.to_rarray(y; track_numbers=Number)
182+
@test @jit(kron(x_ra, y_ra)) kron(x, y)
183+
end
184+
end
185+
end

0 commit comments

Comments
 (0)