From 91a492757a2ebbe338258ecc1bcbd854e1fc7370 Mon Sep 17 00:00:00 2001 From: sgilmore10 <74676073+sgilmore10@users.noreply.github.com> Date: Wed, 23 Aug 2023 17:14:51 -0400 Subject: [PATCH] GH-37340: [MATLAB] The `column(index)` method of `arrow.tabular.RecordBatch` errors if `index` refers to an `arrow.array.Time32Array` column (#37347) ### Rationale for this change The `column(index)` method of `arrow.tabular.RecordBatch` errors if `index` refers to an `arrow.array.Time32Array` column. ```matlab >> string = arrow.array(["A", "B", "C"]); >> time32 = arrow.array.Time32Array.fromMATLAB(seconds(1:3)); >> rb = arrow.tabular.RecordBatch.fromArrays(string, time32, ColumnNames=["String", "Time32"]) rb = String: [ "A", "B", "C" ] Time32: [ 00:00:01, 00:00:02, 00:00:03 ] >> time32Column = rb.column(2) Error using . Unsupported DataType: time32[s] Error in arrow.tabular.RecordBatch/column (line 63) [proxyID, typeID] = obj.Proxy.getColumnByIndex(args); ``` `column(index)` is throwing this error because the case for `Time32` was not added to `arrow::matlab::array::proxy::wrap(arrowArray)` in #37315. `column(index)` calls this function to create proxy objects around existing arrow arrays. Adding the case for `Time32` will resolve this bug. ### What changes are included in this PR? 1. Updated `arrow::Result wrap(const std::shared_ptr& array)` to handle wrapping `arrow::Time32Array` instances within `proxy::Time32Array`s. 2. Added a new test utility called `arrow.internal.test.tabular.createAllSupportedArrayTypes`, which returns two cell arrays: `arrowArrays` and `matlabData`. The `arrowArrays` cell array contains one instance of each concrete subclass of `arrow.array.Array`. The `matlabData` cell array contains the MATLAB arrays used to generate each array in `arrowArrays`. This utility can be used to create an `arrow.array.RecordBatch` that contains all the arrow array types that are supported by the MATLAB interface. ### Are these changes tested? Yes. Updated the `fromArrays` test cases in `tRecordBatch` to verify the `column(index)` method of `arrow.type.RecordBatch` supports extracting columns of any arrow type (supported by the MATLAB Interface). ### Are there any user-facing changes? Yes. Fixed a bug in `arrow.tabular.RecordBatch`. ### Future Directions 1. #37345 * Closes: #37340 Authored-by: Sarah Gilmore Signed-off-by: Kevin Gurney --- .../src/cpp/arrow/matlab/array/proxy/wrap.cc | 2 + .../+tabular/createAllSupportedArrayTypes.m | 104 ++++++++++++++++++ matlab/test/arrow/tabular/tRecordBatch.m | 39 ++++--- 3 files changed, 125 insertions(+), 20 deletions(-) create mode 100644 matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m diff --git a/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc b/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc index dab09359598d4..e11e9bb7669b1 100644 --- a/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc +++ b/matlab/src/cpp/arrow/matlab/array/proxy/wrap.cc @@ -51,6 +51,8 @@ namespace arrow::matlab::array::proxy { return std::make_shared>(std::static_pointer_cast(array)); case ID::TIMESTAMP: return std::make_shared>(std::static_pointer_cast(array)); + case ID::TIME32: + return std::make_shared>(std::static_pointer_cast(array)); case ID::STRING: return std::make_shared(std::static_pointer_cast(array)); default: diff --git a/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m b/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m new file mode 100644 index 0000000000000..7d3f36cb46e7c --- /dev/null +++ b/matlab/src/matlab/+arrow/+internal/+test/+tabular/createAllSupportedArrayTypes.m @@ -0,0 +1,104 @@ +%CREATEALLSUPPORTEDARRAYTYPES Creates a MATLAB cell array containing all +%the concrete subclasses of arrow.array.Array. Returns a cell array +%containing the MATLAB data from which the arrow arrays were generated +%as second output argument. + +% Licensed to the Apache Software Foundation (ASF) under one or more +% contributor license agreements. See the NOTICE file distributed with +% this work for additional information regarding copyright ownership. +% The ASF licenses this file to you under the Apache License, Version +% 2.0 (the "License"); you may not use this file except in compliance +% with the License. You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +% implied. See the License for the specific language governing +% permissions and limitations under the License. + +function [arrowArrays, matlabData] = createAllSupportedArrayTypes(opts) + arguments + opts.NumRows(1, 1) {mustBeFinite, mustBeNonnegative} = 3; + end + + import arrow.type.ID + import arrow.array.* + + classes = getArrayClassNames(); + numClasses = numel(classes); + arrowArrays = cell(numClasses, 1); + matlabData = cell(numClasses, 1); + + numericArrayToMatlabTypeDict = getArrowArrayToMatlabTypeDictionary(); + + for ii = 1:numel(classes) + name = classes(ii); + if name == "arrow.array.BooleanArray" + matlabData{ii} = randomLogicals(opts.NumRows); + arrowArrays{ii} = BooleanArray.fromMATLAB(matlabData{ii}); + elseif isKey(numericArrayToMatlabTypeDict, name) + matlabType = numericArrayToMatlabTypeDict(name); + matlabData{ii} = randomNumbers(matlabType, opts.NumRows); + cmd = compose("%s.fromMATLAB(matlabData{ii})", name); + arrowArrays{ii} = eval(cmd); + elseif name == "arrow.array.StringArray" + matlabData{ii} = randomStrings(opts.NumRows); + arrowArrays{ii} = StringArray.fromMATLAB(matlabData{ii}); + elseif name == "arrow.array.TimestampArray" + matlabData{ii} = randomDatetimes(opts.NumRows); + arrowArrays{ii} = TimestampArray.fromMATLAB(matlabData{ii}); + elseif name == "arrow.array.Time32Array" + matlabData{ii} = randomDurations(opts.NumRows); + arrowArrays{ii} = Time32Array.fromMATLAB(matlabData{ii}); + else + error("arrow:test:SupportedArrayCase", ... + "Missing if-branch for array class " + name); + end + end +end + +function classes = getArrayClassNames() + metaClass = meta.package.fromName("arrow.array").ClassList; + + % Removes all Abstract classes from the list of all subclasses + abstract = [metaClass.Abstract]; + metaClass(abstract) = []; + classes = string({metaClass.Name}); +end + +function dict = getArrowArrayToMatlabTypeDictionary() + pkg = "arrow.array"; + unsignedTypes = compose("UInt%d", power(2, 3:6)); + signedTypes = compose("Int%d", power(2, 3:6)); + floatTypes = compose("Float%d", power(2, 5:6)); + numericTypes = [unsignedTypes, signedTypes, floatTypes]; + keys = compose("%s.%sArray", pkg, numericTypes); + + values = [lower([unsignedTypes, signedTypes]) "single" "double"]; + dict = dictionary(keys, values); +end + +function number = randomNumbers(numberType, numElements) + number = cast(randi(255, [numElements 1]), numberType); +end + +function text = randomStrings(numElements) + text = string(randi(255, [numElements 1])); +end + +function tf = randomLogicals(numElements) + number = randi(2, [numElements 1]) - 1; + tf = logical(number); +end + +function times = randomDurations(numElements) + number = randi(255, [numElements 1]); + times = seconds(number); +end + +function dates = randomDatetimes(numElements) + day = days(randi(255, [numElements 1])); + dates = datetime(2023, 8, 23) + day; +end \ No newline at end of file diff --git a/matlab/test/arrow/tabular/tRecordBatch.m b/matlab/test/arrow/tabular/tRecordBatch.m index 027195d3d8250..d9c3c98652b08 100644 --- a/matlab/test/arrow/tabular/tRecordBatch.m +++ b/matlab/test/arrow/tabular/tRecordBatch.m @@ -28,11 +28,17 @@ function Basic(tc) function SupportedTypes(tc) % Create a table all supported MATLAB types. import arrow.internal.test.tabular.createTableWithSupportedTypes + import arrow.type.traits.traits TOriginal = createTableWithSupportedTypes(); arrowRecordBatch = arrow.recordBatch(TOriginal); expectedColumnNames = string(TOriginal.Properties.VariableNames); - tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, TOriginal); + + % For each variable in the input MATLAB table, look up the + % corresponding Arrow Array type using type traits. + expectedArrayClasses = varfun(@(var) traits(string(class(var))).ArrayClassName, ... + TOriginal, OutputFormat="uniform"); + tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, expectedArrayClasses, TOriginal); end function ToMATLAB(tc) @@ -125,19 +131,16 @@ function FromArraysColumnNamesNotProvided(tc) % RecordBatch when given a comma-separated list of % arrow.array.Array values. import arrow.tabular.RecordBatch - import arrow.internal.test.tabular.createTableWithSupportedTypes + import arrow.internal.test.tabular.createAllSupportedArrayTypes - TOriginal = createTableWithSupportedTypes(); - - arrowArrays = cell([1 width(TOriginal)]); - for ii = 1:width(TOriginal) - arrowArrays{ii} = arrow.array(TOriginal.(ii)); - end + [arrowArrays, matlabData] = createAllSupportedArrayTypes(); + TOriginal = table(matlabData{:}); arrowRecordBatch = RecordBatch.fromArrays(arrowArrays{:}); expectedColumnNames = compose("Column%d", 1:width(TOriginal)); TOriginal.Properties.VariableNames = expectedColumnNames; - tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, TOriginal); + expectedArrayClasses = cellfun(@(c) string(class(c)), arrowArrays, UniformOutput=true); + tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, expectedArrayClasses, TOriginal); end function FromArraysWithColumnNamesProvided(tc) @@ -145,19 +148,16 @@ function FromArraysWithColumnNamesProvided(tc) % RecordBatch when given a comma-separated list of % arrow.array.Array values and the ColumnNames nv-pair is provided. import arrow.tabular.RecordBatch - import arrow.internal.test.tabular.createTableWithSupportedTypes - - TOriginal = createTableWithSupportedTypes(); + import arrow.internal.test.tabular.createAllSupportedArrayTypes - arrowArrays = cell([1 width(TOriginal)]); - for ii = 1:width(TOriginal) - arrowArrays{ii} = arrow.array(TOriginal.(ii)); - end + [arrowArrays, matlabData] = createAllSupportedArrayTypes(); + TOriginal = table(matlabData{:}); columnNames = compose("MyVar%d", 1:numel(arrowArrays)); arrowRecordBatch = RecordBatch.fromArrays(arrowArrays{:}, ColumnNames=columnNames); TOriginal.Properties.VariableNames = columnNames; - tc.verifyRecordBatch(arrowRecordBatch, columnNames, TOriginal); + expectedArrayClasses = cellfun(@(c) string(class(c)), arrowArrays, UniformOutput=true); + tc.verifyRecordBatch(arrowRecordBatch, columnNames, expectedArrayClasses, TOriginal); end function FromArraysUnequalArrayLengthsError(tc) @@ -226,7 +226,7 @@ function SchemaNoSetter(tc) end methods - function verifyRecordBatch(tc, recordBatch, expectedColumnNames, expectedTable) + function verifyRecordBatch(tc, recordBatch, expectedColumnNames, expectedArrayClasses, expectedTable) tc.verifyEqual(recordBatch.NumColumns, int32(width(expectedTable))); tc.verifyEqual(recordBatch.ColumnNames, expectedColumnNames); convertedTable = recordBatch.table(); @@ -234,8 +234,7 @@ function verifyRecordBatch(tc, recordBatch, expectedColumnNames, expectedTable) for ii = 1:recordBatch.NumColumns column = recordBatch.column(ii); tc.verifyEqual(column.toMATLAB(), expectedTable{:, ii}); - traits = arrow.type.traits.traits(string(class(expectedTable{:, ii}))); - tc.verifyInstanceOf(column, traits.ArrayClassName); + tc.verifyInstanceOf(column, expectedArrayClasses(ii)); end end end