@@ -6,6 +6,7 @@ using ..Reactant:
66 AnyTracedRArray,
77 AnyTracedRMatrix,
88 AnyTracedRVector,
9+ AnyTracedRVecOrMat,
910 unwrapped_eltype,
1011 Ops,
1112 MLIR
@@ -347,4 +348,55 @@ function LinearAlgebra.ldiv!(
347348 return B
348349end
349350
351+ # Kronecker Product
352+ function LinearAlgebra. kron (
353+ x:: AnyTracedRVecOrMat{T1} , y:: AnyTracedRVecOrMat{T2}
354+ ) where {T1,T2}
355+ x = materialize_traced_array (x)
356+ y = materialize_traced_array (y)
357+ z = similar (x, Base. promote_op (* , T1, T2), LinearAlgebra. _kronsize (x, y))
358+ LinearAlgebra. kron! (z, x, y)
359+ return z
360+ end
361+
362+ function LinearAlgebra. kron (x:: AnyTracedRVector{T1} , y:: AnyTracedRVector{T2} ) where {T1,T2}
363+ x = materialize_traced_array (x)
364+ y = materialize_traced_array (y)
365+ z = similar (x, Base. promote_op (* , T1, T2), length (x) * length (y))
366+ LinearAlgebra. kron! (z, x, y)
367+ return z
368+ end
369+
370+ function LinearAlgebra. kron! (C:: AnyTracedRVector , A:: AnyTracedRVector , B:: AnyTracedRVector )
371+ LinearAlgebra. kron! (
372+ reshape (C, length (B), length (A)), reshape (A, 1 , length (A)), reshape (B, length (B), 1 )
373+ )
374+ return C
375+ end
376+
377+ function LinearAlgebra. _kron! (C:: AnyTracedRMatrix , A:: AnyTracedRMatrix , B:: AnyTracedRMatrix )
378+ A = materialize_traced_array (A)
379+ B = materialize_traced_array (B)
380+
381+ final_shape = Int64[size (B, 1 ), size (A, 1 ), size (B, 2 ), size (A, 2 )]
382+
383+ A = Ops. broadcast_in_dim (A, Int64[2 , 4 ], final_shape)
384+ B = Ops. broadcast_in_dim (B, Int64[1 , 3 ], final_shape)
385+
386+ C_tmp = Ops. reshape (Ops. multiply (A, B), size (C)... )
387+ set_mlir_data! (C, get_mlir_data (C_tmp))
388+
389+ return C
390+ end
391+
392+ function LinearAlgebra. _kron! (C:: AnyTracedRMatrix , A:: AnyTracedRVector , B:: AnyTracedRMatrix )
393+ LinearAlgebra. _kron! (C, reshape (A, length (A), 1 ), B)
394+ return C
395+ end
396+
397+ function LinearAlgebra. _kron! (C:: AnyTracedRMatrix , A:: AnyTracedRMatrix , B:: AnyTracedRVector )
398+ LinearAlgebra. _kron! (C, A, reshape (B, length (B), 1 ))
399+ return C
400+ end
401+
350402end
0 commit comments