Skip to content

Commit

Permalink
fix broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Jul 12, 2018
1 parent ec8550e commit 17b3313
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,14 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
track(Call(back, tracker.(args)), y)
end

Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)

Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...)
using Base.Broadcast: BroadcastStyle

struct TrackedStyle <: BroadcastStyle end

Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()

function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
bc = Broadcast.flatten(bc)
∇broadcast(bc.f, bc.args...)
end

0 comments on commit 17b3313

Please sign in to comment.