Skip to content

Commit

Permalink
Merge pull request #480 from N5N3/fix
Browse files Browse the repository at this point in the history
Fix inference on `symmatrix` (patch for #472)
  • Loading branch information
mkitti authored Feb 25, 2022
2 parents 8a89ddf + 7fd7fbc commit 0fe41bf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
13 changes: 9 additions & 4 deletions src/b-splines/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,8 @@ symmatrix(h::NTuple{1,Any}) = SMatrix{1,1}(h)
symmatrix(h::NTuple{3,Any}) = SMatrix{2,2}((h[1], h[2], h[2], h[3]))
symmatrix(h::NTuple{6,Any}) = SMatrix{3,3}((h[1], h[2], h[3], h[2], h[4], h[5], h[3], h[5], h[6]))
function symmatrix(h::NTuple{L,Any}) where L
@noinline incommensurate(N,L) = error("$L must be equal to N*(N+1)/2 (N = $N)")
N = floor(Int, (2L)^(1//2))
(N*(N+1))÷2 == L || incommensurate(N,L)
l = Matrix{Int}(undef, N, N)
N = symsize(Val(L))
l = MMatrix{N,N,Int}(undef)
l[:,1] = 1:N
idx = N
for j = 2:N, i = 1:N
Expand All @@ -221,3 +219,10 @@ function symmatrix(h::NTuple{L,Any}) where L
SMatrix{N,N}(h[i] for i in l)
end
end

# Use @generated to force const propagation
@generated function symsize(::Val{L}) where L
N = floor(Int, sqrt(2L))
(N*(N+1))÷2 == L || error("$L must be equal to N*(N+1)/2 (N = $N)")
return :($N)
end
16 changes: 11 additions & 5 deletions test/issues/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,16 @@ using Interpolations, Test, ForwardDiff
@test ForwardDiff.gradient(itp_test, [3.0, 2.0]) Interpolations.gradient(itp, 3.0, 2.0)
end
@testset "issue 469" begin
dims = VERSION > v"1.7" ? 7 : 3 # symmatrix is unstable before 1.7, (^ was llvm based)
A = zeros(Float64, ntuple(_ -> 5, dims))
itp = interpolate(A, BSpline(Quadratic(Reflect(OnCell()))))
@test (@inferred Interpolations.hessian(itp, ntuple(_ -> 1.0, dims)...)) == zeros(dims,dims)
# We have different inference result on different version.
max_dim = VERSION < v"1.3" ? 3 : isdefined(Base, :Any32) ? 7 : 5
for dims in 3:max_dim
A = zeros(Float64, ntuple(_ -> 5, dims))
itp = interpolate(A, BSpline(Quadratic(Reflect(OnCell()))))
@test (@inferred Interpolations.hessian(itp, ntuple(_ -> 1.0, dims)...)) == zeros(dims,dims)
end
@test Interpolations.symsize(Val(36)) == 8
@test Interpolations.symsize(Val(45)) == 9
@test_throws ErrorException Interpolations.symsize(Val(2))
@test_throws ErrorException Interpolations.symsize(Val(33))
end

end

0 comments on commit 0fe41bf

Please sign in to comment.