diff --git a/Project.toml b/Project.toml index 2254e8d..e6139de 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockDiagonals" uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.28" +version = "0.1.29" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/base_maths.jl b/src/base_maths.jl index c4eb24e..a42e356 100644 --- a/src/base_maths.jl +++ b/src/base_maths.jl @@ -92,14 +92,15 @@ function _mulblocksizes(bblocks, M::AbstractMatrix) return zip(size.(bblocks, 1), Base.Iterators.repeated(size(M, 2), length(bblocks))) end +# avoid ambiguities arising with AbstractVecOrMat Base.:*(B::BlockDiagonal, x::AbstractVector) = _mul(B, x) Base.:*(B::BlockDiagonal, X::AbstractMatrix) = _mul(B, X) -function _mul(B::BlockDiagonal{T}, x::AbstractVecOrMat) where {T} +function _mul(B::BlockDiagonal{T}, x::AbstractVecOrMat{T2}) where {T, T2} _check_matmul_dims(B, x) bblocks = blocks(B) new_blocksizes = _mulblocksizes(bblocks, x) - d = similar.(bblocks, T, new_blocksizes) + d = similar.(bblocks, promote_type(T, T2), new_blocksizes) ed = 0 @inbounds @views for (p, block) in enumerate(bblocks) st = ed + 1 # start @@ -109,11 +110,11 @@ function _mul(B::BlockDiagonal{T}, x::AbstractVecOrMat) where {T} return reduce(vcat, d) end -function Base.:*(M::AbstractMatrix, B::BlockDiagonal{T}) where T +function Base.:*(M::AbstractMatrix{T}, B::BlockDiagonal{T2}) where {T, T2} _check_matmul_dims(M, B) bblocks = blocks(B) new_blocksizes = zip(fill(size(M, 1), length(bblocks)), size.(bblocks, 2)) - d = similar.(bblocks, T, new_blocksizes) + d = similar.(bblocks, promote_type(T, T2), new_blocksizes) ed = 0 @inbounds @views for (p, block) in enumerate(bblocks) st = ed + 1 # start @@ -124,9 +125,9 @@ function Base.:*(M::AbstractMatrix, B::BlockDiagonal{T}) where T end # Diagonal -function Base.:*(B::BlockDiagonal, M::Diagonal)::BlockDiagonal +function Base.:*(B::BlockDiagonal{T}, M::Diagonal{T2})::BlockDiagonal where {T, T2} _check_matmul_dims(B, M) - A = similar(B) + A = similar(B, promote_type(T, T2)) d = parent(M) col = 1 @inbounds @views for (p, block) in enumerate(blocks(B)) @@ -140,9 +141,9 @@ function Base.:*(B::BlockDiagonal, M::Diagonal)::BlockDiagonal return A end -function Base.:*(M::Diagonal, B::BlockDiagonal)::BlockDiagonal +function Base.:*(M::Diagonal{T}, B::BlockDiagonal{T2})::BlockDiagonal where {T, T2} _check_matmul_dims(M, B) - A = similar(B) + A = similar(B, promote_type(T, T2)) d = parent(M) row = 1 @inbounds @views for (p, block) in enumerate(blocks(B)) diff --git a/src/blockdiagonal.jl b/src/blockdiagonal.jl index f418964..b9c2240 100644 --- a/src/blockdiagonal.jl +++ b/src/blockdiagonal.jl @@ -121,6 +121,7 @@ Base.collect(B::BlockDiagonal) = Matrix(B) Base.size(B::BlockDiagonal) = sum(first∘size, blocks(B)), sum(last∘size, blocks(B)) Base.similar(B::BlockDiagonal) = BlockDiagonal(map(similar, blocks(B))) +Base.similar(B::BlockDiagonal, ::Type{T}) where T = BlockDiagonal(map(b -> similar(b, T), blocks(B))) Base.parent(B::BlockDiagonal) = B.blocks @propagate_inbounds function Base.setindex!(B::BlockDiagonal, v, i::Integer, j::Integer) diff --git a/test/base_maths.jl b/test/base_maths.jl index c1cbeab..9b775e7 100644 --- a/test/base_maths.jl +++ b/test/base_maths.jl @@ -16,6 +16,9 @@ using Test a = rand(rng, N) b = rand(rng, N + N1) + b64 = BlockDiagonal([rand(rng, 2, 2), rand(rng, 2, 2)]) + b32 = BlockDiagonal([rand(rng, Float32, 2, 2), rand(rng, Float32, 2, 2)]) + @testset "Addition" begin @testset "BlockDiagonal + BlockDiagonal" begin @test b1 + b1 isa BlockDiagonal @@ -94,6 +97,10 @@ using Test @test b1 * a isa Vector @test b1 * a ≈ Matrix(b1) * a @test_throws DimensionMismatch b1 * b + + # promote_type + @test b32 * rand(4) isa Vector{Float64} + @test b64 * rand(Float32, 4) isa Vector{Float64} end @testset "Vector^T * BlockDiagonal" begin @test a' * b1 isa Adjoint{<:Number, <:Vector} @@ -118,6 +125,12 @@ using Test m = rand(5, 0) @test m' * BlockDiagonal([m]) == m' * m == rand(0, 0) @test m * BlockDiagonal([m']) == m * m' == zeros(5, 5) + + # promote_type + @test b32 * rand(4, 4) isa Matrix{Float64} + @test rand(4, 4) * b32 isa Matrix{Float64} + @test b64 * rand(Float32, 4, 4) isa Matrix{Float64} + @test rand(Float32, 4, 4) * b64 isa Matrix{Float64} end @testset "BlockDiagonal * Diagonal" begin @@ -132,6 +145,10 @@ using Test @test D * b1 isa BlockDiagonal @test D * b1 ≈ D * Matrix(b1) @test_throws DimensionMismatch D′ * b1 + + # promote_type + @test b32 * Diagonal(rand(4)) isa BlockDiagonal{Float64} + @test Diagonal(rand(4)) * b32 isa BlockDiagonal{Float64} end @testset "Non-Square BlockDiagonal * Non-Square BlockDiagonal" begin diff --git a/test/blockdiagonal.jl b/test/blockdiagonal.jl index e43d707..100f11b 100644 --- a/test/blockdiagonal.jl +++ b/test/blockdiagonal.jl @@ -56,6 +56,8 @@ using Test @test similar(b1) isa BlockDiagonal @test size(similar(b1)) == size(b1) @test size.(blocks(similar(b1))) == size.(blocks(b1)) + + @test similar(b1, Float32) isa BlockDiagonal{Float32} end @testset "setindex!" begin