Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ogauthe committed Sep 17, 2024
1 parent dc762be commit 32d0f11
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using NDTensors.GradedAxes:
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
using LinearAlgebra: adjoint
using Random: randn!
function blockdiagonal!(f, a::AbstractArray)
for i in 1:minimum(blocksize(a))
Expand Down Expand Up @@ -38,8 +39,6 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
@test nstored(b) == 32
@test block_nstored(b) == 2
# TODO: Have to investigate why this fails
# on Julia v1.6, or drop support for v1.6.
for i in 1:ndims(a)
@test axes(b, i) isa GradedOneTo
end
Expand Down Expand Up @@ -158,11 +157,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for ax in axes(b)
@test ax isa GradedUnitRangeDual
@test a[:, :] isa BlockSparseArray # broken in 1.6
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
end

I = [Block(1)[1:1]]
@test_broken a[I, :]
@test_broken a[:, I]
Expand All @@ -179,9 +178,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for ax in axes(b)
@test ax isa GradedUnitRangeDual
@test a[:, :] isa BlockSparseArray # broken in 1.6
for i in 1:2
@test axes(b, i) isa GradedUnitRangeDual
@test_broken axes(a[:, :], i) isa GradedUnitRangeDual
end

I = [Block(1)[1:1]]
Expand All @@ -191,7 +191,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test_broken GradedAxes.isdual(axes(a[I, I], 1))
end

@testset "BlockedUnitRange" begin
@testset "BlockedUnitRange" begin # self dual
r = blockedrange([2, 2])
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
Expand All @@ -201,8 +201,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
@test a[:, :] isa BlockSparseArray
for ax in axes(b)
@test ax isa BlockedOneTo
for i in 1:2
@test axes(b, i) isa BlockedOneTo
@test axes(a[:, :], i) isa BlockedOneTo
end

I = [Block(1)[1:1]]
Expand All @@ -226,7 +227,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test ax isa typeof(dual(r))
end

@test a[:, :] isa BlockSparseArray
@test a[:, :] isa BlockSparseArray # broken in 1.6
@test axes(a[:, :]) isa Tuple{BlockedOneTo,BlockedOneTo} # broken in 1.6

I = [Block(1)[1:1]]
@test size(a[I, :]) == (1, 4)
Expand Down
10 changes: 7 additions & 3 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@eval module $(gensym())
using BlockArrays:
Block,
BlockedOneTo,
blockaxes,
blockedrange,
blockfirsts,
Expand Down Expand Up @@ -32,19 +33,22 @@ Base.isless(c1::U1, c2::U1) = c1.n < c2.n

@testset "AbstractUnitRange" begin
a0 = OneToOne()
@test gradedisequal(a0, dual(a0))
@test !isdual(a0)
@test dual(a0) isa OneToOne
@test gradedisequal(a0, dual(a0))

a = 1:3
ad = dual(a)
@test !isdual(ad)
@test !isdual(a)
@test !isdual(ad)
@test ad isa UnitRange
@test gradedisequal(ad, a)

a = blockedrange([2, 3])
ad = dual(a)
@test !isdual(ad)
@test !isdual(a)
@test !isdual(ad)
@test ad isa BlockedOneTo
@test gradedisequal(ad, a)
end

Expand Down

0 comments on commit 32d0f11

Please sign in to comment.