Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff with min, max, and clamp #640

Open
smkatz12 opened this issue Apr 30, 2024 · 3 comments
Open

ForwardDiff with min, max, and clamp #640

smkatz12 opened this issue Apr 30, 2024 · 3 comments

Comments

@smkatz12
Copy link

Hi! We are working on a textbook chapter on reachability analysis and are interested in support for intervals when using ForwardDiff.jl to compute gradients and hessians of functions that use the min, max, and clamp. The current behavior with these functions does not cause an error but is incorrect. We have written the functions below to correct for this. These functions could be added to IntervalArithmeticForwardDiffExt.jl.

function Base.max(x::Dual{T,V,N}, y::AbstractFloat) where {T,V<:Interval,N}
    if value(x).hi < y
        return Dual{T,V,N}(y..y, (0..0) * partials(x))
    elseif value(x).lo > y
        return Dual{T,V,N}(value(x), (1..1) * partials(x))
    else
        return Dual{T,V,N}(y..value(x).hi, (0..1) * partials(x))
    end
end

function Base.max(x::Dual{T,Dual{T2,V2,N2},N}, y::AbstractFloat) where {T,T2,V2<:Interval,N2,N}
    if value(value(x)).hi < y
        return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(y..y), (0..0) * partials(x))
    elseif value(value(x)).lo > y
        return Dual{T,Dual{T2,V2,N2},N}(value(x), (1..1) * partials(x))
    else
        return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(y..value(value(x)).hi, partials(value(x))), (0..1) * partials(x))
    end
end

function Base.min(x::Dual{T,V,N}, y::AbstractFloat) where {T,V<:Interval,N}
    if value(x).lo > y
        return Dual{T,V,N}(y..y, (0..0) * partials(x))
    elseif value(x).hi < y
        return Dual{T,V,N}(value(x), (1..1) * partials(x))
    else
        return Dual{T,V,N}(value(x).lo..y, (0..1) * partials(x))
    end
end

function Base.min(x::Dual{T,Dual{T2,V2,N2},N}, y::AbstractFloat) where {T,T2,V2<:Interval,N2,N}
    if value(value(x)).lo > y
        return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(y..y), (0..0) * partials(x))
    elseif value(value(x)).hi < y
        return Dual{T,Dual{T2,V2,N2},N}(value(x), (1..1) * partials(x))
    else
        return Dual{T,Dual{T2,V2,N2},N}(Dual{T2,V2,N2}(value(value(x)).lo..y, partials(value(x))), (0..1) * partials(x))
    end
end

function Base.clamp(i::Dual{T,V,N}, lo::AbstractFloat, hi::AbstractFloat) where {T,V<:Interval,N}
    return min(max(i, lo), hi)
end

function Base.clamp(i::Dual{T,Dual{T2,V2,N2},N}, lo::AbstractFloat, hi::AbstractFloat) where {T,T2,V2<:Interval,N2,N}
    return min(max(i, lo), hi)
end

Example:

function f(x)
    return x[1]^2 + clamp(x[2], 1.5, 2.5)^3
end

ForwardDiff.hessian(f, [2..3, 1..2])

Current output:

2×2 Matrix{Interval{Float64}}:
 [2, 2]   [0, 0]
 [0, 0]  [6, 12]

Correct output using code above:

2×2 Matrix{Interval{Float64}}:
 [2, 2]   [0, 0]
 [0, 0]  [0, 12]
@OlivierHnt
Copy link
Member

Thx for opening an issue.

What version of IntervalArithmetic are you using? On 0.22.11 the example you gave returns an error, as it should since no one implemented the functions you describe.

That being said, it would be nice it improve support for ForwardDiff. Then we must also address the question of decorations.

@smkatz12
Copy link
Author

smkatz12 commented May 2, 2024

Thanks for the quick reply! I was using an old version of IntervalArithmetic. I updated and confirmed that I now get an error when running that code. In terms of adding support for these functions, do you want us to submit a pull request? What are your recommendations for decorations?

@Kolaru
Copy link
Collaborator

Kolaru commented May 3, 2024

A PR would be great :)

We'll review the PR in more details, but there is a bunch of subtleties to take into account:

  • The decoration (as you mention): the decoration should be :com (the functions are continuous and well defined[1]).
  • The guarantee: when floating point number and intervals are mixed, the guarantee should always be false (in particular we need to be careful when creating an interval "by hand" from floating point numbers).
  • Interval function: it seems like we don't support clamp at all currently, would be nice to add it.
  • Interval-interval inputs: it could be good to have the derivative for things like min(1..2, 0.5..3), derived with respect to any argument.
  • .. has been moved to a submodule (IntervalArithmetic.Symbols), internally we currently prefer the interval constructor.

I hope this is not overwhelming, please let us know if you need help or advice about anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants