Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add Hessians for ScaledInterpolation and tests #269

Merged
merged 9 commits into from
Nov 24, 2018
Merged
19 changes: 19 additions & 0 deletions src/extrapolation/extrapolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ end
end
end

@inline function hessian(etp::AbstractExtrapolation{T,N}, x::Vararg{Number,N}) where {T,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, thanks for adding it.

itp = parent(etp)
if checkbounds(Bool, itp, x...)
hessian(itp, x...)
else
error("extrapolation of hessian not yet implemented")
# # copied from gradient above, with obvious modifications
# # but final part is missing
# eflag = tcollect(etpflag, etp)
# xs = inbounds_position(eflag, bounds(itp), x, etp, x)
# h = @inbounds hessian(itp, xs...)
# skipni = t->skip_flagged_nointerp(itp, t)
# # not sure if it should be just h here instead of Tuple(h)
# # extrapolate_hessian needs to be written
# # SVector is likely also wrong here
# SVector(extrapolate_hessian.(skipni(eflag), skipni(x), skipni(xs), Tuple(h)))
end
end

checkbounds(::Bool, ::AbstractExtrapolation, I...) = true

# The last two arguments are just for error-reporting
Expand Down
22 changes: 22 additions & 0 deletions src/scaling/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ rescale_gradient(r::UnitRange, g) = g
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects.
""" rescale_gradient

@propagate_inbounds function hessian(sitp::ScaledInterpolation{T,N}, xs::Vararg{Number,N}) where {T,N}
@boundscheck (checkbounds(Bool, sitp, xs...) || Base.throw_boundserror(sitp, xs))
xl = maybe_clamp(sitp.itp, coordslookup(itpflag(sitp.itp), sitp.ranges, xs))
h = hessian(sitp.itp, xl...)
return rescale_hessian_components(itpflag(sitp.itp), sitp.ranges, h)
end

function rescale_hessian_components(flags, ranges, h)
steps = SVector(get_steps(flags, ranges))
return h ./ (steps .* steps')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out this is the source of non-inferrability, and it's really the fault of StaticArrays. You can replicate the problem with

using StaticArrays, Test
h = SMatrix{1,1}([1.0])
steps = SVector(1.0)
testinf(h, steps) = h ./ (steps .* steps')

julia> testinf(h, steps)
1×1 SArray{Tuple{1,1},Float64,2,1}:
 1.0

julia> @inferred testinf(h, steps)
ERROR: return type SArray{Tuple{1,1},Float64,2,1} does not match inferred return type Any
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at none:0

I can look into fixing this.

(This also implies that NoInterp is a red herring.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

function get_steps(flags, ranges)
if getfirst(flags) isa NoInterp
return get_steps(getrest(flags), Base.tail(ranges))
else
item = step(ranges[1])
return (item, get_steps(getrest(flags), Base.tail(ranges))...)
end
end
get_steps(flags, ::Tuple{}) = ()

### Iteration

struct ScaledIterator{SITPT,CI,WIS}
Expand Down
35 changes: 33 additions & 2 deletions test/scaling/nointerp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Interpolations, Test, LinearAlgebra, Random
xs = -pi:2pi/10:pi
f1(x) = sin(x)
f2(x) = cos(x)
f3(x) = sin(x) .* cos(x)
f3(x) = sin(x) * cos(x)
f(x,y) = y == 1 ? f1(x) : (y == 2 ? f2(x) : (y == 3 ? f3(x) : error("invalid value for y (must be 1, 2 or 3, you used $y)")))
ys = 1:3

Expand All @@ -15,11 +15,41 @@ using Interpolations, Test, LinearAlgebra, Random

for (ix,x0) in enumerate(xs[1:end-1]), y0 in ys
x,y = x0, y0
@test ≈(sitp(x,y),f(x,y),atol=0.05)
@test sitp(x,y)f(x,y) atol=0.05
end

@test length(Interpolations.gradient(sitp, pi/3, 2)) == 1

xs = range(-pi, stop=pi, length=60)[1:end-1]
A = hcat(map(f1, xs), map(f2, xs), map(f3, xs))
itp = interpolate(A, (BSpline(Cubic(Periodic(OnGrid()))), NoInterp()))
sitp = scale(itp, xs, ys)

for x in xs, y in ys
if y in (1,2)
h = @inferred(Interpolations.hessian(sitp, x, y))
@test h[1] ≈ -f(x, y) atol=0.01
else # y==3
h = @inferred(Interpolations.hessian(sitp, x, y))
@test h[1] ≈ -4*f(x, y) atol=0.01
end
end

@test length(Interpolations.hessian(sitp, pi/3, 2)) == 1

A = hcat(map(f1, xs), map(f2, xs), map(f3, xs))
itp = interpolate(A', (NoInterp(), BSpline(Cubic(Periodic(OnGrid())))))
sitp = scale(itp, ys, xs)
h(y, x) = Interpolations.hessian(sitp, y, x)
h(1, xs[10])
for x in xs[6:end-6], y in ys
if y in (1,2)
@test h(y, x)[1] ≈ -f(x, y) atol=0.05
elseif y==3
@test h(y, x)[1] ≈ -4*f(x, y) atol=0.05
end
end

# check for case where initial/middle indices are NoInterp but later ones are <:BSpline
isdefined(Random, :seed!) ? Random.seed!(1234) : srand(1234) # `srand` was renamed to `seed!`
z0 = rand(10,10)
Expand All @@ -34,4 +64,5 @@ using Interpolations, Test, LinearAlgebra, Random
sitpb = scale(itpb, 1:10, rng)
@test Interpolations.gradient(sitpa, 3.0, 3) == Interpolations.gradient(sitpb, 3, 3.0)


end
18 changes: 16 additions & 2 deletions test/scaling/scaling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Interpolations
using Test, LinearAlgebra
using Test, LinearAlgebra, StaticArrays

@testset "Scaling" begin
# Model linear interpolation of y = -3 + .5x by interpolating y=x
Expand Down Expand Up @@ -41,7 +41,21 @@ using Test, LinearAlgebra

for x in -pi:.1:pi
g = @inferred(Interpolations.gradient(sitp, x))[1]
@test ≈(cos(x),g,atol=0.05)
@test cos(x) ≈ g atol=0.05
end

# Test Hessians of scaled grids
xs = -pi:.1:pi
ys = -pi:.2:pi
zs = sin.(xs) .* sin.(ys')
itp = interpolate(zs, BSpline(Cubic(Line(OnGrid()))))
sitp = @inferred scale(itp, xs, ys)

for x in xs[2:end-1], y in ys[2:end-1]
h = @inferred(Interpolations.hessian(sitp, x, y))
@test issymmetric(h)
@test [-sin(x) * sin(y) cos(x) * cos(y)
cos(x) * cos(y) -sin(x) * sin(y)] ≈ h atol=0.03
end

# Verify that return types are reasonable
Expand Down