Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

list of operations that grad does not work with #93

Open
jw3126 opened this issue Jul 1, 2021 · 5 comments
Open

list of operations that grad does not work with #93

jw3126 opened this issue Jul 1, 2021 · 5 comments

Comments

@jw3126
Copy link

jw3126 commented Jul 1, 2021

Here is a list of some operations that did not work for me. I wonder about the errors that involve ChainRules in their message? For instance, in the sum example, I guess we are tracing too deep into the sum implementation. E.g. there exists a more high level sum rule:
https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl#L9

@dfdx Maybe the trace used for gradtape should have an is_primitive that checks if the signature is covered by an rrule? I realize that Yota.is_primitive !== Ghost.is_primitive and there is already such a rule. I think the issue is
what should happen when one starts tracing with a call that is already primitive. Not obvious whats the best design. Currently, such a call is entered anyway, this is why e.g. sum([1.0]) fails.

################################################################################
Yota.gradtape(sum, [1.0])
fails
No deriative rule found for op %42 = mapreduce(identity, add_sum, %2)::Float64, try defining it us
ing ChainRules.rrule(::typeof(mapreduce), ::typeof(identity), ::typeof(Base.add_sum), ::Vector{Flo
at64}) = ...
################################################################################
Yota.gradtape(sum, abs2, [1.0])
fails
No deriative rule found for op %30 = mapreduce(%2, add_sum, %3)::Float64, try defining it using Ch
ainRules.rrule(::typeof(mapreduce), ::typeof(abs2), ::typeof(Base.add_sum), ::Vector{Float64}) = .
..
################################################################################
Yota.gradtape(identity, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(sin, 1.0)
fails
MethodError: Cannot `convert` an object of type Float64 to an object of type Ghost.Variable
Closest candidates are:
  convert(::Type{T}, ::T) where T at essentials.jl:205
  Ghost.Variable(::Any, ::Any) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:22
################################################################################
Yota.gradtape(*, 1.0)
fails
MethodError: no method matching call_signature(::Tape{Yota.GradCtx}, ::Ghost.Input)
Closest candidates are:
  call_signature(::Tape, ::Ghost.Call) at /home/jan/.julia/packages/Ghost/S5BSq/src/tape.jl:516
################################################################################
Yota.gradtape(*, 1.0, 2.0)
fails
No deriative rule found for op %4 = mul_float(%2, %3)::Float64, try defining it using ChainRules.r
rule(::Core.IntrinsicFunction, ::Float64, ::Float64) = ...
@dfdx
Copy link
Owner

dfdx commented Jul 3, 2021

You are absolutely right - there's no way to represent a single primitive f(args...) as a tape, at least as a tape different from the one for args -> f(args...). I see several options here:

  1. Leave it as is, letting people trace the code of primitives even if sometimes it will confuse them.
  2. Forbid tracing the primitives. But what if it is just what somebody wanted to do?
  3. Show a warning explaining that it's probably not what a user wants, but letting them do it anyway.

I lean towards the last option since it's unlikely somebody will trace primitives not from REPL, and warnings in REPL are usually fine. I will let this idea to mature for the next couple of days though.

@jw3126
Copy link
Author

jw3126 commented Jul 3, 2021

I think that the tape args -> f(args...) option sounds natural. One way to think about a tape is that it is a list of all primitive calls that occur. If the entry point was already primitive, then it is just this one primitve call.
I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same. This again would be consistent with tracing a primitive call returning the tape with just that call. What drawbacks do you see with this?

Also one could add to the list:
1b. Throw an error by default, but that error can be disabled with a keyword allowing tracing into a primitive like currently. Generally I usually favor an error that must be explicitly disabled over a warning.

@dfdx
Copy link
Owner

dfdx commented Jul 4, 2021

I also expected that if you have function f(x); g(x) end then trace(f,x) and trace(g,x) would be the same.

trace() already works like this:

julia>     g(x) = 2x
g (generic function with 1 method)

julia>     f(x) = g(x)
f (generic function with 1 method)

julia>     trace(f, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(f)
  inp %2::Float64
  %3 = *(2, %2)::Float64


julia>     trace(g, 1.0)[2]
Tape{Dict{Any, Any}}
  inp %1::typeof(g)
  inp %2::Float64
  %3 = *(2, %2)::Float64

grad() behaves similarly with the exception for caching.

One way to think about a tape is that it is a list of all primitive calls that occur.

The first input to a tape is usually an object being called. In case of args -> f(args...) this object is an anonymous function which is fine. In case of a primitive it's unclear what should we put there instead.

The most straightforward way is to wrap the primitive into an anonymous function, but it will break an assumption that tape[V(1)].fn == f which may be useful for introspection and downstream transformations. It will also break on closures/callable structs.

The same applies to skipping the first argument altogether.

Putting the primitive itself as the first input also sounds weird - it will look like a recursive function which it's not.

On the other hand, trying to trace a primitive function doesn't seem to be a big use case, raising an error or warning sounds like a reasonable solution for me, at least until we hit a real case where it's not enough.

@lassepe
Copy link

lassepe commented Feb 15, 2022

Perhaps another entry for the list: It seems that grad currently deoes not work with LinearAlgebra.Adjoint

julia> A = rand(100, 100)
julia> x = rand(100)
julia> Yota.grad(x -> 0.5 * x' * A * x, x)
ERROR: LoadError: No deriative rule found for op %8 = *(0.5, %7, %2)::Float64, try defining it using 

	ChainRulesCore.rrule(::typeof(*), ::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}}, ::Vector{Float64}) = ...

@dfdx
Copy link
Owner

dfdx commented Feb 15, 2022

Thanks! I posted the corresponding issue in JuliaDiff/ChainRules.jl#589 as other AD systems may benefit from such a rule too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants