From 6cdc9e6604694500fdca6f0fb9f3be758a936af7 Mon Sep 17 00:00:00 2001 From: longemen3000 Date: Wed, 30 Aug 2023 17:59:10 -0400 Subject: [PATCH] define ternary functions --- Project.toml | 2 +- src/ForwardDiffOverMeasurements.jl | 70 ++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index c6e422d..174300a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ForwardDiffOverMeasurements" uuid = "9eb8ae02-809e-4b16-afbc-1cadb820c769" authors = ["longemen3000 and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/src/ForwardDiffOverMeasurements.jl b/src/ForwardDiffOverMeasurements.jl index 367511c..99d8685 100644 --- a/src/ForwardDiffOverMeasurements.jl +++ b/src/ForwardDiffOverMeasurements.jl @@ -1,8 +1,10 @@ module ForwardDiffOverMeasurements -using ForwardDiff: Dual, DiffRules, NaNMath, LogExpFunctions, SpecialFunctions +using ForwardDiff: Dual, DiffRules, NaNMath, LogExpFunctions, SpecialFunctions,≺ using Measurements: Measurement -import Base: +,-,/,*,promote_rule +import Base: +,-,/,*,promote_rule +using ForwardDiff: AMBIGUOUS_TYPES, partials, values +using ForwardDiff: ForwardDiff function promote_rule(::Type{Measurement{V}}, ::Type{Dual{T, V, N}}) where {T,V,N} Dual{Measurement{T}, V, N} @@ -20,7 +22,7 @@ function overload_ambiguous_binary(M,f) ∂y = Dual{Tx}(y) $Mf(x,∂y) end - + @inline function $Mf(x::Measurement,y::Dual{Ty}) where {Ty} ∂x = Dual{Ty}(x) $Mf(∂x,y) @@ -28,6 +30,34 @@ function overload_ambiguous_binary(M,f) end end +macro define_ternary_dual_op2(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body) + FD = ForwardDiff + R = Measurement + defs = quote + @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body + @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body + @inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body + @inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body + end + for Q in AMBIGUOUS_TYPES + expr = quote + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body + @inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body + end + append!(defs.args, expr.args) + end + expr = quote + @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body + @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body + @inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body + end + append!(defs.args, expr.args) + return esc(defs) +end + #use DiffRules.jl rules for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) @@ -44,4 +74,38 @@ for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) end end +#ternary overloads +@define_ternary_dual_op2( + Base.hypot, + ForwardDiff.calc_hypot(x, y, z, Txyz), + ForwardDiff.calc_hypot(x, y, z, Txy), + ForwardDiff.calc_hypot(x, y, z, Txz), + ForwardDiff.calc_hypot(x, y, z, Tyz), + ForwardDiff.calc_hypot(x, y, z, Tx), + ForwardDiff.calc_hypot(x, y, z, Ty), + ForwardDiff.calc_hypot(x, y, z, Tz), +) + +@define_ternary_dual_op2( + Base.fma, + ForwardDiff.calc_fma_xyz(x, y, z), # xyz_body + ForwardDiff.calc_fma_xy(x, y, z), # xy_body + ForwardDiff.calc_fma_xz(x, y, z), # xz_body + Base.fma(y, x, z), # yz_body + Dual{Tx}(Base.fma(value(x), y, z), partials(x) * y), # x_body + Base.fma(y, x, z), # y_body + Dual{Tz}(Base.fma(x, y, value(z)), partials(z)) # z_body +) + +@define_ternary_dual_op2( + Base.muladd, + ForwardDiff.calc_muladd_xyz(x, y, z), # xyz_body + ForwardDiff.calc_muladd_xy(x, y, z), # xy_body + ForwardDiff.calc_muladd_xz(x, y, z), # xz_body + Base.muladd(y, x, z), # yz_body + Dual{Tx}(Base.muladd(value(x), y, z), partials(x) * y), # x_body + Base.muladd(y, x, z), # y_body + Dual{Tz}(Base.muladd(x, y, value(z)), partials(z)) # z_body +) + end #module \ No newline at end of file