Skip to content

Commit

Permalink
Alternative approach to #4270 without a staged function
Browse files Browse the repository at this point in the history
  • Loading branch information
simonster committed Dec 8, 2014
1 parent c701ac4 commit e40318d
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,28 @@ reshape(a::AbstractArray, dims::Int...) = reshape(a, dims)
vec(a::AbstractArray) = reshape(a,length(a))
vec(a::AbstractVector) = a

function nextunsqueezeddim(dims, i)
while true
squeeze = false
for dim in dims
if dim == i
squeeze && error("squeezed dims must be unique")
squeeze = true
end
argtail(x, rest...) = rest
tail(x::Tuple) = argtail(x...)

_sub(::(), ::()) = ()
_sub(t::Tuple, ::()) = t
_sub(t::Tuple, s::Tuple) = _sub(tail(t), tail(s))

function squeeze(A::AbstractArray, dims::Dims)
for i in 1:length(dims)
1 <= dims[i] <= ndims(A) || error("squeezed dims must be in range [1, ndims(A)]")
size(A, dims[i]) == 1 || error("squeezed dims must all be size 1")
for j = 1:i-1
dims[j] == dims[i] && error("squeezed dims must be unique")
end
!squeeze && return i
i += 1
end
end

stagedfunction squeeze(A::AbstractArray, dims::Dims)
n = ndims(A)
quote
if !($(Expr(:&&, [:(1 <= dims[$i] <= $n) for i = 1:length(dims)]...)))
error("squeezed dims must be in range [1, ndims(A)]")
elseif !($(Expr(:&&, [:(size(A, dims[$i]) == 1) for i = 1:length(dims)]...)))
error("squeezed dims must all be size 1")
d = ()
for i = 1:ndims(A)
if !in(i, dims)
d = tuple(d..., size(A, i))
end

dim_1 = nextunsqueezeddim(dims, 1)
$([:($(symbol("dim_$i")) = nextunsqueezeddim(dims, $(symbol("dim_$(i-1)"))+1)) for i = 2:n-length(dims)]...)

reshape(A, tuple($([:(size(A, $(symbol("dim_$i")))) for i = 1:n-length(dims)]...)))
end
reshape(A, d::typeof(_sub(size(A), dims)))
end

squeeze(A::AbstractArray, dim::Integer) = squeeze(A, (int(dim),))
Expand Down

0 comments on commit e40318d

Please sign in to comment.