diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 04be331fdb..7111d780fd 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -159,10 +159,10 @@ end end end -Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) -Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) -Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) -Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) +Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims) +Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) +Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) +Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims) @grad function cat(Xs...; dims) cat(data.(Xs)..., dims = dims), function (Δ)