From 4faa0299ccd4f7ebbf9469f803c4c6b990459d78 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Mon, 9 Oct 2017 23:52:27 +0800 Subject: [PATCH 1/2] ndarray: getindex/setindex! linear indexing ```julia x = mx.zeros(2, 5) x[5] = 42 ``` --- NEWS.md | 7 +++++ src/ndarray.jl | 64 +++++++++++++++++++++++++++++++++------- test/unittest/ndarray.jl | 50 +++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index 773c2471c..ab9d83c96 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,13 @@ * `deepcopy` for NDArray (#273) +* NDArray getindex/setindex! linear indexing (#TBD) + e.g. + ```julia + x = mx.zeros(2, 5) + x[5] = 42 + ``` + ## API Changes * `reshape` of NDArray share the same interface with Base (#272). diff --git a/src/ndarray.jl b/src/ndarray.jl index eba7e2169..17b89e70b 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -336,31 +336,52 @@ function slice(arr :: NDArray, slice::UnitRange{Int}) return NDArray(MX_NDArrayHandle(hdr_ref[]), arr.writable) end +function _at(handle::Union{MX_NDArrayHandle, MX_handle}, idx::Integer) + h_ref = Ref{MX_handle}(C_NULL) + @mxcall(:MXNDArrayAt, (MX_handle, MX_uint, Ref{MX_handle}), + handle, idx, h_ref) + h_ref[] +end + import Base: setindex! """ setindex!(arr :: NDArray, val, idx) -Assign values to an `NDArray`. Elementwise assignment is not implemented, only the following -scenarios are supported +Assign values to an `NDArray`. +The following scenarios are supported + +* single value assignment via linear indexing: `arr[42] = 24` * `arr[:] = val`: whole array assignment, `val` could be a scalar or an array (Julia `Array` or `NDArray`) of the same shape. * `arr[start:stop] = val`: assignment to a *slice*, `val` could be a scalar or an array of the same shape to the slice. See also [`slice`](@ref). """ -function setindex!(arr :: NDArray, val :: Real, ::Colon) - @assert(arr.writable) +function setindex!(arr::NDArray, val::Real, idx::Integer) + # linear indexing + @assert arr.writable + _set_value(out=arr[idx], src=val) +end + +function setindex!(arr::NDArray, val::Real, ::Colon) + @assert arr.writable _set_value(out=arr, src=convert(eltype(arr), val)) - return arr end -function setindex!{T<:Real}(arr :: NDArray, val :: Array{T}, ::Colon) + +function setindex!{T<:Real}(arr::NDArray, val::Array{T}, ::Colon) + @assert arr.writable copy!(arr, val) end -function setindex!(arr :: NDArray, val :: NDArray, ::Colon) + +function setindex!(arr::NDArray, val::NDArray, ::Colon) + @assert arr.writable copy!(arr, val) end -function setindex!{T<:Real}(arr :: NDArray, val :: Union{T,Array{T},NDArray}, idx::UnitRange{Int}) + +function setindex!{T<:Real}(arr::NDArray, val::Union{T,Array{T},NDArray}, + idx::UnitRange{Int}) + @assert arr.writable setindex!(slice(arr, idx), val, Colon()) end @@ -396,13 +417,36 @@ function getindex(arr :: NDArray, ::Colon) end """ -Shortcut for [`slice`](@ref). **NOTE** the behavior for Julia's built-in index slicing is to create a -copy of the sub-array, while here we simply call `slice`, which shares the underlying memory. +Shortcut for [`slice`](@ref). +**NOTE** the behavior for Julia's built-in index slicing is to create a +copy of the sub-array, while here we simply call `slice`, +which shares the underlying memory. """ function getindex(arr :: NDArray, idx::UnitRange{Int}) slice(arr, idx) end +function getindex(arr :: NDArray, idx::Integer) + # linear indexing + len = length(arr) + size_ = size(arr) + + if idx <= 0 || idx > len + throw(BoundsError( + "attempt to access $(join(size_, 'x')) NDArray at index $(idx)")) + end + + idx -= 1 + offsets = size_[1:end-1] |> reverse โˆ˜ cumprod โˆ˜ collect + handle = arr.handle + for offset โˆˆ offsets + handle = _at(handle, idx รท offset) + idx %= offset + end + + _at(handle, idx) |> MX_NDArrayHandle |> x -> NDArray(x, arr.writable) +end + import Base: copy!, copy, convert, deepcopy """ diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 2185d920c..706f78353 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -85,6 +85,55 @@ function test_slice() @test copy(mx.slice(array, 2:3)) == [1 1; 1 1] end +function test_linear_idx() + info("NDArray::getindex::linear indexing") + let A = reshape(collect(1:30), 3, 10) + x = mx.NDArray(A) + + @test copy(x) == A + @test copy(x[1]) == [1] + @test copy(x[2]) == [2] + @test copy(x[3]) == [3] + @test copy(x[12]) == [12] + @test copy(x[13]) == [13] + @test copy(x[14]) == [14] + + @test_throws BoundsError x[-1] + @test_throws BoundsError x[0] + @test_throws BoundsError x[31] + @test_throws BoundsError x[42] + end + + let A = reshape(collect(1:24), 3, 2, 4) + x = mx.NDArray(A) + + @test copy(x) == A + @test copy(x[1]) == [1] + @test copy(x[2]) == [2] + @test copy(x[3]) == [3] + @test copy(x[11]) == [11] + @test copy(x[12]) == [12] + @test copy(x[13]) == [13] + @test copy(x[14]) == [14] + end + + info("NDArray::setindex!::linear indexing") + let A = reshape(collect(1:24), 3, 2, 4) + x = mx.NDArray(A) + + @test copy(x) == A + + x[4] = -4 + @test copy(x[4]) == [-4] + + x[11] = -11 + @test copy(x[11]) == [-11] + + x[24] = 42 + @test copy(x[24]) == [42] + end +end # function test_linear_idx + function test_plus() dims = rand_dims() t1, a1 = rand_tensors(dims) @@ -432,6 +481,7 @@ end test_assign() test_copy() test_slice() + test_linear_idx() test_plus() test_minus() test_mul() From 62bd2133afb844887fdb0e64a68e742cbcbed7ce Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Wed, 8 Nov 2017 19:47:45 +0800 Subject: [PATCH 2/2] ndarray: implement first --- NEWS.md | 21 +++++++++++++++++---- src/ndarray.jl | 15 ++++++++++----- test/unittest/ndarray.jl | 14 ++++++++++++++ 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/NEWS.md b/NEWS.md index a1b9fca32..112b1b5e0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -26,11 +26,24 @@ 2.0 4.0 ``` -* NDArray getindex/setindex! linear indexing (#TBD) - e.g. +* `NDArray` `getindex`/`setindex!` linear indexing support and `first` for extracting scalar value. (#TBD) + ```julia - x = mx.zeros(2, 5) - x[5] = 42 + julia> x = mx.zeros(2, 5) + + julia> x[5] = 42 # do synchronization and set the value + ``` + + ```julia + julia> y = x[5] # actually, getindex won't do synchronization, but REPL's showing did it for you + 1 mx.NDArray{Float32} @ CPU0: + 42.0 + + julia> first(y) # do sync and get the value + 42.0f0 + + julia> y[] # this is available, also + 42.0f0 ``` ## API Changes diff --git a/src/ndarray.jl b/src/ndarray.jl index ced344ad2..008c9556c 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -312,6 +312,9 @@ function eltype{T <: Union{NDArray, MX_NDArrayHandle}}(arr :: T) end end +@inline _first(arr::NDArray) = try_get_shared(arr, sync = :read) |> first + +Base.first(arr::NDArray) = _first(arr) """ slice(arr :: NDArray, start:stop) @@ -351,7 +354,7 @@ end import Base: setindex! """ - setindex!(arr :: NDArray, val, idx) + setindex!(arr::NDArray, val, idx) Assign values to an `NDArray`. The following scenarios are supported @@ -392,7 +395,7 @@ end import Base: getindex """ - getindex(arr :: NDArray, idx) + getindex(arr::NDArray, idx) Shortcut for [`slice`](@ref). A typical use is to write @@ -417,7 +420,7 @@ which furthur translates into create a **copy** of the sub-array for Julia `Array`, while for `NDArray`, this is a *slice* that shares the memory. """ -function getindex(arr :: NDArray, ::Colon) +function getindex(arr::NDArray, ::Colon) return arr end @@ -427,11 +430,13 @@ Shortcut for [`slice`](@ref). copy of the sub-array, while here we simply call `slice`, which shares the underlying memory. """ -function getindex(arr :: NDArray, idx::UnitRange{Int}) +function getindex(arr::NDArray, idx::UnitRange{Int}) slice(arr, idx) end -function getindex(arr :: NDArray, idx::Integer) +getindex(arr::NDArray) = _first(arr) + +function getindex(arr::NDArray, idx::Integer) # linear indexing len = length(arr) size_ = size(arr) diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 12a8fa335..80d0a6d68 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -134,6 +134,19 @@ function test_linear_idx() end end # function test_linear_idx +function test_first() + info("NDArray::first") + let A = reshape(collect(1:30), 3, 10) + x = mx.NDArray(A) + + @test x[] == 1 + @test x[5][] == 5 + + @test first(x) == 1 + @test first(x[5]) == 5 + end +end # function test_first + function test_plus() dims = rand_dims() t1, a1 = rand_tensors(dims) @@ -641,6 +654,7 @@ end test_copy() test_slice() test_linear_idx() + test_first() test_plus() test_minus() test_mul()