Skip to content

Commit

Permalink
implement new array assignment shape matching rule. fixes #4048, fixes
Browse files Browse the repository at this point in the history
…#4383

this rule ignores singleton dimensions, and allows the last dimension of
one side to match all trailing dimensions of the other.
  • Loading branch information
JeffBezanson committed Dec 23, 2013
1 parent def4b95 commit d62f455
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
15 changes: 1 addition & 14 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,7 @@ function setindex!(A::Array, x, I::Union(Real,AbstractArray)...)
assign_cache = Dict()
end
X = x
nel = 1
for idx in I
nel *= length(idx)
end
if length(X) != nel
throw(DimensionMismatch(""))
end
if ndims(X) > 1
for i = 1:length(I)
if size(X,i) != length(I[i])
throw(DimensionMismatch(""))
end
end
end
setindex_shape_check(X, I...)
gen_array_index_map(assign_cache, storeind -> quote
A[$storeind] = X[refind]
refind += 1
Expand Down
82 changes: 72 additions & 10 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,83 @@ index_shape(I::Real...) = ()
index_shape(i, I...) = tuple(length(i), index_shape(I...)...)

# check for valid sizes in A[I...] = X where X <: AbstractArray
# we want to allow dimensions that are equal up to permutation, but only
# for permutations that leave array elements in the same linear order.
# those are the permutations that preserve the order of the non-singleton
# dimensions.
function setindex_shape_check(X::AbstractArray, I...)
li = length(I)
ii = 1
nel = 1
for idx in I
nel *= length(idx)
end
if length(X) != nel
error("dimensions must match")
end
if ndims(X) > 1
for i = 1:length(I)
if size(X,i) != length(I[i])
error("dimensions must match")
xi = 1
ndx = ndims(X)
match = true
while ii < li
lii = length(I[ii])::Int
ii += 1
if lii != 1
nel *= lii
local lxi
while true
lxi = size(X,xi)
xi += 1
if lxi != 1 || xi > ndx
break
end
end
if xi > ndx
trailing = lii
while ii <= li
lii = length(I[ii])::Int
trailing *= lii
ii += 1
end
# X's last dimension can match all the trailing indexes
if lxi == trailing && match
return
else
throw(DimensionMismatch(""))
end
else
if lxi != lii
match = false

This comment has been minimized.

Copy link
@timholy

timholy Dec 26, 2013

Member

As far as I can tell, to get here both lxi and lii have to be bigger than 1. So it's not obvious to me why this can't throw immediately---once match gets set to false, there's no way for it to get set to true, and AFAICT all exit paths from this function result in an error if match isn't true. By that logic, match isn't even necessary.

Assuming I'm wrong in this analysis 😄, a comment might be in order.

end
end
end
end

# last index can match X's trailing dimensions
lii = length(I[ii])::Int
nel *= lii
if lii != trailingsize(X,xi)
match = false
end

if !(match && length(X)==nel)
throw(DimensionMismatch(""))
end
end

setindex_shape_check(X::AbstractArray) = (length(X)==1 || throw(DimensionMismatch("")))

setindex_shape_check(X::AbstractArray, i) =
(length(X)==length(i) || throw(DimensionMismatch("")))

setindex_shape_check{T}(X::AbstractArray{T,1}, i) =
(length(X)==length(i) || throw(DimensionMismatch("")))

setindex_shape_check{T}(X::AbstractArray{T,1}, i, j) =
(length(X)==length(i)*length(j) || throw(DimensionMismatch("")))

function setindex_shape_check{T}(X::AbstractArray{T,2}, i, j)
li, lj = length(i), length(j)
if length(X) != li*lj
throw(DimensionMismatch(""))
end
sx1 = size(X,1)
if !(li == 1 || li == sx1 || sx1 == 1)
throw(DimensionMismatch(""))
end
end

# convert to integer index
Expand Down

0 comments on commit d62f455

Please sign in to comment.