Skip to content

Commit

Permalink
cat fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Aug 3, 2018
1 parent 73c7cfd commit 13affb1
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ end
end
end

Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)

@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
Expand Down

0 comments on commit 13affb1

Please sign in to comment.