From cadb9171b11db6f22291254843c64fe49ebbf091 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Sat, 18 Jan 2020 13:54:07 +0100 Subject: [PATCH] fix single slice gives empty array instead of single element array closes #1. --- Sources/NdArray/NdArray.swift | 15 +++++++++++---- Tests/NdArrayTests/subscriptTests.swift | 8 ++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/Sources/NdArray/NdArray.swift b/Sources/NdArray/NdArray.swift index eee197a..197c82b 100644 --- a/Sources/NdArray/NdArray.swift +++ b/Sources/NdArray/NdArray.swift @@ -65,7 +65,7 @@ open class NdArray: CustomDebugStringConvertible, internal convenience init(empty shape: [Int], order: Contiguous = .C) { let n = shape.isEmpty ? 0 : shape.reduce(1, *) self.init(empty: n) - var success = reshape([n]) + var success = reshape([n]) assert(success, "could not reshape from [\(self.shape)] to \(n)") success = reshape(shape, order: order) assert(success, "could not reshape form [\(self.shape)] to \(shape)") @@ -386,9 +386,16 @@ open class NdArray: CustomDebugStringConvertible, // here we reduce the shape, hence slice = 0 let slice = NdArraySlice(self, startIndex: startIndex, sliced: 0) // drop leading shape 1 - slice.shape = Array(slice.shape[1...]) - slice.strides = Array(slice.strides[1...]) - slice.count = slice.len + let shape = [Int](slice.shape[1...]) + if shape.isEmpty { + slice.shape = [1] + slice.strides = [1] + slice.count = 1 + } else { + slice.shape = shape + slice.strides = Array(slice.strides[1...]) + slice.count = slice.len + } return slice } set { diff --git a/Tests/NdArrayTests/subscriptTests.swift b/Tests/NdArrayTests/subscriptTests.swift index ff1fcd0..408425d 100644 --- a/Tests/NdArrayTests/subscriptTests.swift +++ b/Tests/NdArrayTests/subscriptTests.swift @@ -935,4 +935,12 @@ class NdArraySliceSubscriptTests: XCTestCase { XCTAssertEqual(NdArray(copy: a[0], order: .C).dataArray, [0, 1, 2, 3, 4, 5]) } } + + func testSingleElementSlice1d() { + let a = NdArray.range(to: 4) + let s: NdArraySlice = a[2] + XCTAssertEqual(s.shape, [1]) + XCTAssertEqual(s.dataArray, [2]) + } + }