-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Add Hessians for ScaledInterpolation and tests #269
Changes from all commits
3c69f88
8fb17fc
74ac76d
1fb7907
35cef09
1e33cb2
09fa5c5
42f8008
877e59f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,28 @@ rescale_gradient(r::UnitRange, g) = g | |
Implements the chain rule dy/dx = dy/du * du/dx for use when calculating gradients with scaled interpolation objects. | ||
""" rescale_gradient | ||
|
||
@propagate_inbounds function hessian(sitp::ScaledInterpolation{T,N}, xs::Vararg{Number,N}) where {T,N} | ||
@boundscheck (checkbounds(Bool, sitp, xs...) || Base.throw_boundserror(sitp, xs)) | ||
xl = maybe_clamp(sitp.itp, coordslookup(itpflag(sitp.itp), sitp.ranges, xs)) | ||
h = hessian(sitp.itp, xl...) | ||
return rescale_hessian_components(itpflag(sitp.itp), sitp.ranges, h) | ||
end | ||
|
||
function rescale_hessian_components(flags, ranges, h) | ||
steps = SVector(get_steps(flags, ranges)) | ||
return h ./ (steps .* steps') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out this is the source of non-inferrability, and it's really the fault of StaticArrays. You can replicate the problem with using StaticArrays, Test
h = SMatrix{1,1}([1.0])
steps = SVector(1.0)
testinf(h, steps) = h ./ (steps .* steps')
julia> testinf(h, steps)
1×1 SArray{Tuple{1,1},Float64,2,1}:
1.0
julia> @inferred testinf(h, steps)
ERROR: return type SArray{Tuple{1,1},Float64,2,1} does not match inferred return type Any
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] top-level scope at none:0 I can look into fixing this. (This also implies that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end | ||
|
||
function get_steps(flags, ranges) | ||
if getfirst(flags) isa NoInterp | ||
return get_steps(getrest(flags), Base.tail(ranges)) | ||
else | ||
item = step(ranges[1]) | ||
return (item, get_steps(getrest(flags), Base.tail(ranges))...) | ||
end | ||
end | ||
get_steps(flags, ::Tuple{}) = () | ||
|
||
### Iteration | ||
|
||
struct ScaledIterator{SITPT,CI,WIS} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this, thanks for adding it.