Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brianguenter/issue67 #75

Merged
merged 6 commits into from
Apr 17, 2024
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ TestItems = "0.1"
julia = "1.8"

[extras]
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"

[targets]
test = ["Test","TestItemRunner","TestItems","Memoize"]
test = ["Test", "TestItemRunner", "TestItems", "Memoize"]
15 changes: 10 additions & 5 deletions src/DerivativeGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,21 @@ struct DerivativeGraph{T<:Integer}
function DerivativeGraph(roots::AbstractVector, index_type::Type=Int64)
postorder_number = IdDict{Node,index_type}()

(postorder_number, nodes, var_array) = postorder(roots)
new_roots = Vector{Node}(undef, length(roots))
for (i, root) in pairs(roots)
new_roots[i] = create_NoOp(root)
end

(postorder_number, nodes, var_array) = postorder(new_roots)

expression_cache = IdDict()

edges = partial_edges(roots, postorder_number, expression_cache, length(var_array), length(roots))
edges = partial_edges(new_roots, postorder_number, expression_cache, length(var_array), length(new_roots))

sort!(var_array, by=x -> postorder_number[x]) #sort by postorder number from lowest to highest

root_index_to_postorder_number = Vector{index_type}(undef, length(roots)) #roots are handled differently than variables because variable node can only occur once in list of variables but root node can occur multiple times in list of roots
for (i, x) in pairs(roots)
root_index_to_postorder_number = Vector{index_type}(undef, length(new_roots)) #new_roots are handled differently than variables because variable node can only occur once in list of variables but root node can occur multiple times in list of new_roots
for (i, x) in pairs(new_roots)
root_index_to_postorder_number[i] = postorder_number[x]
end

Expand Down Expand Up @@ -205,7 +210,7 @@ struct DerivativeGraph{T<:Integer}
return new{index_type}(
postorder_number,
nodes,
roots,
new_roots,
var_array,
root_index_to_postorder_number,
root_postorder_to_index,
Expand Down
16 changes: 16 additions & 0 deletions src/ExpressionGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,22 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val{I}) where {I,N} = Node(1)

function_variable_derivative(a::Node, index::Val{i}) where {i} = check_cache((Differential, children(a)[i]), EXPRESSION_CACHE)

"""When constructing `DerivativeGraph` with repeated values in roots, e.g.,
```julia
@variables x
f = sin(x)
gr = DerivativeGraph([f,f,f])
```
all three of the f values reference the same element. To ensure that `partial_edges` creates an edge for each one of the roots we need a `NoOp` function. The derivative of `NoOp` is 1.0; the sole purpose of this node type is to ensure that the resulting derivative graph has a separate edge for each repeated root value. There are other ways this might be accomplished but this is the simplest, since it can be performed on the original `Node` graph before the recursive `partial_edges` traversal."""
struct NoOp
end

function create_NoOp(child)
return Node(NoOp(), child)
end

derivative(NoOp, arg::Tuple{T}, ::Val{1}) where {T} = 1.0

function derivative(a::Node, index::Val{1})
# if is_variable(a)
# if arity(a) == 0
Expand Down
Loading
Loading