Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BlockTridiagonal #24

Merged
merged 11 commits into from
Mar 23, 2019
4 changes: 2 additions & 2 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
julia 0.7
BandedMatrices 0.8.2 0.9
BlockArrays 0.7 0.8
BandedMatrices 0.9
BlockArrays 0.8
FillArrays 0.3
LazyArrays 0.6
36 changes: 18 additions & 18 deletions src/BlockBandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,12 @@ module BlockBandedMatrices
using BlockArrays, BandedMatrices, LazyArrays, FillArrays, SparseArrays
using LinearAlgebra

import BlockArrays: BlockSizes, nblocks, blocksize, blockcheckbounds, global2blockindex,
Block, BlockSlice, getblock, unblock, setblock!, globalrange,
_unblock, _find_block, BlockIndexRange, blocksizes, cumulsizes,
AbstractBlockSizes

import BandedMatrices: isbanded, bandwidths, bandwidth, banded_getindex, colrange,
inbands_setindex!, inbands_getindex, banded_setindex!,
banded_generic_axpy!,
BlasFloat, banded_dense_axpy!, MemoryLayout,
BandedColumnMajor,
BandedSubBandedMatrix, bandeddata, tribandeddata,
_BandedMatrix, colstart, colstop, rowstart, rowstop,
BandedStyle, _fill_lmul!,
_banded_colval, _banded_rowval, _banded_nzval # for sparse

import Base: getindex, setindex!, checkbounds, @propagate_inbounds, convert,
+, *, -, /, \, strides, zeros, size,
unsafe_convert, fill!, length, first, last,
eltype, getindex, to_indices, to_index,
reindex, _maybetail, tail, @_propagate_inbounds_meta,
==, axes, copyto!, similar, OneTo
==, axes, copyto!, similar, OneTo, replace_in_print_matrix

import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted, broadcasted,
materialize, materialize!
Expand All @@ -42,8 +27,23 @@ import LazyArrays: AbstractStridedLayout, ColumnMajor, @lazymul, MatMulMatAdd, M
ArrayMulArrayStyle, AbstractColumnMajor, DenseColumnMajor, ColumnMajor,
DiagonalLayout, MatMulMat

export BandedBlockBandedMatrix, BlockBandedMatrix, BlockTridiagonalMatrix, BlockSkylineMatrix, blockbandwidth, blockbandwidths,
subblockbandwidth, subblockbandwidths, Ones, Zeros, Fill, Block
import BlockArrays: BlockSizes, nblocks, blocksize, blockcheckbounds, global2blockindex,
Block, BlockSlice, getblock, unblock, setblock!, globalrange,
_unblock, _find_block, BlockIndexRange, blocksizes, cumulsizes,
AbstractBlockSizes, sizes_from_blocks

import BandedMatrices: isbanded, bandwidths, bandwidth, banded_getindex, colrange,
inbands_setindex!, inbands_getindex, banded_setindex!,
banded_generic_axpy!,
BlasFloat, banded_dense_axpy!, MemoryLayout,
BandedColumnMajor,
BandedSubBandedMatrix, bandeddata, tribandeddata,
_BandedMatrix, colstart, colstop, rowstart, rowstop,
BandedStyle, _fill_lmul!,
_banded_colval, _banded_rowval, _banded_nzval # for sparse

export BandedBlockBandedMatrix, BlockBandedMatrix, BlockSkylineMatrix, blockbandwidth, blockbandwidths,
subblockbandwidth, subblockbandwidths, Ones, Zeros, Fill, Block, BlockTridiagonal


include("AbstractBlockBandedMatrix.jl")
Expand Down
1 change: 0 additions & 1 deletion src/BlockSkylineMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ struct BlockSkylineMatrix{T, LL<:AbstractVector{Int}, UU<:AbstractVector{Int}} <
end

const BlockBandedMatrix{T} = BlockSkylineMatrix{T, Fill{Int,1,Tuple{OneTo{Int}}}, Fill{Int,1,Tuple{OneTo{Int}}}}
const BlockTridiagonalMatrix{T} = BlockSkylineMatrix{T, Ones{Int,1,Tuple{OneTo{Int}}}, Ones{Int,1,Tuple{OneTo{Int}}}}

