Skip to content

Remove type assertions in evaluation code #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions src/EvaluateEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ which speed up evaluation significantly.
"""
function eval_tree_array(
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
) where {T<:Number}
n = size(cX, 2)
if turbo
@assert T in (Float32, Float64)
Expand All @@ -87,7 +87,7 @@ end

function _eval_tree_array(
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,turbo}
) where {T<:Number,turbo}
n = size(cX, 2)
# First, we see if there are only constants in the tree - meaning
# we can just return the constant result.
Expand Down Expand Up @@ -148,7 +148,7 @@ end

function deg2_eval(
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
) where {T<:Number,F,turbo}
@maybe_turbo turbo for j in indices(cumulator_l)
x = op(cumulator_l[j], cumulator_r[j])::T
cumulator_l[j] = x
Expand All @@ -158,7 +158,7 @@ end

function deg1_eval(
cumulator::AbstractVector{T}, op::F, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
) where {T<:Number,F,turbo}
@maybe_turbo turbo for j in indices(cumulator)
x = op(cumulator[j])::T
cumulator[j] = x
Expand All @@ -168,7 +168,7 @@ end

function deg0_eval(
tree::Node{T}, cX::AbstractMatrix{T}
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
) where {T<:Number}
if tree.constant
n = size(cX, 2)
return (fill(tree.val::T, n), true)
Expand All @@ -179,7 +179,7 @@ end

function deg1_l2_ll0_lr0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
) where {T<:Number,F,F2,turbo}
n = size(cX, 2)
if tree.l.l.constant && tree.l.r.constant
val_ll = tree.l.l.val::T
Expand Down Expand Up @@ -229,7 +229,7 @@ end
# op(op2(x)) for x variable or constant
function deg1_l1_ll0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
) where {T<:Number,F,F2,turbo}
n = size(cX, 2)
if tree.l.l.constant
val_ll = tree.l.l.val::T
Expand All @@ -254,7 +254,7 @@ end
# op(x, y) for x and y variable/constant
function deg2_l0_r0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
) where {T<:Number,F,turbo}
n = size(cX, 2)
if tree.l.constant && tree.r.constant
val_l = tree.l.val::T
Expand Down Expand Up @@ -297,7 +297,7 @@ end
# op(x, y) for x variable/constant, y arbitrary
function deg2_l0_eval(
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
) where {T<:Number,F,turbo}
n = size(cX, 2)
if tree.l.constant
val = tree.l.val::T
Expand All @@ -319,7 +319,7 @@ end
# op(x, y) for x arbitrary, y variable/constant
function deg2_r0_eval(
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
) where {T<:Number,F,turbo}
n = size(cX, 2)
if tree.r.constant
val = tree.r.val::T
Expand Down Expand Up @@ -347,7 +347,7 @@ over an entire array when the values are all the same.
"""
function _eval_constant_tree(
tree::Node{T}, operators::OperatorEnum
)::Tuple{T,Bool} where {T<:Number}
) where {T<:Number}
if tree.degree == 0
return deg0_eval_constant(tree)
elseif tree.degree == 1
Expand All @@ -357,13 +357,13 @@ function _eval_constant_tree(
end
end

@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Number}
@inline function deg0_eval_constant(tree::Node{T}) where {T<:Number}
return tree.val::T, true
end

function deg1_eval_constant(
tree::Node{T}, op::F, operators::OperatorEnum
)::Tuple{T,Bool} where {T<:Number,F}
) where {T<:Number,F}
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
!complete && return zero(T), false
output = op(cumulator)::T
Expand All @@ -372,7 +372,7 @@ end

