Skip to content

Commit

Permalink
fix broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Jun 20, 2018
1 parent 5d6465a commit 7b44adb
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,14 @@ function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
end

Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)

Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...)
using Base.Broadcast: BroadcastStyle

struct TrackedStyle <: BroadcastStyle end

Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()

function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
bc = Broadcast.flatten(bc)
tracked_broadcast(bc.f, bc.args...)
end

0 comments on commit 7b44adb

Please sign in to comment.