Skip to content

Commit

Permalink
Add reset!(), comment out wrong rules
Browse files Browse the repository at this point in the history
  • Loading branch information
dfdx committed Jul 3, 2021
1 parent 6af4b62 commit 1fa7d08
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ for i=1:100
end
```

Note that Yota caches gradients and may not see changes to functions
if you redefine them (e.g. in REPL). To reset the cache, invoke:

```julia
Yota.reset!()
```


## ChainRules

Expand Down
17 changes: 9 additions & 8 deletions src/drules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,16 @@ end

# some special broadcasting rules

function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::typeof(+), args...)
return NoTangent(), NoTangent(), [dy for a in args]...
end
@drule Broadcast.broadcasted(f::typeof(+), args::Vararg) ∇broadcasted
# are these rules just incorrect versions of the `∇broadcasted_special` below?
# function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::typeof(+), args...)
# return NoTangent(), NoTangent(), [dy for a in args]...
# end
# @drule Broadcast.broadcasted(f::typeof(+), args::Vararg) ∇broadcasted

function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::typeof(*), args...)
return NoTangent(), NoTangent(), [dy .* a for a in args]...
end
@drule Broadcast.broadcasted(f::typeof(*), args::Vararg) ∇broadcasted
# function ∇broadcasted(dy, ::typeof(Broadcast.broadcasted), f::typeof(*), args...)
# return NoTangent(), NoTangent(), [dy .* a for a in args]...
# end
# @drule Broadcast.broadcasted(f::typeof(*), args::Vararg) ∇broadcasted

# @diffrule getindex(u::AbstractArray, i) u ungetindex(u, dy, i)

Expand Down
4 changes: 2 additions & 2 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ end
Make a single step of backpropagation.
"""
function step_back!(tape::Tape, y::Variable, deriv_todo::Vector{Variable})
# TODO: here's the problem with the constructor_loss test:
# we reach y = __new__(...) twice and update all its fields twice
@debug "step_back!() for $(tape[y])"
df = get_deriv_function(call_signature(tape, tape[y]))
dy = tape.c.derivs[y]
Expand Down Expand Up @@ -234,6 +232,8 @@ end

const GRAD_CACHE = Dict{Any,Any}()

reset!() = empty!(GRAD_CACHE)


"""
grad(f, args...; seed=1)
Expand Down

0 comments on commit 1fa7d08

Please sign in to comment.