From 19200cb8487f7c44a859f3e525eab3b056d1598f Mon Sep 17 00:00:00 2001 From: Alva Bandy Date: Thu, 6 Jun 2024 19:59:10 -0400 Subject: [PATCH] GH-42020: [Swift] Add Arrow decoding implementation for Swift Codable --- swift/Arrow/Sources/Arrow/ArrowDecoder.swift | 343 ++++++++++++++++++ .../Arrow/Tests/ArrowTests/CodableTests.swift | 168 +++++++++ 2 files changed, 511 insertions(+) create mode 100644 swift/Arrow/Sources/Arrow/ArrowDecoder.swift create mode 100644 swift/Arrow/Tests/ArrowTests/CodableTests.swift diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift new file mode 100644 index 0000000000000..55c0ea1b07bc6 --- /dev/null +++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation + +public class ArrowDecoder: Decoder { + var rbIndex: UInt = 0 + public var codingPath: [CodingKey] = [] + public var userInfo: [CodingUserInfoKey: Any] = [:] + public let rb: RecordBatch + public let nameToCol: [String: ArrowArrayHolder] + public let columns: [ArrowArrayHolder] + public init(_ decoder: ArrowDecoder) { + self.userInfo = decoder.userInfo + self.codingPath = decoder.codingPath + self.rb = decoder.rb + self.columns = decoder.columns + self.nameToCol = decoder.nameToCol + self.rbIndex = decoder.rbIndex + } + + public init(_ rb: RecordBatch) { + self.rb = rb + var colMapping = [String: ArrowArrayHolder]() + var columns = [ArrowArrayHolder]() + for index in 0..(_ type: T.Type) throws -> [T] { + var output = [T]() + for index in 0..(keyedBy type: Key.Type + ) -> KeyedDecodingContainer where Key: CodingKey { + let container = ArrowKeyedDecoding(self, codingPath: codingPath) + return KeyedDecodingContainer(container) + } + + public func unkeyedContainer() -> UnkeyedDecodingContainer { + return ArrowUnkeyedDecoding(self, codingPath: codingPath) + } + + public func singleValueContainer() -> SingleValueDecodingContainer { + return ArrowSingleValueDecoding(self, codingPath: codingPath) + } + + func getCol(_ name: String) throws -> AnyArray { + guard let col = self.nameToCol[name] else { + throw ArrowError.invalid("Column for key \"\(name)\" not found") + } + + guard let anyArray = col.array as? AnyArray else { + throw ArrowError.invalid("Unable to convert array to AnyArray") + } + + return anyArray + } + + func getCol(_ index: Int) throws -> AnyArray { + if index >= self.columns.count { + throw ArrowError.outOfBounds(index: Int64(index)) + } + + guard let anyArray = self.columns[index].array as? AnyArray else { + throw ArrowError.invalid("Unable to convert array to AnyArray") + } + + return anyArray + } + + func doDecode(_ key: CodingKey) throws -> T? { + let array: AnyArray = try self.getCol(key.stringValue) + return array.asAny(self.rbIndex) as? T + } + + func doDecode(_ col: Int) throws -> T? { + let array: AnyArray = try self.getCol(col) + return array.asAny(self.rbIndex) as? T + } +} + +private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { + var codingPath: [CodingKey] + var count: Int? = 0 + var isAtEnd: Bool = false + var currentIndex: Int = 0 + let decoder: ArrowDecoder + + init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { + self.decoder = decoder + self.codingPath = codingPath + self.count = self.decoder.columns.count + } + + mutating func increment() { + self.currentIndex += 1 + self.isAtEnd = self.currentIndex >= self.count! + } + + mutating func decodeNil() throws -> Bool { + defer {increment()} + return try self.decoder.doDecode(self.currentIndex) == nil + } + + mutating func decode(_ type: T.Type) throws -> T where T: Decodable { + if type == Date.self { + defer {increment()} + return try self.decoder.doDecode(self.currentIndex)! + } else if type == Int8.self || type == Int16.self || + type == Int32.self || type == Int64.self || + type == UInt8.self || type == UInt16.self || + type == UInt32.self || type == UInt64.self || + type == String.self || type == Double.self || + type == Float.self { + defer {increment()} + return try self.decoder.doDecode(self.currentIndex)! + } else { + throw ArrowError.invalid("Type \(type) is currently not supported") + } + } + + func nestedContainer( + keyedBy type: NestedKey.Type + ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + throw ArrowError.invalid("Nested decoding is currently not supported.") + } + + func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { + throw ArrowError.invalid("Nested decoding is currently not supported.") + } + + func superDecoder() throws -> Decoder { + throw ArrowError.invalid("super decoding is currently not supported.") + } +} + +private struct ArrowKeyedDecoding: KeyedDecodingContainerProtocol { + var codingPath = [CodingKey]() + var allKeys = [Key]() + let decoder: ArrowDecoder + + init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { + self.decoder = decoder + self.codingPath = codingPath + } + + func contains(_ key: Key) -> Bool { + return self.decoder.nameToCol.keys.contains(key.stringValue) + } + + func decodeNil(forKey key: Key) throws -> Bool { + return try self.decoder.doDecode(key) == nil + } + + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: String.Type, forKey key: Key) throws -> String { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { + throw ArrowError.invalid("Int type is not supported (please use Int8, Int16, Int32 or Int64)") + } + + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { + throw ArrowError.invalid( + "UInt type is not supported (please use UInt8, UInt16, UInt32 or UInt64)") + } + + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { + return try self.decoder.doDecode(key)! + } + + func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { + if type == Date.self { + return try self.decoder.doDecode(key)! + } else { + throw ArrowError.invalid("Type \(type) is currently not supported") + } + } + + func nestedContainer( + keyedBy type: NestedKey.Type, + forKey key: Key + ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + throw ArrowError.invalid("Nested decoding is currently not supported.") + } + + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + throw ArrowError.invalid("Nested decoding is currently not supported.") + } + + func superDecoder() throws -> Decoder { + throw ArrowError.invalid("super decoding is currently not supported.") + } + + func superDecoder(forKey key: Key) throws -> Decoder { + throw ArrowError.invalid("super decoding is currently not supported.") + } +} + +private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { + var codingPath = [CodingKey]() + let decoder: ArrowDecoder + + init(_ decoder: ArrowDecoder, codingPath: [CodingKey]) { + self.decoder = decoder + self.codingPath = codingPath + } + + func decodeNil() -> Bool { + return false + } + + func decode(_ type: Bool.Type) throws -> Bool { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: String.Type) throws -> String { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Double.Type) throws -> Double { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Float.Type) throws -> Float { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Int.Type) throws -> Int { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Int8.Type) throws -> Int8 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Int16.Type) throws -> Int16 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Int32.Type) throws -> Int32 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: Int64.Type) throws -> Int64 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: UInt.Type) throws -> UInt { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: UInt8.Type) throws -> UInt8 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: UInt16.Type) throws -> UInt16 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: UInt32.Type) throws -> UInt32 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: UInt64.Type) throws -> UInt64 { + return try self.decoder.doDecode(0)! + } + + func decode(_ type: T.Type) throws -> T where T: Decodable { + if type == Date.self { + return try self.decoder.doDecode(0)! + } else { + throw ArrowError.invalid("Type \(type) is currently not supported") + } + } +} diff --git a/swift/Arrow/Tests/ArrowTests/CodableTests.swift b/swift/Arrow/Tests/ArrowTests/CodableTests.swift new file mode 100644 index 0000000000000..89f4a40f38e1f --- /dev/null +++ b/swift/Arrow/Tests/ArrowTests/CodableTests.swift @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import XCTest +@testable import Arrow + +final class CodableTests: XCTestCase { + public class TestClass: Codable, EmptyInitalizer { + public var prop0: Bool + public var prop1: Int8 + public var prop2: Int16 + public var prop3: Int32 + public var prop4: Int64 + public var prop5: UInt8 + public var prop6: UInt16 + public var prop7: UInt32 + public var prop8: UInt64 + public var prop9: Float + public var prop10: Double + public var prop11: String + public var prop12: Date + + public required init() { + self.prop0 = false + self.prop1 = 1 + self.prop2 = 2 + self.prop3 = 3 + self.prop4 = 4 + self.prop5 = 5 + self.prop6 = 6 + self.prop7 = 7 + self.prop8 = 8 + self.prop9 = 9 + self.prop10 = 10 + self.prop11 = "11" + self.prop12 = Date.now + } + } + + func testArrowKeyedDecoder() throws { // swiftlint:disable:this function_body_length + let date1 = Date(timeIntervalSinceReferenceDate: 86400 * 5000 + 352) + + let boolBuilder = try ArrowArrayBuilders.loadBoolArrayBuilder() + let int8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let int16Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let int32Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let int64Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let uint8Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let uint16Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let uint32Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let uint64Builder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let floatBuilder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let doubleBuilder: NumberArrayBuilder = try ArrowArrayBuilders.loadNumberArrayBuilder() + let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() + let dateBuilder = try ArrowArrayBuilders.loadDate64ArrayBuilder() + + boolBuilder.append(false, true, false) + int8Builder.append(10, 11, 12) + int16Builder.append(20, 21, 22) + int32Builder.append(30, 31, 32) + int64Builder.append(40, 41, 42) + uint8Builder.append(50, 51, 52) + uint16Builder.append(60, 61, 62) + uint32Builder.append(70, 71, 72) + uint64Builder.append(80, 81, 82) + floatBuilder.append(90.1, 91.1, 92.1) + doubleBuilder.append(100.1, 101.1, 102.1) + stringBuilder.append("test0", "test1", "test2") + dateBuilder.append(date1, date1, date1) + let result = RecordBatch.Builder() + .addColumn("prop0", arrowArray: try boolBuilder.toHolder()) + .addColumn("prop1", arrowArray: try int8Builder.toHolder()) + .addColumn("prop2", arrowArray: try int16Builder.toHolder()) + .addColumn("prop3", arrowArray: try int32Builder.toHolder()) + .addColumn("prop4", arrowArray: try int64Builder.toHolder()) + .addColumn("prop5", arrowArray: try uint8Builder.toHolder()) + .addColumn("prop6", arrowArray: try uint16Builder.toHolder()) + .addColumn("prop7", arrowArray: try uint32Builder.toHolder()) + .addColumn("prop8", arrowArray: try uint64Builder.toHolder()) + .addColumn("prop9", arrowArray: try floatBuilder.toHolder()) + .addColumn("prop10", arrowArray: try doubleBuilder.toHolder()) + .addColumn("prop11", arrowArray: try stringBuilder.toHolder()) + .addColumn("prop12", arrowArray: try dateBuilder.toHolder()) + .finish() + switch result { + case .success(let rb): + let decoder = ArrowDecoder(rb) + var testClasses = try decoder.decode(TestClass.self) + for index in 0.. = try ArrowArrayBuilders.loadNumberArrayBuilder() + int8Builder.append(10, 11, 12) + let result = RecordBatch.Builder() + .addColumn("prop1", arrowArray: try int8Builder.toHolder()) + .finish() + switch result { + case .success(let rb): + let decoder = ArrowDecoder(rb) + let testData = try decoder.decode(Int8.self) + for index in 0.. = try ArrowArrayBuilders.loadNumberArrayBuilder() + let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder() + int8Builder.append(10, 11, 12) + stringBuilder.append("test0", "test1", "test2") + let result = RecordBatch.Builder() + .addColumn("prop1", arrowArray: try int8Builder.toHolder()) + .addColumn("prop2", arrowArray: try stringBuilder.toHolder()) + .finish() + switch result { + case .success(let rb): + let decoder = ArrowDecoder(rb) + let testData = try decoder.decode([Int8: String].self) + var index: Int8 = 0 + for data in testData { + let str = data[10 + index] + XCTAssertEqual(str, "test\(index)") + index += 1 + } + case .failure(let err): + throw err + } + } + +}