You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
julia> using CUDA, NNlib
julia> A, B = randn(2,5,17), randn(5,9,17);
julia> A ⊠ B;
julia> Ag = CuArray(A);
julia> Bg = CuArray(B);
julia> Ag ⊠ Bg
ERROR: MethodError: no method matching _batched_gemm!(::Type{CuArray{Float64, 3}}, ::Char, ::Char, ::Float64, ::CuArray{Float64, 3}, ::CuArray{Float64, 3}, ::Float64, ::CuArray{Float64, 3})
Closest candidates are:
_batched_gemm!(::Type{var"#s117"} where var"#s117"<:Array, ::Char, ::Char, ::Number, ::Any, ::Any, ::Number, ::Any) at /groups/scicompsoft/home/arthurb/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:260
Stacktrace:
[1] _batched_try_gemm!(#unused#::Type{CuArray{Float64, 3}}, C::CuArray{Float64, 3}, A::CuArray{Float64, 3}, B::CuArray{Float64, 3}, α::Float64, β::Float64)
@ NNlib ~/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:256
[2] _batched_mul!(#unused#::Type{CuArray{Float64, 3}}, C::CuArray{Float64, 3}, A::CuArray{Float64, 3}, B::CuArray{Float64, 3}, α::Float64, β::Float64)
@ NNlib ~/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:219
[3] batched_mul!(C::CuArray{Float64, 3}, A::CuArray{Float64, 3}, B::CuArray{Float64, 3}, α::Float64, β::Float64) (repeats 2 times)
@ NNlib ~/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:213
[4] _batched_mul(#unused#::Type{CuArray{Float64, 3}}, A::CuArray{Float64, 3}, B::CuArray{Float64, 3})
@ NNlib ~/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:69
[5] batched_mul(A::CuArray{Float64, 3}, B::CuArray{Float64, 3})
@ NNlib ~/.julia/packages/NNlib/LiXUT/src/batched/batchedmul.jl:56
[6] top-level scope
@ REPL[6]:1
[7] top-level scope
@ ~/.julia/packages/CUDA/3VnCC/src/initialization.jl:81
(@v1.6) pkg> st
Status `~/.julia/environments/v1.6/Project.toml`
[6e4b80f9] BenchmarkTools v1.0.0
[052768ef] CUDA v3.2.1
[34f1f09b] ClusterManagers v0.4.0
[31c24e10] Distributions v0.25.2
[5903a43b] Infiltrator v0.3.0
[4138dd39] JLD v0.12.3
[872c559c] NNlib v0.7.20
[132c30aa] ProfileSVG v0.2.1
julia> VERSION
v"1.6.1"
help?> batched_mul
search: batched_mul batched_vec batched_adjoint batched_transpose
batched_mul(A, B) -> C
A ⊠ B # \boxtimes
Batched matrix multiplication. Result has C[:,:,k] == A[:,:,k] * B[:,:,k] for all k. If
size(B,3) == 1 then instead C[:,:,k] == A[:,:,k] * B[:,:,1], and similarly for A.
To transpose each matrix, apply batched_transpose to the array, or batched_adjoint for
conjugate-transpose:
julia> A, B = randn(2,5,17), randn(5,9,17);
julia> A ⊠ B |> size
(2, 9, 17)
julia> batched_adjoint(A) |> size
(5, 2, 17)
julia> batched_mul(A, batched_adjoint(randn(9,5,17))) |> size
(2, 9, 17)
julia> A ⊠ randn(5,9,1) |> size
(2, 9, 17)
julia> batched_transpose(A) == PermutedDimsArray(A, (2,1,3))
true
The equivalent PermutedDimsArray may be used in place of batched_transpose. Other
permutations are also handled by BLAS, provided that the batch index k is not the first
dimension of the underlying array. Thus PermutedDimsArray(::Array, (1,3,2)) and
PermutedDimsArray(::Array, (3,1,2)) are fine.
However, A = PermutedDimsArray(::Array, (3,2,1)) is not acceptable to BLAS, since the batch
dimension is the contiguous one: stride(A,3) == 1. This will be copied, as doing so is
faster than batched_mul_generic!.
Both this copy and batched_mul_generic! produce @debug messages, and setting for instance
ENV["JULIA_DEBUG"] = NNlib will display them.
────────────────────────────────────────────────────────────────────────────────────────────
batched_mul(A::Array{T,3}, B::Matrix)
batched_mul(A::Matrix, B::Array{T,3})
A ⊠ B
This is always matrix-matrix multiplication, but either A or B may lack a batch index.
• When B is a matrix, result has C[:,:,k] == A[:,:,k] * B[:,:] for all k.
• When A is a matrix, then C[:,:,k] == A[:,:] * B[:,:,k]. This can also be done by
reshaping and calling *, for instance A ⊡ B using TensorCore.jl, but is
implemented here using batched_gemm instead of gemm.
julia> randn(16,8,32) ⊠ randn(8,4) |> size
(16, 4, 32)
julia> randn(16,8,32) ⊠ randn(8,4,1) |> size # equivalent
(16, 4, 32)
julia> randn(16,8) ⊠ randn(8,4,32) |> size
(16, 4, 32)
See also batched_vec to regard B as a batch of vectors, A[:,:,k] * B[:,k].
The text was updated successfully, but these errors were encountered:
The text was updated successfully, but these errors were encountered: