Skip to content

Commit

Permalink
Merge branch 'master' into dl/planbysize
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Nov 15, 2023
2 parents 77efdb0 + e1db1ca commit ca68b5c
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 155 deletions.
84 changes: 15 additions & 69 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,76 +55,22 @@ end
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))


for op in (:ldiv, :lmul)
op_dim_begin! = Symbol(string(op) * "_dim_begin!")
op_dim_end! = Symbol(string(op) * "_dim_end!")
op! = Symbol(string(op) * "!")
@eval begin
function $op_dim_begin!(α, d::Number, y::AbstractArray{<:Any,N}) where N
# scale just the d-th dimension by permuting it to the first
= PermutedDimsArray(y, _permfirst(d, N))
$op!(α, view(ỹ, 1, ntuple(_ -> :, Val(N-1))...))
end

ldiv_dim_begin!(α, d::Number, y::AbstractVector) = y[1] /= α
function ldiv_dim_begin!(α, d::Number, y::AbstractMatrix)
if isone(d)
ldiv!(α, @view(y[1,:]))
else
ldiv!(α, @view(y[:,1]))
end
end
function ldiv_dim_begin!(α, d::Number, y::AbstractArray{<:Any,3})
if isone(d)
ldiv!(α, @view(y[1,:,:]))
elseif d == 2
ldiv!(α, @view(y[:,1,:]))
else # d == 3
ldiv!(α, @view(y[:,:,1]))
end
end

ldiv_dim_end!(α, d::Number, y::AbstractVector) = y[end] /= α
function ldiv_dim_end!(α, d::Number, y::AbstractMatrix)
if isone(d)
ldiv!(α, @view(y[end,:]))
else
ldiv!(α, @view(y[:,end]))
end
end
function ldiv_dim_end!(α, d::Number, y::AbstractArray{<:Any,3})
if isone(d)
ldiv!(α, @view(y[end,:,:]))
elseif d == 2
ldiv!(α, @view(y[:,end,:]))
else # d == 3
ldiv!(α, @view(y[:,:,end]))
end
end

lmul_dim_begin!(α, d::Number, y::AbstractVector) = y[1] *= α
function lmul_dim_begin!(α, d::Number, y::AbstractMatrix)
if isone(d)
lmul!(α, @view(y[1,:]))
else
lmul!(α, @view(y[:,1]))
end
end
function lmul_dim_begin!(α, d::Number, y::AbstractArray{<:Any,3})
if isone(d)
lmul!(α, @view(y[1,:,:]))
elseif d == 2
lmul!(α, @view(y[:,1,:]))
else # d == 3
lmul!(α, @view(y[:,:,1]))
end
end

lmul_dim_end!(α, d::Number, y::AbstractVector) = y[end] *= α
function lmul_dim_end!(α, d::Number, y::AbstractMatrix)
if isone(d)
lmul!(α, @view(y[end,:]))
else
lmul!(α, @view(y[:,end]))
end
end
function lmul_dim_end!(α, d::Number, y::AbstractArray{<:Any,3})
if isone(d)
lmul!(α, @view(y[end,:,:]))
elseif d == 2
lmul!(α, @view(y[:,end,:]))
else # d == 3
lmul!(α, @view(y[:,:,end]))
function $op_dim_end!(α, d::Number, y::AbstractArray{<:Any,N}) where N
# scale just the d-th dimension by permuting it to the first
= PermutedDimsArray(y, _permfirst(d, N))
$op!(α, view(ỹ, size(ỹ,1), ntuple(_ -> :, Val(N-1))...))
end
end
end

Expand Down
183 changes: 97 additions & 86 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,96 +322,107 @@ using FastTransforms, Test
end

@testset "tensor" begin
X = randn(4,5,6)
= similar(X)
@testset "chebyshevtransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j]) end
@test @inferred(chebyshevtransform(X,1)) @inferred(chebyshevtransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j]) end
@test chebyshevtransform(X,2) chebyshevtransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:]) end
@test chebyshevtransform(X,3) chebyshevtransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j],Val(2)) end
@test @inferred(chebyshevtransform(X,Val(2),1)) @inferred(chebyshevtransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j],Val(2)) end
@test chebyshevtransform(X,Val(2),2) chebyshevtransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:],Val(2)) end
@test chebyshevtransform(X,Val(2),3) chebyshevtransform!(copy(X),Val(2),3)

@test @inferred(chebyshevtransform(X)) @inferred(chebyshevtransform!(copy(X))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,1),2),3)
@test @inferred(chebyshevtransform(X,Val(2))) @inferred(chebyshevtransform!(copy(X),Val(2))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)
end

@testset "ichebyshevtransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j]) end
@test @inferred(ichebyshevtransform(X,1)) @inferred(ichebyshevtransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j]) end
@test ichebyshevtransform(X,2) ichebyshevtransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:]) end
@test ichebyshevtransform(X,3) ichebyshevtransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j],Val(2)) end
@test @inferred(ichebyshevtransform(X,Val(2),1)) @inferred(ichebyshevtransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j],Val(2)) end
@test ichebyshevtransform(X,Val(2),2) ichebyshevtransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:],Val(2)) end
@test ichebyshevtransform(X,Val(2),3) ichebyshevtransform!(copy(X),Val(2),3)

@test @inferred(ichebyshevtransform(X)) @inferred(ichebyshevtransform!(copy(X))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,1),2),3)
@test @inferred(ichebyshevtransform(X,Val(2))) @inferred(ichebyshevtransform!(copy(X),Val(2))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)

@test ichebyshevtransform(chebyshevtransform(X)) X
@test chebyshevtransform(ichebyshevtransform(X)) X
end

@testset "chebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j]) end
@test @inferred(chebyshevutransform(X,1)) @inferred(chebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j]) end
@test chebyshevutransform(X,2) chebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:]) end
@test chebyshevutransform(X,3) chebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(chebyshevutransform(X,Val(2),1)) @inferred(chebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j],Val(2)) end
@test chebyshevutransform(X,Val(2),2) chebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:],Val(2)) end
@test chebyshevutransform(X,Val(2),3) chebyshevutransform!(copy(X),Val(2),3)

@test @inferred(chebyshevutransform(X)) @inferred(chebyshevutransform!(copy(X))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,1),2),3)
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)
@testset "3D" begin
X = randn(4,5,6)
= similar(X)
@testset "chebyshevtransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j]) end
@test @inferred(chebyshevtransform(X,1)) @inferred(chebyshevtransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j]) end
@test chebyshevtransform(X,2) chebyshevtransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:]) end
@test chebyshevtransform(X,3) chebyshevtransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevtransform(X[:,k,j],Val(2)) end
@test @inferred(chebyshevtransform(X,Val(2),1)) @inferred(chebyshevtransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevtransform(X[k,:,j],Val(2)) end
@test chebyshevtransform(X,Val(2),2) chebyshevtransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevtransform(X[k,j,:],Val(2)) end
@test chebyshevtransform(X,Val(2),3) chebyshevtransform!(copy(X),Val(2),3)

@test @inferred(chebyshevtransform(X)) @inferred(chebyshevtransform!(copy(X))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,1),2),3)
@test @inferred(chebyshevtransform(X,Val(2))) @inferred(chebyshevtransform!(copy(X),Val(2))) chebyshevtransform(chebyshevtransform(chebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)
end

@testset "ichebyshevtransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j]) end
@test @inferred(ichebyshevtransform(X,1)) @inferred(ichebyshevtransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j]) end
@test ichebyshevtransform(X,2) ichebyshevtransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:]) end
@test ichebyshevtransform(X,3) ichebyshevtransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevtransform(X[:,k,j],Val(2)) end
@test @inferred(ichebyshevtransform(X,Val(2),1)) @inferred(ichebyshevtransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevtransform(X[k,:,j],Val(2)) end
@test ichebyshevtransform(X,Val(2),2) ichebyshevtransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevtransform(X[k,j,:],Val(2)) end
@test ichebyshevtransform(X,Val(2),3) ichebyshevtransform!(copy(X),Val(2),3)

@test @inferred(ichebyshevtransform(X)) @inferred(ichebyshevtransform!(copy(X))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,1),2),3)
@test @inferred(ichebyshevtransform(X,Val(2))) @inferred(ichebyshevtransform!(copy(X),Val(2))) ichebyshevtransform(ichebyshevtransform(ichebyshevtransform(X,Val(2),1),Val(2),2),Val(2),3)

