Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad() fails for simple case since v0.5.0 #90

Closed
EvoArt opened this issue Jun 28, 2021 · 8 comments
Closed

grad() fails for simple case since v0.5.0 #90

EvoArt opened this issue Jun 28, 2021 · 8 comments

Comments

@EvoArt
Copy link

EvoArt commented Jun 28, 2021

Maybe I'm using the package wrong. I'm clueless about AD in general. I installed Yota today and cant seem to get simple gradients. Whereas on v0.2.0 (accidentally installed in other environment due to conflicts) all seems fine.

julia> using Yota

julia> f(x) = x^2
f (generic function with 1 method)

julia> grad(f,3)
ERROR: Neither ChainRules pullback, nor native Yota derivative found for op %10 = literal_pow(^, %2, %9)::Int64
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] step_back!(tape::Ghost.Tape{Yota.GradCtx}, y::Ghost.Variable, deriv_todo::Vector{Ghost.Variable})
   @ Yota C:\Users\arn203\.julia\packages\Yota\tL6l7\src\grad.jl:129
 [3] back!(tape::Ghost.Tape{Yota.GradCtx})
   @ Yota C:\Users\arn203\.julia\packages\Yota\tL6l7\src\grad.jl:174
 [4] gradtape!(tape::Ghost.Tape{Yota.GradCtx})
   @ Yota C:\Users\arn203\.julia\packages\Yota\tL6l7\src\grad.jl:195
 [5] gradtape(f::typeof(f), args::Int64)
   @ Yota C:\Users\arn203\.julia\packages\Yota\tL6l7\src\grad.jl:208
 [6] grad(f::typeof(f), args::Int64)
   @ Yota C:\Users\arn203\.julia\packages\Yota\tL6l7\src\grad.jl:239
 [7] top-level scope
   @ REPL[6]:1

(PGM) pkg> status
      Status `C:\Users\arn203\OneDrive - University of Exeter\Documents\PGM\Project.toml`
  [cbdf2221] AlgebraOfGraphics v0.4.4
  [6e4b80f9] BenchmarkTools v1.0.0
  [49dc2e85] Calculus v0.5.1
  [31c24e10] Distributions v0.24.18
  [ea4f424c] Gen v0.4.3
  [bdcacae8] LoopVectorization v0.12.47
  [c7f686f2] MCMCChains v4.12.0
  [2913bbd2] StatsBase v0.33.8
  [4c63d2b9] StatsFuns v0.9.8
  [f3b207a7] StatsPlots v0.14.21
  [bc48ee85] Tullio v0.2.14
  [cd998857] Yota v0.5.0  
  [37e2e46d] LinearAlgebra
@EvoArt EvoArt changed the title fails for simple case since v0.5.0 grad() fails for simple case since v0.5.0 Jun 28, 2021
@dfdx
Copy link
Owner

dfdx commented Jun 28, 2021

Thanks for reporting it! It turns out ChainRules doesn't have a rule for Base.literal_pow() which x ^ n is lowered to. A quick fix is to add this method manually:

function ChainRules.rrule(::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}) where p
    function literal_pow_pullback(dy)
        return ZeroTangent(), ZeroTangent(), (p * x ^ (p - 1) * dy), NoTangent()
    end
    return Base.literal_pow(^, x, Val(p)), literal_pow_pullback
end

Yota.update_chainrules_primitives!()

You need only need update_chainrules_primitives!() if you define this rrule after Yota is loaded to update its cache. I think we will get rid of this inconvenience soon, but not today.


It's interesting that other packages using ChainRules never reported this issues. My guess is that they either have their own derivatives for it (as in Zygote), or overload ^ directly without lowering it to Base.literal_pow().

@oxinabox do you think we need to just add the rrule() to ChainRules or you know a more elegant way?

@oxinabox
Copy link
Contributor

This seems the correct way.

Maybe can use a @scalar_rule for it.

It seems right that it should have different code thanx^p since we don't need derivative wrt p as it is a literal.

I am not sure why Nabla didn't hit this as a failure.
I will have to test that, maybe we it is a missing test.
Maybe there is still a old Nabla only rule for it that I forgot to delete

@EvoArt
Copy link
Author

EvoArt commented Jun 29, 2021

Thanks for the quick response! The following works for me:

function ChainRules.rrule(::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}) where p
    function literal_pow_pullback(dy)
        return Yota.ZeroTangent(), Yota.ZeroTangent(), (p * x ^ (p - 1) * dy), Yota.NoTangent()
    end
    return Base.literal_pow(^, x, Val(p)), literal_pow_pullback
end

Yota.update_chainrules_primitives!()

I had to add the Yota namespace. Would there be any benefit to using ChainRulesCore.ZeroTangent() instead?

@oxinabox
Copy link
Contributor

Would there be any benefit to using ChainRulesCore.ZeroTangent() instead?

