Skip to content

Commit

Permalink
[RFC] Add Hessians for ScaledInterpolation and tests (#269)
Browse files Browse the repository at this point in the history
This also adds hessian support for in-bounds extrapolations. Supporting it for out-of-bounds points remains unimplemented.
  • Loading branch information
dkarrasch authored and timholy committed Nov 24, 2018
1 parent 3909cfb commit d8fbd11
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
19 changes: 19 additions & 0 deletions src/extrapolation/extrapolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ end
end
end

@inline function hessian(etp::AbstractExtrapolation{T,N}, x::Vararg{Number,N}) where {T,N}
itp = parent(etp)
if checkbounds(Bool, itp, x...)
hessian(itp, x...)
else
error("extrapolation of hessian not yet implemented")
# # copied from gradient above, with obvious modifications
# # but final part is missing
# eflag = tcollect(etpflag, etp)
# xs = inbounds_position(eflag, bounds(itp), x, etp, x)
# h = @inbounds hessian(itp, xs...)
# skipni = t->skip_flagged_nointerp(itp, t)
# # not sure if it should be just h here instead of Tuple(h)
# # extrapolate_hessian needs to be written
# # SVector is likely also wrong here
# SVector(extrapolate_hessian.(skipni(eflag), skipni(x), skipni(xs), Tuple(h)))
end
end

checkbounds(::Bool, ::AbstractExtrapolation, I...) = true

# The last two arguments are just for error-reporting
Expand Down
22 changes: 22 additions & 0 deletions src/scaling/scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ rescale_gradient(r::UnitRange, g) = g
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects.
""" rescale_gradient

@propagate_inbounds function hessian(sitp::ScaledInterpolation{T,N}, xs::Vararg{Number,N}) where {T,N}
@boundscheck (checkbounds(Bool, sitp, xs...) || Base.throw_boundserror(sitp, xs))
xl = maybe_clamp(sitp.itp, coordslookup(itpflag(sitp.itp), sitp.ranges, xs))
h = hessian(sitp.itp, xl...)
return rescale_hessian_components(itpflag(sitp.itp), sitp.ranges, h)
end

function rescale_hessian_components(flags, ranges, h)
steps = SVector(get_steps(flags, ranges))
return h ./ (steps .* steps')
end

function get_steps(flags, ranges)
if getfirst(flags) isa NoInterp
return get_steps(getrest(flags), Base.tail(ranges))
else
item = step(ranges[1])
return (item, get_steps(getrest(flags), Base.tail(ranges))...)
end
end
get_steps(flags, ::Tuple{}) = ()

### Iteration

struct ScaledIterator{SITPT,CI,WIS}
Expand Down
35 changes: 33 additions & 2 deletions test/scaling/nointerp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Interpolations, Test, LinearAlgebra, Random
xs = -pi:2pi/10:pi
f1(x) = sin(x)
f2(x) = cos(x)
f3(x) = sin(x) .* cos(x)
f3(x) = sin(x) * cos(x)
f(x,y) = y == 1 ? f1(x) : (y == 2 ? f2(x) : (y == 3 ? f3(x) : error("invalid value for y (must be 1, 2 or 3, you used $y)")))
ys = 1:3

Expand All @@ -15,11 +15,41 @@ using Interpolations, Test, LinearAlgebra, Random

for (ix,x0) in enumerate(xs[1:end-1]), y0 in ys
x,y = x0, y0
@test (sitp(x,y),f(x,y),atol=0.05)
@test sitp(x,y) f(x,y) atol=0.05
end

@test length(Interpolations.gradient(sitp, pi/3, 2)) == 1

xs = range(-pi, stop=pi, length=60)[1:end-1]
A = hcat(map(f1, xs), map(f2, xs), map(f3, xs))
itp = interpolate(A, (BSpline(Cubic(Periodic(OnGrid()))), NoInterp()))
sitp = scale(itp, xs, ys)

for x in xs, y in ys
if y in (1,2)
h = @inferred(Interpolations.hessian(sitp, x, y))
@test h[1] -f(x, y) atol=0.01
else # y==3
h = @inferred(Interpolations.hessian(sitp, x, y))
@test h[1] -4*f(x, y) atol=0.01
end
end

@test length(Interpolations.hessian(sitp, pi/3, 2)) == 1

A = hcat(map(f1, xs), map(f2, xs), map(f3, xs))
itp = interpolate(A', (NoInterp(), BSpline(Cubic(Periodic(OnGrid())))))
sitp = scale(itp, ys, xs)
h(y, x) = Interpolations.hessian(sitp, y, x)
h(1, xs[10])
for x in xs[6:end-6], y in ys
if y in (1,2)
@test h(y, x)[1] -f(x, y) atol=0.05
elseif y==3
@test h(y, x)[1] -4*f(x, y) atol=0.05
end
end

# check for case where initial/middle indices are NoInterp but later ones are <:BSpline
isdefined(Random, :seed!) ? Random.seed!(1234) : srand(1234) # `srand` was renamed to `seed!`
z0 = rand(10,10)
Expand All @@ -34,4 +64,5 @@ using Interpolations, Test, LinearAlgebra, Random
sitpb = scale(itpb, 1:10, rng)
@test Interpolations.gradient(sitpa, 3.0, 3) == Interpolations.gradient(sitpb, 3, 3.0)


end
18 changes: 16 additions & 2 deletions test/scaling/scaling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Interpolations
using Test, LinearAlgebra
using Test, LinearAlgebra, StaticArrays

@testset "Scaling" begin
# Model linear interpolation of y = -3 + .5x by interpolating y=x
Expand Down Expand Up @@ -41,7 +41,21 @@ using Test, LinearAlgebra

for x in -pi:.1:pi
g = @inferred(Interpolations.gradient(sitp, x))[1]
@test (cos(x),g,atol=0.05)
@test cos(x) g atol=0.05
end

# Test Hessians of scaled grids
xs = -pi:.1:pi
ys = -pi:.2:pi
zs = sin.(xs) .* sin.(ys')
itp = interpolate(zs, BSpline(Cubic(Line(OnGrid()))))
sitp = @inferred scale(itp, xs, ys)

for x in xs[2:end-1], y in ys[2:end-1]
h = @inferred(Interpolations.hessian(sitp, x, y))
@test issymmetric(h)
@test [-sin(x) * sin(y) cos(x) * cos(y)
cos(x) * cos(y) -sin(x) * sin(y)] h atol=0.03
end

# Verify that return types are reasonable
Expand Down

0 comments on commit d8fbd11

Please sign in to comment.