Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting and constructors #119

Open
torfjelde opened this issue Jan 13, 2022 · 8 comments
Open

Broadcasting and constructors #119

torfjelde opened this issue Jan 13, 2022 · 8 comments

Comments

@torfjelde
Copy link

torfjelde commented Jan 13, 2022

There seems to be something weird going on when broadcasting over Real:

julia> using Tracker, Distributions

julia> m = first(param(zeros(1)))
0.0 (tracked)

julia> s = first(param(ones(1)))
1.0 (tracked)

julia> typeof(Normal.(m, s))
Tracker.Tracked{Normal{Float64}}

But the strange thing is that the following seems to work just fine:

julia> struct TwoFields{T1,T2}
           x::T1
           y::T2
       end

julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)

Possibly related to: #65

@devmotion
Copy link
Contributor

Maybe https://github.com/JuliaStats/Distributions.jl/blob/master/src/common.jl#L149 messes with Tracker's broadcasting? What happens if you add the same definition for your custom struct?

@torfjelde
Copy link
Author

Looks fine:

julia> using Tracker

julia> struct TwoFields{T1,T2}
           x::T1
           y::T2
       end

julia> Broadcast.broadcastable(d::TwoFields) = Ref(d)

julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)

@devmotion
Copy link
Contributor

OK, I figured out why this is happening: It's caused by the heuristic in

(eltype(y) <: Real && eltype(y) !== Bool) || return y
and the definition of eltype(::Normal{T}) = T (well, actually Distributions defines it on types as recommended in the Julia docs since instances fall back to it).

If one defines eltype for TwoFields such that it returns a subtype of Real the same behaviour can be observed, eg.

julia> Base.eltype(::Type{TwoFields{T1,T2}}) where {T1,T2} = Base.promote_type(T1, T2)

julia> typeof(TwoFields.(first(param(zeros(1))), first(param(zeros(1)))))
Tracker.Tracked{TwoFields{Float64, Float64}}

@torfjelde
Copy link
Author

Aaaah..

So is the fix to also check y isa AbstractArray or something?

@devmotion
Copy link
Contributor

I'm not completely sure about the motivation of this check but to me it seems the heuristic is supposed to drop tracking information in cases where the output is known to be non-differentiable. I think the heuristic should rather be too strict and avoid dropping tracking information silently. Maybe eg it could be restricted to y isa Union{Bool,AbstractArray{<:Bool}}. I'm worried though that changes in this heuristic break many downstream packages.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 14, 2022

Anything with eltype Bool gets rejected (including bools themselves), so as long as the conditional remains the same, we should be compatible with the rest of the packages. I would test Turing/ SciML against such a branch to be safe regardless. Unless someone is relying on the behaviour in the MWE.

@devmotion
Copy link
Contributor

Anything with eltype Bool gets rejected (including bools themselves), so as long as the conditional remains the same,

The fix for Distributions (and e.g. samplers in general: https://docs.julialang.org/en/v1/stdlib/Random/#A-simple-sampler-without-pre-computed-data) requires to change, i.e., probably restrict, the conditional - eltype is not only used for arrays or standard containers but also in other settings which causes the problems in this PR. For instance, eltype(::Bernoulli) = Bool (https://github.com/JuliaStats/Distributions.jl/blob/5cb0bfc0383180341c19a478ba859190b5b728a0/src/univariate/discrete/bernoulli.jl#L43) but it is completely reasonable to differentiate loglikelihood(Bernoulli(p), x) = sum(x) * log(p) + (length(x) - sum(x)) * log1p(-p) with respect to parameter p where p is a TrackedReal and x is an untracked Array{Bool}. So it seems a fix would have to make sure that not anything with eltype Bool drops the tracking information.

@ToucheSir
Copy link
Member

I don't quite understand what ∇broadcast is doing and why it differs from https://github.com/FluxML/Zygote.jl/blob/v0.6.51/src/lib/broadcast.jl#L188-L207, but perhaps Tracker could borrow some of Zygote's logic here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants