From 0060c86668b72de8a1dbec010cea91246f1cb4cc Mon Sep 17 00:00:00 2001 From: jishnub Date: Thu, 3 Jun 2021 16:46:23 +0400 Subject: [PATCH] make reshape accept Integers and ranges --- base/reshapedarray.jl | 9 ++++----- test/abstractarray.jl | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index 671dd2d86a840..35de26d2463a8 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -107,15 +107,14 @@ julia> reshape(1:6, 2, 3) """ reshape -reshape(parent::AbstractArray, dims::IntOrInd...) = reshape(parent, dims) +reshape(parent::AbstractArray, dims::Union{Integer, Colon, AbstractUnitRange{<:Integer}}...) = reshape(parent, dims) +reshape(parent::AbstractArray, shp::Tuple{Integer, Vararg{Integer}}) = reshape(parent, to_shape(shp)) reshape(parent::AbstractArray, shp::Tuple{Union{Integer,OneTo}, Vararg{Union{Integer,OneTo}}}) = reshape(parent, to_shape(shp)) reshape(parent::AbstractArray, dims::Dims) = _reshape(parent, dims) # Allow missing dimensions with Colon(): reshape(parent::AbstractVector, ::Colon) = parent -reshape(parent::AbstractArray, dims::Int...) = reshape(parent, dims) -reshape(parent::AbstractArray, dims::Union{Int,Colon}...) = reshape(parent, dims) -reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape(parent, _reshape_uncolon(parent, dims)) +reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Integer, Colon, Base.OneTo}}}) = reshape(parent, _reshape_uncolon(parent, dims)) @inline function _reshape_uncolon(A, dims) @noinline throw1(dims) = throw(DimensionMismatch(string("new dimensions $(dims) ", "may have at most one omitted dimension specified by `Colon()`"))) @@ -124,7 +123,7 @@ reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = reshape( pre = _before_colon(dims...) post = _after_colon(dims...) _any_colon(post...) && throw1(dims) - sz, remainder = divrem(length(A), prod(pre)*prod(post)) + sz, remainder = divrem(length(A), prod(to_shape(pre))*prod(to_shape(post))) remainder == 0 || throw2(A, dims) (pre..., Int(sz), post...) end diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 05ab7147e1f7e..dd8607f8454a8 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1410,3 +1410,23 @@ end @test_throws ArgumentError keepat!(a, [2, 1]) @test isempty(keepat!(a, [])) end + +@testset "reshape may mix axes, Integers and colon" begin + for a in Any[1:3, collect(1:3), reshape(collect(1:3), 1, 1, 3)] + r13 = reshape(a, 1, 3) + r31 = reshape(a, 3, 1) + @test reshape(a, Int8(1), Int8(3)) == r13 + @test reshape(a, 1, Int8(3)) == r13 + @test reshape(a, 1, Base.OneTo(3)) == r13 + @test reshape(a, 1, Base.OneTo(Int8(3))) == r13 + @test reshape(a, 1, :) == r13 + @test reshape(a, :, 1) == r31 + @test reshape(a, Base.OneTo(1), :) == r13 + @test reshape(a, :, Base.OneTo(1)) == r31 + @test reshape(a, Base.OneTo(Int8(1)), :) == r13 + @test reshape(a, Base.OneTo(1), Int8(3)) == r13 + @test reshape(a, Base.OneTo(Int8(1)), Int8(3)) == r13 + end + a = 1:3 + @test reshape(a, axes(a, 1), axes(a,2)) == reshape(a, size(a, 1), size(a,2)) +end