Skip to content

Commit

Permalink
Update wrappers.jl
Browse files Browse the repository at this point in the history
Fix incorrect definition of m and n in gemv_strided_batched!
  • Loading branch information
kose-y authored Aug 28, 2024
1 parent bbe625b commit 8b54f85
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,9 @@ for (fname, fname_64, eltyin, eltyout) in (
if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
end
m = size(A, trans == 'N' ? 1 : 2)
n = size(A, trans == 'N' ? 2 : 1)
if m != size(y, 1) || n != size(x, 1)
m = size(A, 1)
n = size(A, 2)
if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m)
throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))"))
end

Expand Down

0 comments on commit 8b54f85

Please sign in to comment.