Yes, that only works because Yota brings ChainRuleCore.ZeroTangent into it's own namespace.
But doing so is not part of Yota's public API.
Yota might stop doing that and instead change to qualifying all uses of the name
or even import ChainRulesCore.ZeroTangent as ZZ.
Which would break code that relied on Yote.ZeroTangent returning a ChainRulesCore.ZeroTangent.

related it should be ChainRulesCore.rrule, rather than ChainRules.rrule.
That only works because when seeing if overloading is allowed julia just checks that the name has any namespace containing it listed -- not nesc the one that defined it.
That code will similarly also break if ChainRules.jl changed to doing import ChainRulesCore.rrule as RR
(that seems incredibly unlikely)

@dfdx
Copy link
Owner

dfdx commented Jun 29, 2021

I'm trying to convert the rrule into a @scalar_rule, but doesn't work very well:

julia> @scalar_rule(Base.literal_pow(f::typeof(^), x, p::Val{y} where y),
            (NoTangent(), (ifelse(iszero(x), zero(Ω), y * Ω / x), NoTangent())),
        )
ERROR: LoadError: BoundsError: attempt to access 2-element Vector{Any} at index [3]
Stacktrace:
  [1] getindex
    @ ./array.jl:801 [inlined]
  [2] #59
    @ ./none:0 [inlined]
  [3] iterate
    @ ./generator.jl:47 [inlined]
  [4] collect(itr::Base.Generator{Tuple{Expr}, ChainRulesCore.var"#59#61"{Int64}})
    @ Base ./array.jl:678
  [5] (::ChainRulesCore.var"#58#60"{Tuple{Expr}, Vector{Expr}})(input_i::Int64)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/e5hAX/src/rule_definition_tools.jl:190
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] collect_to!(dest::Vector{Expr}, itr::Base.Generator{UnitRange{Int64}, ChainRulesCore.var"#58#60"{Tuple{Expr}, Vector{Expr}}}, offs::Int64, st::Int64)
    @ Base ./array.jl:724
  [8] collect_to_with_first!(dest::Vector{Expr}, v1::Expr, itr::Base.Generator{UnitRange{Int64}, ChainRulesCore.var"#58#60"{Tuple{Expr}, Vector{Expr}}}, st::Int64)
    @ Base ./array.jl:702
  [9] _collect(c::UnitRange{Int64}, itr::Base.Generator{UnitRange{Int64}, ChainRulesCore.var"#58#60"{Tuple{Expr}, Vector{Expr}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:696
 [10] collect_similar(cont::UnitRange{Int64}, itr::Base.Generator{UnitRange{Int64}, ChainRulesCore.var"#58#60"{Tuple{Expr}, Vector{Expr}}})
    @ Base ./array.jl:606
 [11] map
    @ ./abstractarray.jl:2294 [inlined]
 [12] scalar_rrule_expr(__source__::LineNumberNode, f::Expr, call::Expr, setup_stmts::Tuple{Nothing}, inputs::Vector{Expr}, partials::Tuple{Expr})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/e5hAX/src/rule_definition_tools.jl:189
 [13] var"@scalar_rule"(__source__::LineNumberNode, __module__::Module, call::Any, maybe_setup::Any, partials::Vararg{Any, N} where N)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/e5hAX/src/rule_definition_tools.jl:89
...

@oxinabox Is it something @scalar_rule is supposed to handle? Am I using it wrong?

@oxinabox
Copy link
Contributor

oxinabox commented Jun 30, 2021

@scalar_rule is not designed to handle functions that have arguments that are not simply scalars.
^ and Val are not scalars.,
It was basically made to make porting things from DiffRules.jl easy.
arguments that that are not annotated with types get restricted to Number,
and there is no support for where since scalars do not have type parameters.

The following works, but I wouldn't say that it is nesc something that the public API promises will keep working.
(It might, it might not, I would need to think more).

julia> @scalar_rule(
           Base.literal_pow(op::Any, x::Real, p::Val),
           (@setup y=only(typeof(p).parameters)),
           (NoTangent(), ifelse(iszero(x), zero(Ω), y * Ω / x), NoTangent())
       )

julia> rrule(Base.literal_pow, ^, 1, Val(2))
(1, var"#literal_pow_pullback#8"{Int64, Int64, Int64}(1, 2, 1))

julia> rrule(Base.literal_pow, ^, 1, Val(2))[2](1)
(NoTangent(), NoTangent(), 2.0, NoTangent())

@dfdx
Copy link
Owner

dfdx commented Jun 30, 2021

Maybe a bit more stable version:

val_param(::Val{P}) where P = P
@scalar_rule(
       Base.literal_pow(op::Any, x::Real, p::Val),
       (@setup y=val_param(p)),
       (NoTangent(), ifelse(iszero(x), zero(Ω), y * Ω / x), NoTangent())
)

@dfdx
Copy link
Owner

dfdx commented Jul 3, 2021

Fixed in JuliaDiff/ChainRules.jl#464

@dfdx dfdx closed this as completed Jul 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants