diff --git a/NEWS.md b/NEWS.md index 77ef0b766..fc998538f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -26,6 +26,25 @@ 2.0 4.0 ``` +* `NDArray` `getindex`/`setindex!` linear indexing support and `first` for extracting scalar value. (#TBD) + + ```julia + julia> x = mx.zeros(2, 5) + + julia> x[5] = 42 # do synchronization and set the value + ``` + + ```julia + julia> y = x[5] # actually, getindex won't do synchronization, but REPL's showing did it for you + 1 mx.NDArray{Float32} @ CPU0: + 42.0 + + julia> first(y) # do sync and get the value + 42.0f0 + + julia> y[] # this is available, also + 42.0f0 + ``` * Elementwise power of `NDArray`. (#293) * `x.^2` * `2.^x` diff --git a/src/ndarray.jl b/src/ndarray.jl index 3e7625e3f..fb495af34 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -312,6 +312,9 @@ function eltype(arr :: T) where T <: Union{NDArray, MX_NDArrayHandle} end end +@inline _first(arr::NDArray) = try_get_shared(arr, sync = :read) |> first + +Base.first(arr::NDArray) = _first(arr) """ slice(arr :: NDArray, start:stop) @@ -341,37 +344,58 @@ function slice(arr :: NDArray, slice::UnitRange{Int}) return NDArray(MX_NDArrayHandle(hdr_ref[]), arr.writable) end +function _at(handle::Union{MX_NDArrayHandle, MX_handle}, idx::Integer) + h_ref = Ref{MX_handle}(C_NULL) + @mxcall(:MXNDArrayAt, (MX_handle, MX_uint, Ref{MX_handle}), + handle, idx, h_ref) + h_ref[] +end + import Base: setindex! """ - setindex!(arr :: NDArray, val, idx) + setindex!(arr::NDArray, val, idx) -Assign values to an `NDArray`. Elementwise assignment is not implemented, only the following -scenarios are supported +Assign values to an `NDArray`. +The following scenarios are supported + +* single value assignment via linear indexing: `arr[42] = 24` * `arr[:] = val`: whole array assignment, `val` could be a scalar or an array (Julia `Array` or `NDArray`) of the same shape. * `arr[start:stop] = val`: assignment to a *slice*, `val` could be a scalar or an array of the same shape to the slice. See also [`slice`](@ref). """ -function setindex!(arr :: NDArray, val :: Real, ::Colon) - @assert(arr.writable) +function setindex!(arr::NDArray, val::Real, idx::Integer) + # linear indexing + @assert arr.writable + _set_value(out=arr[idx], src=val) +end + +function setindex!(arr::NDArray, val::Real, ::Colon) + @assert arr.writable _set_value(out=arr, src=convert(eltype(arr), val)) - return arr end -function setindex!(arr :: NDArray, val :: Array{T}, ::Colon) where T<:Real + +function setindex!(arr::NDArray, val::Array{T}, ::Colon) where T<:Real + @assert arr.writable copy!(arr, val) end -function setindex!(arr :: NDArray, val :: NDArray, ::Colon) + +function setindex!(arr::NDArray, val::NDArray, ::Colon) + @assert arr.writable copy!(arr, val) end -function setindex!(arr :: NDArray, val :: Union{T,Array{T},NDArray}, idx::UnitRange{Int}) where T<:Real + +function setindex!(arr::NDArray, val::Union{T,Array{T},NDArray}, + idx::UnitRange{Int}) where T<:Real + @assert arr.writable setindex!(slice(arr, idx), val, Colon()) end import Base: getindex """ - getindex(arr :: NDArray, idx) + getindex(arr::NDArray, idx) Shortcut for [`slice`](@ref). A typical use is to write @@ -396,18 +420,43 @@ which furthur translates into create a **copy** of the sub-array for Julia `Array`, while for `NDArray`, this is a *slice* that shares the memory. """ -function getindex(arr :: NDArray, ::Colon) +function getindex(arr::NDArray, ::Colon) return arr end """ -Shortcut for [`slice`](@ref). **NOTE** the behavior for Julia's built-in index slicing is to create a -copy of the sub-array, while here we simply call `slice`, which shares the underlying memory. +Shortcut for [`slice`](@ref). +**NOTE** the behavior for Julia's built-in index slicing is to create a +copy of the sub-array, while here we simply call `slice`, +which shares the underlying memory. """ -function getindex(arr :: NDArray, idx::UnitRange{Int}) +function getindex(arr::NDArray, idx::UnitRange{Int}) slice(arr, idx) end +getindex(arr::NDArray) = _first(arr) + +function getindex(arr::NDArray, idx::Integer) + # linear indexing + len = length(arr) + size_ = size(arr) + + if idx <= 0 || idx > len + throw(BoundsError( + "attempt to access $(join(size_, 'x')) NDArray at index $(idx)")) + end + + idx -= 1 + offsets = size_[1:end-1] |> reverse โˆ˜ cumprod โˆ˜ collect + handle = arr.handle + for offset โˆˆ offsets + handle = _at(handle, idx รท offset) + idx %= offset + end + + _at(handle, idx) |> MX_NDArrayHandle |> x -> NDArray(x, arr.writable) +end + import Base: copy!, copy, convert, deepcopy """ diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 2299e0483..ac9090e3c 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -85,6 +85,68 @@ function test_slice() @test copy(mx.slice(array, 2:3)) == [1 1; 1 1] end +function test_linear_idx() + info("NDArray::getindex::linear indexing") + let A = reshape(collect(1:30), 3, 10) + x = mx.NDArray(A) + + @test copy(x) == A + @test copy(x[1]) == [1] + @test copy(x[2]) == [2] + @test copy(x[3]) == [3] + @test copy(x[12]) == [12] + @test copy(x[13]) == [13] + @test copy(x[14]) == [14] + + @test_throws BoundsError x[-1] + @test_throws BoundsError x[0] + @test_throws BoundsError x[31] + @test_throws BoundsError x[42] + end + + let A = reshape(collect(1:24), 3, 2, 4) + x = mx.NDArray(A) + + @test copy(x) == A + @test copy(x[1]) == [1] + @test copy(x[2]) == [2] + @test copy(x[3]) == [3] + @test copy(x[11]) == [11] + @test copy(x[12]) == [12] + @test copy(x[13]) == [13] + @test copy(x[14]) == [14] + end + + info("NDArray::setindex!::linear indexing") + let A = reshape(collect(1:24), 3, 2, 4) + x = mx.NDArray(A) + + @test copy(x) == A + + x[4] = -4 + @test copy(x[4]) == [-4] + + x[11] = -11 + @test copy(x[11]) == [-11] + + x[24] = 42 + @test copy(x[24]) == [42] + end +end # function test_linear_idx + +function test_first() + info("NDArray::first") + let A = reshape(collect(1:30), 3, 10) + x = mx.NDArray(A) + + @test x[] == 1 + @test x[5][] == 5 + + @test first(x) == 1 + @test first(x[5]) == 5 + end +end # function test_first + function test_plus() dims = rand_dims() t1, a1 = rand_tensors(dims) @@ -668,6 +730,8 @@ end test_assign() test_copy() test_slice() + test_linear_idx() + test_first() test_plus() test_minus() test_mul()