function deg2_eval_constant(
tree::Node{T}, op::F, operators::OperatorEnum
)::Tuple{T,Bool} where {T<:Number,F}
) where {T<:Number,F}
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
!complete && return zero(T), false
(cumulator2, complete2) = _eval_constant_tree(tree.r, operators)
Expand All @@ -388,7 +388,7 @@ Evaluate an expression tree in a way that can be auto-differentiated.
"""
function differentiable_eval_tree_array(
tree::Node{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum
)::Tuple{AbstractVector{T},Bool} where {T<:Number,T1}
) where {T<:Number,T1}
n = size(cX, 2)
if tree.degree == 0
if tree.constant
Expand All @@ -405,7 +405,7 @@ end

function deg1_diff_eval(
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1}
) where {T<:Number,F,T1}
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
@return_on_false complete left
out = op.(left)
Expand All @@ -415,7 +415,7 @@ end

function deg2_diff_eval(
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1}
) where {T<:Number,F,T1}
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
@return_on_false complete left
(right, complete2) = differentiable_eval_tree_array(tree.r, cX, operators)
Expand Down
26 changes: 13 additions & 13 deletions src/EvaluateEquationDerivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function eval_diff_tree_array(
operators::OperatorEnum,
direction::Int;
turbo::Bool=false,
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
) where {T<:Number}
assert_autodiff_enabled(operators)
# TODO: Implement quick check for whether the variable is actually used
# in this tree. Otherwise, return zero.
Expand Down Expand Up @@ -71,7 +71,7 @@ function _eval_diff_tree_array(
operators::OperatorEnum,
direction::Int,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,turbo}
) where {T<:Number,turbo}
evaluation, derivative, complete = if tree.degree == 0
diff_deg0_eval(tree, cX, direction)
elseif tree.degree == 1
Expand All @@ -82,7 +82,7 @@ function _eval_diff_tree_array(
operators.diff_unaops[tree.op],
operators,
direction,
Val(turbo),
turbo ? Val(true) : Val(false),
)
else
diff_deg2_eval(
Expand All @@ -92,7 +92,7 @@ function _eval_diff_tree_array(
operators.diff_binops[tree.op],
operators,
direction,
Val(turbo),
turbo ? Val(true) : Val(false),
)
end
@return_on_false2 complete evaluation derivative
Expand All @@ -101,7 +101,7 @@ end

function diff_deg0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
) where {T<:Number}
n = size(cX, 2)
const_part = deg0_eval(tree, cX)[1]
derivative_part =
Expand All @@ -117,7 +117,7 @@ function diff_deg1_eval(
operators::OperatorEnum,
direction::Int,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
) where {T<:Number,F,dF,turbo}
n = size(cX, 2)
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
tree.l, cX, operators, direction, Val(turbo)
Expand All @@ -143,7 +143,7 @@ function diff_deg2_eval(
operators::OperatorEnum,
direction::Int,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
) where {T<:Number,F,dF,turbo}
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
tree.l, cX, operators, direction, Val(turbo)
)
Expand Down Expand Up @@ -194,7 +194,7 @@ function eval_grad_tree_array(
operators::OperatorEnum;
variable::Bool=false,
turbo::Bool=false,
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
) where {T<:Number}
assert_autodiff_enabled(operators)
n = size(cX, 2)
if variable
Expand Down Expand Up @@ -224,7 +224,7 @@ function eval_grad_tree_array(
operators::OperatorEnum,
::Val{variable},
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
) where {T<:Number,variable,turbo}
evaluation, gradient, complete = _eval_grad_tree_array(
tree, n, n_gradients, index_tree, cX, operators, Val(variable), Val(turbo)
)
Expand Down Expand Up @@ -258,7 +258,7 @@ function _eval_grad_tree_array(
operators::OperatorEnum,
::Val{variable},
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
) where {T<:Number,variable,turbo}
if tree.degree == 0
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, Val(variable))
elseif tree.degree == 1
Expand Down Expand Up @@ -297,7 +297,7 @@ function grad_deg0_eval(
index_tree::NodeIndex,
cX::AbstractMatrix{T},
::Val{variable},
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable}
) where {T<:Number,variable}
const_part = deg0_eval(tree, cX)[1]

if variable == tree.constant
Expand All @@ -321,7 +321,7 @@ function grad_deg1_eval(
operators::OperatorEnum,
::Val{variable},
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
) where {T<:Number,F,dF,variable,turbo}
(cumulator, dcumulator, complete) = eval_grad_tree_array(
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
)
Expand Down Expand Up @@ -350,7 +350,7 @@ function grad_deg2_eval(
operators::OperatorEnum,
::Val{variable},
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
) where {T<:Number,F,dF,variable,turbo}
(cumulator1, dcumulator1, complete) = eval_grad_tree_array(
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
)
Expand Down