diff --git a/src/aggregation.jl b/src/aggregation.jl index f88496e..af9e40a 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -83,10 +83,8 @@ function transform_with(flag::LogJacFlag, transformation::ArrayTransformation, x end function transform_with(flag::LogJacFlag, t::ArrayTransformation{Identity}, x, index) - # TODO use version below when https://github.com/FluxML/Flux.jl/issues/416 is fixed - # y = reshape(copy(x), t.dims) index′ = index+dimension(t) - y = reshape(map(identity, x[index:(index′-1)]), t.dims) + y = reshape(x[index:(index′-1)], t.dims) y, logjac_zero(flag, robust_eltype(x)), index′ end