Skip to content

Commit

Permalink
[Nonlinear] add support for univariate sign (#2444)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Feb 27, 2024
1 parent 41dce2e commit 4be42b4
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/FileFormats/NL/NLExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ const _UNARY_SPECIAL_CASES = Dict{Symbol,Function}(
:asech => (x) -> :(acosh(1 / $x)),
:acsch => (x) -> :(asinh(1 / $x)),
:acoth => (x) -> :(atanh(1 / $x)),
:sign => (x) -> :(ifelse($x >= 0, 1, -1)),
)

"""
Expand Down
4 changes: 2 additions & 2 deletions src/Nonlinear/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ The list of univariate operators that are supported by default.
julia> import MathOptInterface as MOI
julia> MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS
72-element Vector{Symbol}:
73-element Vector{Symbol}:
:+
:-
:abs
:sign
:sqrt
:cbrt
:abs2
:inv
:log
:log10
:log2
:airybi
:airyaiprime
Expand Down
1 change: 1 addition & 0 deletions src/Nonlinear/univariate_expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const SYMBOLIC_UNIVARIATE_EXPRESSIONS = Tuple{Symbol,Expr,Any}[
(:+, :(one(x)), :(zero(x))),
(:-, :(-one(x)), :(zero(x))),
(:abs, :(ifelse(x >= 0, one(x), -one(x))), :(zero(x))),
(:sign, :(zero(x)), :(zero(x))),
(:sqrt, :(0.5 / sqrt(x)), :((0.5 * -(0.5 / sqrt(x))) / sqrt(x) ^ 2)),
(:cbrt, :(0.3333333333333333 / cbrt(x) ^ 2), :((0.3333333333333333 * -(2 * (0.3333333333333333 / cbrt(x) ^ 2) * cbrt(x))) / (cbrt(x) ^ 2) ^ 2)),
(:abs2, :(2x), :((typeof(x))(2))),
Expand Down
3 changes: 2 additions & 1 deletion src/Nonlinear/univariate_expressions_generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ open("univariate_expressions.jl", "w") do io
const SYMBOLIC_UNIVARIATE_EXPRESSIONS = Tuple{Symbol,Expr,Any}[
(:+, :(one(x)), :(zero(x))),
(:-, :(-one(x)), :(zero(x))),
(:abs, :(ifelse(x >= 0, one(x), -one(x))), :(zero(x))),""",
(:abs, :(ifelse(x >= 0, one(x), -one(x))), :(zero(x))),
(:sign, :(zero(x)), :(zero(x))),""",
)
for (op, deriv) in Calculus.symbolic_derivatives_1arg()
f = Expr(:call, op, :x)
Expand Down
8 changes: 8 additions & 0 deletions test/FileFormats/NL/NL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ function test_nlexpr_scalarnonlinearfunction_unary_special_case()
return
end

function test_nlexpr_scalarnonlinearfunction_unary_special_case_sign()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:sign, Any[x])
expr = NL._NLExpr(:(ifelse($x >= 0, 1, -1)))
_test_nlexpr(f, expr.nonlinear_terms, Dict(x => 0.0), 0.0)
return
end

function test_nlexpr_scalarnonlinearfunction_binary_special_case()
x = MOI.VariableIndex(1)
f = MOI.ScalarNonlinearFunction(:\, Any[x, 1])
Expand Down
26 changes: 26 additions & 0 deletions test/Nonlinear/Nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,32 @@ function test_automatic_differentiation_backend()
return
end

function test_univariate_sign()
f(y, p) = sign(y) * abs(y)^p
∇f(y, p) = p * abs(y)^(p - 1)
∇²f(y, p) = sign(y) * p * (p - 1) * abs(y)^(p - 2)
for p in (-0.5, 0.5, 2.0)
x = MOI.VariableIndex(1)
model = MOI.Nonlinear.Model()
MOI.Nonlinear.set_objective(model, :(sign($x) * abs($x)^$p))
evaluator = MOI.Nonlinear.Evaluator(
model,
MOI.Nonlinear.SparseReverseMode(),
[x],
)
MOI.initialize(evaluator, [:Grad, :Hess])
for y in (-10.0, -1.2, 1.2, 10.0)
@test MOI.eval_objective(evaluator, [y]) f(y, p)
g = [NaN]
MOI.eval_objective_gradient(evaluator, g, [y])
@test g[1] ∇f(y, p)
H = zeros(length(MOI.hessian_objective_structure(evaluator)))
MOI.eval_hessian_objective(evaluator, H, [y])
@test H[1] ∇²f(y, p)
end
end
end

end # TestNonlinear

TestNonlinear.runtests()
Expand Down

0 comments on commit 4be42b4

Please sign in to comment.