Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched_mul! doesn't support arrays of dimension larger than 3 #529

Open
ZhaoFancy opened this issue Aug 31, 2023 · 1 comment
Open

batched_mul! doesn't support arrays of dimension larger than 3 #529

ZhaoFancy opened this issue Aug 31, 2023 · 1 comment

Comments

@ZhaoFancy
Copy link

Motivation and description

On the other hand, batch_mul supports.

Possible Implementation

No response

@AntonOresten
Copy link

AntonOresten commented Nov 14, 2024

I looked into this. Here's a possible implementation that reshapes the input arrays (referencing the same data):

function batched_mul!(C::AbstractArray{T,N}, A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N},
        α::Number=one(T), β::Number=zero(T)) where {T,N}
    batch_size = size(C)[3:end]
    @assert batch_size == size(A)[3:end] "batch size has to be the same for arrays."
    @assert batch_size == size(B)[3:end] "batch size has to be the same for arrays."
    
    C2 = reshape(C, size(C,1), size(C,2), :)
    A2 = reshape(A, size(A,1), size(A,2), :)
    B2 = reshape(B, size(B,1), size(B,2), :)
    
    batched_mul!(C2, A2, B2, α, β)
    return C
end
julia> A = randn(30, 40, 30, 60);

julia> B = randn(40, 50, 30, 60);

julia> C = similar(A, 30, 50, 30, 60);

julia> @btime batched_mul(A, B);
  7.785 ms (101 allocations: 20.61 MiB)

julia> @btime batched_mul!(C, A, B);
  6.969 ms (98 allocations: 12.88 KiB)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants