From f2d56c7a5bc4dfbfd649abc090c440abd51f85e9 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 8 Feb 2024 14:22:48 +0800 Subject: [PATCH] SubArray: avoid invalid elimination of singleton indices (#53228) close #53209 (cherry picked from commit 4d0a469e6f8fe30a5e152ac7c93b7569c41e3c39) --- base/subarray.jl | 14 ++++++++------ test/subarray.jl | 13 ++++++++++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/base/subarray.jl b/base/subarray.jl index 901410e908d1e..3fdabca88e79f 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -127,6 +127,11 @@ _maybe_reshape_parent(A::AbstractArray, ::NTuple{1, Bool}) = reshape(A, Val(1)) _maybe_reshape_parent(A::AbstractArray{<:Any,1}, ::NTuple{1, Bool}) = reshape(A, Val(1)) _maybe_reshape_parent(A::AbstractArray{<:Any,N}, ::NTuple{N, Bool}) where {N} = A _maybe_reshape_parent(A::AbstractArray, ::NTuple{N, Bool}) where {N} = reshape(A, Val(N)) +# The trailing singleton indices could be eliminated after bounds checking. +rm_singleton_indices(ndims::Tuple, J1, Js...) = (J1, rm_singleton_indices(IteratorsMD._splitrest(ndims, index_ndims(J1)), Js...)...) +rm_singleton_indices(::Tuple{}, ::ScalarIndex, Js...) = rm_singleton_indices((), Js...) +rm_singleton_indices(::Tuple) = () + """ view(A, inds...) @@ -173,15 +178,12 @@ julia> view(2:5, 2:3) # returns a range as type is immutable 3:4 ``` """ -function view(A::AbstractArray{<:Any,N}, I::Vararg{Any,M}) where {N,M} +function view(A::AbstractArray, I::Vararg{Any,M}) where {M} @inline J = map(i->unalias(A,i), to_indices(A, I)) @boundscheck checkbounds(A, J...) - if length(J) > ndims(A) && J[N+1:end] isa Tuple{Vararg{Int}} - # view([1,2,3], :, 1) does not need to reshape - return unsafe_view(A, J[1:N]...) - end - unsafe_view(_maybe_reshape_parent(A, index_ndims(J...)), J...) + J′ = rm_singleton_indices(ntuple(Returns(true), Val(ndims(A))), J...) + unsafe_view(_maybe_reshape_parent(A, index_ndims(J′...)), J′...) end # Ranges implement getindex to return recomputed ranges; use that for views, too (when possible) diff --git a/test/subarray.jl b/test/subarray.jl index e22c1394cbfc2..11d18185dd753 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -762,9 +762,9 @@ end @testset "issue #41221: view(::Vector, :, 1)" begin v = randn(3) - @test view(v,:,1) == v - @test parent(view(v,:,1)) === v - @test parent(view(v,2:3,1,1)) === v + @test @inferred(view(v,:,1)) == v + @test parent(@inferred(view(v,:,1))) === v + @test parent(@inferred(view(v,2:3,1,1))) === v @test_throws BoundsError view(v,:,2) @test_throws BoundsError view(v,:,1,2) @@ -772,3 +772,10 @@ end @test view(m, 1:2, 3, 1, 1) == m[1:2, 3] @test parent(view(m, 1:2, 3, 1, 1)) === m end + +@testset "issue #53209: avoid invalid elimination of singleton indices" begin + A = randn(4,5) + @test A[CartesianIndices(()), :, 3] == @inferred(view(A, CartesianIndices(()), :, 3)) + @test parent(@inferred(view(A, :, 3, 1, CartesianIndices(()), 1))) === A + @test_throws BoundsError view(A, :, 3, 2, CartesianIndices(()), 1) +end