-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
list of operations that grad does not work with #93
Comments
You are absolutely right - there's no way to represent a single primitive
I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though. |
I think that the tape Also one could add to the list: |
julia> g(x) = 2x
g (generic function with 1 method)
julia> f(x) = g(x)
f (generic function with 1 method)
julia> trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(f)
inp %2::Float64
%3 = *(2, %2)::Float64
julia> trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
inp %1::typeof(g)
inp %2::Float64
%3 = *(2, %2)::Float64
The first input to a tape is usually an object being called. In case of The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that The same applies to skipping the first argument altogether. Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not. On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough. |
Perhaps another entry for the list: It seems that grad currently deoes not work with julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using
ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ... |
Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too. |
Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule:
https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9
@dfdx
Maybe theI realize thattrace
used forgradtape
should have an is_primitive that checks if the signature is covered by an rrule?Yota.is_primitive
!==Ghost.is_primitive
and there is already such a rule. I think the issue iswhat should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g.
sum([1.0])
fails.The text was updated successfully, but these errors were encountered: