diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc index 09d9473df4795..bc5ab093b4534 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.cc @@ -18,8 +18,10 @@ #include "arrow/util/utf8.h" #include "arrow/matlab/array/proxy/array.h" +#include "arrow/matlab/array/proxy/wrap.h" #include "arrow/matlab/bit/unpack.h" #include "arrow/matlab/error/error.h" +#include "arrow/matlab/index/validate.h" #include "arrow/matlab/type/proxy/wrap.h" #include "arrow/pretty_print.h" #include "arrow/type_traits.h" @@ -38,7 +40,7 @@ namespace arrow::matlab::array::proxy { REGISTER_METHOD(Array, getValid); REGISTER_METHOD(Array, getType); REGISTER_METHOD(Array, isEqual); - + REGISTER_METHOD(Array, slice); } std::shared_ptr Array::unwrap() { @@ -144,4 +146,36 @@ namespace arrow::matlab::array::proxy { mda::ArrayFactory factory; context.outputs[0] = factory.createScalar(is_equal); } + + void Array::slice(libmexclass::proxy::method::Context& context) { + namespace mda = ::matlab::data; + + mda::StructArray opts = context.inputs[0]; + const mda::TypedArray offset_mda = opts[0]["Offset"]; + const mda::TypedArray length_mda = opts[0]["Length"]; + + const auto matlab_offset = int64_t(offset_mda[0]); + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(arrow::matlab::index::validateSliceOffset(matlab_offset), + context, error::ARRAY_SLICE_NON_POSITIVE_OFFSET); + + // Note: MATLAB uses 1-based indexing, so subtract 1. + const int64_t offset = matlab_offset - 1; + const int64_t length = int64_t(length_mda[0]); + MATLAB_ERROR_IF_NOT_OK_WITH_CONTEXT(arrow::matlab::index::validateSliceLength(length), + context, error::ARRAY_SLICE_NEGATIVE_LENGTH); + + auto sliced_array = array->Slice(offset, length); + const auto type_id = static_cast(sliced_array->type_id()); + MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto sliced_array_proxy, + array::proxy::wrap(sliced_array), + context, error::ARRAY_SLICE_FAILED_TO_CREATE_ARRAY_PROXY); + + const auto proxy_id = libmexclass::proxy::ProxyManager::manageProxy(sliced_array_proxy); + + mda::ArrayFactory factory; + mda::StructArray output = factory.createStructArray({1, 1}, {"ProxyID", "TypeID"}); + output[0]["ProxyID"] = factory.createScalar(proxy_id); + output[0]["TypeID"] = factory.createScalar(type_id); + context.outputs[0] = output; + } } diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/array.h b/matlab/src/cpp/arrow/matlab/array/proxy/array.h index 0ab7b279bc92e..1e3164ed01a72 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/array.h +++ b/matlab/src/cpp/arrow/matlab/array/proxy/array.h @@ -44,6 +44,8 @@ class Array : public libmexclass::proxy::Proxy { void isEqual(libmexclass::proxy::method::Context& context); + void slice(libmexclass::proxy::method::Context& context); + std::shared_ptr array; }; diff --git a/matlab/src/cpp/arrow/matlab/error/error.h b/matlab/src/cpp/arrow/matlab/error/error.h index 5aa8f05c8c315..e6be411b62a05 100644 --- a/matlab/src/cpp/arrow/matlab/error/error.h +++ b/matlab/src/cpp/arrow/matlab/error/error.h @@ -206,4 +206,8 @@ namespace arrow::matlab::error { static const char* ARRAY_VALIDATE_MINIMAL_FAILED = "arrow:array:ValidateMinimalFailed"; static const char* ARRAY_VALIDATE_FULL_FAILED = "arrow:array:ValidateFullFailed"; static const char* ARRAY_VALIDATE_UNSUPPORTED_ENUM = "arrow:array:ValidateUnsupportedEnum"; + static const char* ARRAY_SLICE_NON_POSITIVE_OFFSET = "arrow:array:slice:NonPositiveOffset"; + static const char* ARRAY_SLICE_NEGATIVE_LENGTH = "arrow:array:slice:NegativeLength"; + static const char* ARRAY_SLICE_FAILED_TO_CREATE_ARRAY_PROXY = "arrow:array:slice:FailedToCreateArrayProxy"; + } diff --git a/matlab/src/cpp/arrow/matlab/index/validate.cc b/matlab/src/cpp/arrow/matlab/index/validate.cc index b24653f1b814c..84e8e424e171f 100644 --- a/matlab/src/cpp/arrow/matlab/index/validate.cc +++ b/matlab/src/cpp/arrow/matlab/index/validate.cc @@ -53,4 +53,20 @@ namespace arrow::matlab::index { } return arrow::Status::OK(); } + + arrow::Status validateSliceOffset(const int64_t matlab_offset) { + if (matlab_offset < 1) { + const std::string msg = "Slice offset must be positive"; + return arrow::Status::Invalid(std::move(msg)); + } + return arrow::Status::OK(); + } + + arrow::Status validateSliceLength(const int64_t length) { + if (length < 0) { + const std::string msg = "Slice length must be nonnegative"; + return arrow::Status::Invalid(std::move(msg)); + } + return arrow::Status::OK(); + } } \ No newline at end of file diff --git a/matlab/src/cpp/arrow/matlab/index/validate.h b/matlab/src/cpp/arrow/matlab/index/validate.h index 40e109c19e9ef..2fa88ef8f1b5a 100644 --- a/matlab/src/cpp/arrow/matlab/index/validate.h +++ b/matlab/src/cpp/arrow/matlab/index/validate.h @@ -23,4 +23,7 @@ namespace arrow::matlab::index { arrow::Status validateNonEmptyContainer(const int32_t num_fields); arrow::Status validateInRange(const int32_t matlab_index, const int32_t num_fields); + arrow::Status validateSliceOffset(const int64_t matlab_offset); + arrow::Status validateSliceLength(const int64_t length); + } \ No newline at end of file diff --git a/matlab/src/matlab/+arrow/+array/Array.m b/matlab/src/matlab/+arrow/+array/Array.m index 293ad87ad4316..4402055932b60 100644 --- a/matlab/src/matlab/+arrow/+array/Array.m +++ b/matlab/src/matlab/+arrow/+array/Array.m @@ -98,4 +98,14 @@ function displayScalarObject(obj) tf = obj.Proxy.isEqual(proxyIDs); end end + + methods (Hidden) + function array = slice(obj, offset, length) + sliceStruct = struct(Offset=offset, Length=length); + arrayStruct = obj.Proxy.slice(sliceStruct); + traits = arrow.type.traits.traits(arrow.type.ID(arrayStruct.TypeID)); + proxy = libmexclass.proxy.Proxy(Name=traits.ArrayProxyClassName, ID=arrayStruct.ProxyID); + array = traits.ArrayConstructor(proxy); + end + end end diff --git a/matlab/test/arrow/array/tSlice.m b/matlab/test/arrow/array/tSlice.m new file mode 100644 index 0000000000000..c99503371a41c --- /dev/null +++ b/matlab/test/arrow/array/tSlice.m @@ -0,0 +1,138 @@ +%TSLICE Unit tests verifying the behavior of arrow.array.Array's slice +%method. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +classdef tSlice < matlab.unittest.TestCase + + methods(Test) + function BooleanArray(testCase) + % Verify the slice method returns the expected array when + % called on a Boolean Array. + boolArray = arrow.array([true true false true true false], Valid=[1 2 3 6]); + slice = boolArray.slice(int64(2), int64(4)); + testCase.verifyEqual(slice.NumElements, int64(4)); + testCase.verifyEqual(slice.Valid, [true; true; false; false]); + testCase.verifyEqual(toMATLAB(slice), [true; false; false; false]); + end + + function NumericArray(testCase) + % Verify the slice method returns the expected array when + % called on a Numeric Array. + float64Array = arrow.array(1:10, Valid=[2 3 4 5 8 10]); + slice = float64Array.slice(int64(4), int64(5)); + testCase.verifyEqual(slice.NumElements, int64(5)); + testCase.verifyEqual(slice.Valid, [true; true; false; false; true]); + testCase.verifyEqual(toMATLAB(slice), [4; 5; NaN; NaN; 8]); + end + + function DateArray(testCase) + % Verify the slice method returns the expected array when + % called on a Date Array. + import arrow.array.Date32Array + dates = datetime(2023, 11, 8:16); + date32Array = Date32Array.fromMATLAB(dates, Valid=[4 5 6 9]); + slice = date32Array.slice(int64(3), int64(4)); + testCase.verifyEqual(slice.NumElements, int64(4)); + testCase.verifyEqual(slice.Valid, [false; true; true; true]); + expected = [NaT; dates(4:6)']; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function TimeArray(testCase) + % Verify the slice method returns the expected array when + % called on a Time Array. + times = seconds(10:20); + time64Array = arrow.array(times, Valid=[2 4 6 7 8 10]); + slice = time64Array.slice(int64(5), int64(6)); + testCase.verifyEqual(slice.NumElements, int64(6)); + testCase.verifyEqual(slice.Valid, [false; true; true; true; false; true]); + expected = [NaN; times(6:8)'; NaN; times(10)]; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function TimestampArray(testCase) + % Verify the slice method returns the expected array when + % called on a TimestampArray. + dates = datetime(2023, 11, 8:16); + date32Array = arrow.array(dates, Valid=[1 2 4 5 6 8]); + slice = date32Array.slice(int64(5), int64(3)); + testCase.verifyEqual(slice.NumElements, int64(3)); + testCase.verifyEqual(slice.Valid, [true; true; false]); + expected = [dates(5:6)'; NaT]; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function StringArray(testCase) + % Verify the slice method returns the expected array when + % called on a StringArray. + stringArray = arrow.array(["a" "b" "c" "d" "e" "f" "g"], Valid=[1 3 4 5 6]); + slice = stringArray.slice(int64(2), int64(3)); + testCase.verifyEqual(slice.NumElements, int64(3)); + testCase.verifyEqual(slice.Valid, [false; true; true]); + expected = [missing; "c"; "d"]; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function ListArray(testCase) + % Verify the slice method returns the expected array when + % called on a ListArray. + cellArray = {missing, [1, 2, 3], missing, [4, NaN], [6, 7, 8], missing}; + listArray = arrow.array(cellArray); + slice = listArray.slice(int64(2), int64(4)); + testCase.verifyEqual(slice.NumElements, int64(4)); + testCase.verifyEqual(slice.Valid, [true; false; true; true]); + expected = {[1; 2; 3]; missing; [4; NaN]; [6; 7; 8]}; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function StructArray(testCase) + % Verify the slice method returns the expected array when + % called on a StructArray. + numbers = [NaN; 2; 3; 4; 5; 6; 7; NaN; 9; 10]; + text = ["a"; missing; "c"; "d"; "e"; missing; "g"; "h"; "i"; "j"]; + t = table(numbers, text); + structArray = arrow.array(t, Valid=[1 2 3 6 7 8 10]); + slice = structArray.slice(int64(5), int64(4)); + testCase.verifyEqual(slice.NumElements, int64(4)); + testCase.verifyEqual(slice.Valid, [false; true; true; true]); + expected = t(5:8, :); + expected.numbers(1) = NaN; + expected.text(1) = missing; + testCase.verifyEqual(toMATLAB(slice), expected); + end + + function NonPositiveOffsetError(testCase) + % Verify the slice method throws an error whose identifier is + % "arrow:array:slice:NonPositiveOffset" if given a non-positive + % value as the offset. + array = arrow.array(1:10); + fcn = @() array.slice(int64(0), int64(2)); + testCase.verifyError(fcn, "arrow:array:slice:NonPositiveOffset"); + fcn = @() array.slice(int64(-1), int64(2)); + testCase.verifyError(fcn, "arrow:array:slice:NonPositiveOffset"); + end + + function NegativeLengthError(testCase) + % Verify the slice method throws an error whose identifier is + % "arrow:array:slice:NegativeLength" if given a negative value + % as the length. + array = arrow.array(1:10); + fcn = @() array.slice(int64(1), int64(-1)); + testCase.verifyError(fcn, "arrow:array:slice:NegativeLength"); + end + end +end \ No newline at end of file