# Auxiliary outer constructors
@inline _BlockBandedMatrix(data::AbstractVector, bs::BlockBandedSizes) =
Expand Down
50 changes: 50 additions & 0 deletions src/interfaceimpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,53 @@ BroadcastStyle(::Type{<:SubKron{<:Any,<:Any,B,Block1,Block1}}) where B =

@inline bandwidths(V::SubKron{<:Any,<:Any,<:Any,Block1,Block1}) =
subblockbandwidths(parent(V))




# Block Tridiagonal
const BlockTridiagonal{T,VT<:Matrix{T}} = BlockMatrix{T,<:Tridiagonal{VT}}

BlockTridiagonal(A,B,C) = mortar(Tridiagonal(A,B,C))

function sizes_from_blocks(A::Tridiagonal, _)
# for k = 1:length(A.du)
# size(A.du[k],1) == sz[1][k] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
# size(A.du[k],2) == sz[2][k+1] || throw(ArgumentError("block sizes of upper diagonal inconsisent with diagonal"))
# size(A.dl[k],1) == sz[1][k+1] || throw(ArgumentError("block sizes of lower diagonal inconsisent with diagonal"))
# size(A.dl[k],2) == sz[2][k] || throw(ArgumentError("block sizes of lower diagonal inconsisent with diagonal"))
# end
BlockSizes(size.(A.d, 1), size.(A.d,2))
end

@inline function getblock(block_arr::BlockTridiagonal{T,VT}, K::Int, J::Int) where {T,VT<:AbstractMatrix}
@boundscheck blockcheckbounds(block_arr, K, J)
abs(J-K) ≥ 2 && return convert(VT, Zeros{T}(blocksize(block_arr,(K,J))))
block_arr.blocks[K,J]
end

function checksquareblocks(A)
m,n = cumulsizes(blocksizes(A))
m == n || throw(DimensionMismatch("blocks are not square: block dimensions are $(blocksizes(A))"))
m
end

for op in (:-, :+)
@eval begin
function $op(A::BlockTridiagonal, λ::UniformScaling)
checksquareblocks(A)
mortar(Tridiagonal(A.blocks.dl, broadcast($op, A.blocks.d, Ref(λ)), A.blocks.du))
end
function $op(λ::UniformScaling, A::BlockTridiagonal)
checksquareblocks(A)
mortar(Tridiagonal(A.blocks.dl, broadcast($op, Ref(λ), A.blocks.d), A.blocks.du))
end
end
end

function replace_in_print_matrix(A::BlockTridiagonal, i::Integer, j::Integer, s::AbstractString)
bi = global2blockindex(A.block_sizes, (i, j))
I,J = bi.I
i,j = bi.α
-1 ≤ J-I ≤ 1 ? s : Base.replace_with_centered_mark(s)
end
2 changes: 1 addition & 1 deletion test/test_bandedblockbanded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ import BlockBandedMatrices: _BandedBlockBandedMatrix, blockcolrange, blockrowran
end

@testset "BandedBlockBanded with BlockMatrix" begin
WithBlockMatrix{T} = BandedBlockBandedMatrix{T, BlockMatrix{T, Matrix{T}}}
WithBlockMatrix{T} = BandedBlockBandedMatrix{T, BlockMatrix{T, Matrix{Matrix{T}}}}
args = ([1, 2, 3], [2, 2, 1]), (1, 1), (1, 1)
A = WithBlockMatrix{Int64}(undef, args...)
B = BandedBlockBandedMatrix{Int64}(undef, A.block_sizes)
Expand Down
12 changes: 12 additions & 0 deletions test/test_misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ size(F::FiniteDifference) = (F.n,F.n)
@test subblockbandwidths(D*D_xx) == subblockbandwidths(D_xx)
end
end

@testset "Block Tridiagonal" begin
A = BlockTridiagonal(fill([1 2],3), fill([3 4],4), fill([4 5],3))
@test A[Block(1,1)] == [3 4]
@test @inferred(A[Block(1,2)]) == [4 5]
@test @inferred(getblock(A,1,3)) == @inferred(A[Block(1,3)]) == [0 0]
@test_throws DimensionMismatch A+I
A = BlockTridiagonal(fill([1 2; 1 2],3), fill([3 4; 3 4],4), fill([4 5; 4 5],3))
@test A+I == I+A == mortar(Tridiagonal(fill([1 2; 1 2],3), fill([4 4; 3 5],4), fill([4 5; 4 5],3))) == Matrix(A) + I
end