Skip to content

Commit

Permalink
fix extrema(A,dim) when length(dim)>1
Browse files Browse the repository at this point in the history
Ref #22118
(cherry picked from commit 4a470cb)
  • Loading branch information
bjarthur authored and vtjnash committed Sep 14, 2017
1 parent 9cf8ac0 commit 01f3a50
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
37 changes: 17 additions & 20 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1440,26 +1440,23 @@ function extrema(A::AbstractArray, dims)
return extrema!(B, A)
end

@generated function extrema!(B, A::AbstractArray{T,N}) where {T,N}
return quote
sA = size(A)
sB = size(B)
@nloops $N i B begin
AI = @nref $N A i
(@nref $N B i) = (AI, AI)
end
Bmax = sB
Istart = Int[sB[i] == 1 != sA[i] ? 2 : 1 for i = 1:ndims(A)]
@inbounds @nloops $N i d->(Istart[d]:size(A,d)) begin
AI = @nref $N A i
@nexprs $N d->(j_d = min(Bmax[d], i_{d}))
BJ = @nref $N B j
if AI < BJ[1]
(@nref $N B j) = (AI, BJ[2])
elseif AI > BJ[2]
(@nref $N B j) = (BJ[1], AI)
end
@noinline function extrema!(B, A)
sA = size(A)
sB = size(B)
for I in CartesianRange(sB)
AI = A[I]
B[I] = (AI, AI)
end
Bmax = CartesianIndex(sB)
@inbounds @simd for I in CartesianRange(sA)
J = min(Bmax,I)
BJ = B[J]
AI = A[I]
if AI < BJ[1]
B[J] = (AI, BJ[2])
elseif AI > BJ[2]
B[J] = (BJ[1], AI)
end
return B
end
return B
end
14 changes: 11 additions & 3 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,17 @@ prod2(itr) = invoke(prod, Tuple{Any}, itr)
@test maximum(collect(Int16(1):Int16(100))) === Int16(100)
@test maximum(Int32[1,2]) === Int32(2)

@test extrema(reshape(1:24,2,3,4),1) == reshape([(1,2),(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24)],1,3,4)
@test extrema(reshape(1:24,2,3,4),2) == reshape([(1,5),(2,6),(7,11),(8,12),(13,17),(14,18),(19,23),(20,24)],2,1,4)
@test extrema(reshape(1:24,2,3,4),3) == reshape([(1,19),(2,20),(3,21),(4,22),(5,23),(6,24)],2,3,1)
A = circshift(reshape(1:24,2,3,4), (0,1,1))
@test extrema(A,1) == reshape([(23,24),(19,20),(21,22),(5,6),(1,2),(3,4),(11,12),(7,8),(9,10),(17,18),(13,14),(15,16)],1,3,4)
@test extrema(A,2) == reshape([(19,23),(20,24),(1,5),(2,6),(7,11),(8,12),(13,17),(14,18)],2,1,4)
@test extrema(A,3) == reshape([(5,23),(6,24),(1,19),(2,20),(3,21),(4,22)],2,3,1)
@test extrema(A,(1,2)) == reshape([(19,24),(1,6),(7,12),(13,18)],1,1,4)
@test extrema(A,(1,3)) == reshape([(5,24),(1,20),(3,22)],1,3,1)
@test extrema(A,(2,3)) == reshape([(1,23),(2,24)],2,1,1)
@test extrema(A,(1,2,3)) == reshape([(1,24)],1,1,1)
@test size(extrema(A,1)) == size(maximum(A,1))
@test size(extrema(A,(1,2))) == size(maximum(A,(1,2)))
@test size(extrema(A,(1,2,3))) == size(maximum(A,(1,2,3)))

# any & all

Expand Down

0 comments on commit 01f3a50

Please sign in to comment.