@test ichebyshevtransform(chebyshevtransform(X)) X
@test chebyshevtransform(ichebyshevtransform(X)) X
end

@testset "chebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j]) end
@test @inferred(chebyshevutransform(X,1)) @inferred(chebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j]) end
@test chebyshevutransform(X,2) chebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:]) end
@test chebyshevutransform(X,3) chebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(chebyshevutransform(X,Val(2),1)) @inferred(chebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j],Val(2)) end
@test chebyshevutransform(X,Val(2),2) chebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:],Val(2)) end
@test chebyshevutransform(X,Val(2),3) chebyshevutransform!(copy(X),Val(2),3)

@test @inferred(chebyshevutransform(X)) @inferred(chebyshevutransform!(copy(X))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,1),2),3)
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)
end

@testset "ichebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j]) end
@test @inferred(ichebyshevutransform(X,1)) @inferred(ichebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j]) end
@test ichebyshevutransform(X,2) ichebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:]) end
@test ichebyshevutransform(X,3) ichebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(ichebyshevutransform(X,Val(2),1)) @inferred(ichebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j],Val(2)) end
@test ichebyshevutransform(X,Val(2),2) ichebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:],Val(2)) end
@test ichebyshevutransform(X,Val(2),3) ichebyshevutransform!(copy(X),Val(2),3)

@test @inferred(ichebyshevutransform(X)) @inferred(ichebyshevutransform!(copy(X))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,1),2),3)
@test @inferred(ichebyshevutransform(X,Val(2))) @inferred(ichebyshevutransform!(copy(X),Val(2))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)

@test ichebyshevutransform(chebyshevutransform(X)) X
@test chebyshevutransform(ichebyshevutransform(X)) X
end

X = randn(1,1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
end

@testset "ichebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j]) end
@test @inferred(ichebyshevutransform(X,1)) @inferred(ichebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j]) end
@test ichebyshevutransform(X,2) ichebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:]) end
@test ichebyshevutransform(X,3) ichebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(ichebyshevutransform(X,Val(2),1)) @inferred(ichebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j],Val(2)) end
@test ichebyshevutransform(X,Val(2),2) ichebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:],Val(2)) end
@test ichebyshevutransform(X,Val(2),3) ichebyshevutransform!(copy(X),Val(2),3)

@test @inferred(ichebyshevutransform(X)) @inferred(ichebyshevutransform!(copy(X))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,1),2),3)
@test @inferred(ichebyshevutransform(X,Val(2))) @inferred(ichebyshevutransform!(copy(X),Val(2))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)

@test ichebyshevutransform(chebyshevutransform(X)) X
@test chebyshevutransform(ichebyshevutransform(X)) X
@testset "4D" begin
X = randn(2,3,4,5)
= similar(X)
for trans in (chebyshevtransform, ichebyshevtransform, chebyshevutransform, ichebyshevutransform)
for k = axes(X,2), j = axes(X,3), l = axes(X,4) X̃[:,k,j,l] = trans(X[:,k,j,l]) end
@test @inferred(trans(X,1))
@test @inferred(trans(X)) trans(trans(trans(trans(X,1),2),3),4)
end
end

X = randn(1,1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
end

@testset "Integer" begin
@test chebyshevtransform([1,2,3]) == chebyshevtransform([1.,2,3])
@test chebyshevtransform([1,2,3], Val(2)) == chebyshevtransform([1.,2,3], Val(2))
Expand Down

0 comments on commit ca68b5c

Please sign in to comment.