Skip to content

Commit

Permalink
apacheGH-39519: [Swift] Fix null count when using reader (apache#39520)
Browse files Browse the repository at this point in the history
Currently the reader is not properly setting the null count when building an array from a stream.  This PR adds a fix for this.

* Closes: apache#39519

Authored-by: Alva Bandy <abandy@live.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
abandy authored and zanmato1984 committed Feb 28, 2024
1 parent 06fbeae commit 4bf43c6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 47 deletions.
12 changes: 8 additions & 4 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,17 @@ public class ArrowReader {
private func loadPrimitiveData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
do {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset)
length: nullLength, messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer])
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
Expand All @@ -76,10 +78,11 @@ public class ArrowReader {
private func loadVariableData(_ loadInfo: DataLoadInfo) -> Result<ArrowArrayHolder, ArrowError> {
let node = loadInfo.recordBatch.nodes(at: loadInfo.nodeIndex)!
do {
let nullLength = UInt(ceil(Double(node.length) / 8))
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex)
let nullBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex)!
let arrowNullBuffer = makeBuffer(nullBuffer, fileData: loadInfo.fileData,
length: UInt(node.nullCount), messageOffset: loadInfo.messageOffset)
length: nullLength, messageOffset: loadInfo.messageOffset)
try validateBufferIndex(loadInfo.recordBatch, index: loadInfo.bufferIndex + 1)
let offsetBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 1)!
let arrowOffsetBuffer = makeBuffer(offsetBuffer, fileData: loadInfo.fileData,
Expand All @@ -88,7 +91,8 @@ public class ArrowReader {
let valueBuffer = loadInfo.recordBatch.buffers(at: loadInfo.bufferIndex + 2)!
let arrowValueBuffer = makeBuffer(valueBuffer, fileData: loadInfo.fileData,
length: UInt(node.length), messageOffset: loadInfo.messageOffset)
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer])
return makeArrayHolder(loadInfo.field, buffers: [arrowNullBuffer, arrowOffsetBuffer, arrowValueBuffer],
nullCount: UInt(node.nullCount))
} catch let error as ArrowError {
return .failure(error)
} catch {
Expand Down
82 changes: 49 additions & 33 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import FlatBuffers
import Foundation

private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
private func makeBinaryHolder(_ buffers: [ArrowBuffer],
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBinary), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Int8>.stride)
nullCount: nullCount, stride: MemoryLayout<Int8>.stride)
return .success(ArrowArrayHolder(BinaryArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -30,10 +31,11 @@ private func makeBinaryHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHold
}
}

private func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
private func makeStringHolder(_ buffers: [ArrowBuffer],
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Int8>.stride)
nullCount: nullCount, stride: MemoryLayout<Int8>.stride)
return .success(ArrowArrayHolder(StringArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -43,30 +45,32 @@ private func makeStringHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHold
}

private func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint,
buffers: [ArrowBuffer]
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
switch floatType.precision {
case .single:
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat)
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat, nullCount: nullCount)
case .double:
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble)
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble, nullCount: nullCount)
default:
return .failure(.unknownType("Float precision \(floatType.precision) currently not supported"))
}
}

private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date,
buffers: [ArrowBuffer]
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
if dateType.unit == .day {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Date>.stride)
nullCount: nullCount, stride: MemoryLayout<Date>.stride)
return .success(ArrowArrayHolder(Date32Array(arrowData)))
}

let arrowData = try ArrowData(ArrowType(ArrowType.ArrowString), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Date>.stride)
nullCount: nullCount, stride: MemoryLayout<Date>.stride)
return .success(ArrowArrayHolder(Date64Array(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -76,19 +80,20 @@ private func makeDateHolder(_ dateType: org_apache_arrow_flatbuf_Date,
}

private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,
buffers: [ArrowBuffer]
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
if timeType.unit == .second || timeType.unit == .millisecond {
let arrowUnit: ArrowTime32Unit = timeType.unit == .second ? .seconds : .milliseconds
let arrowData = try ArrowData(ArrowTypeTime32(arrowUnit), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Time32>.stride)
nullCount: nullCount, stride: MemoryLayout<Time32>.stride)
return .success(ArrowArrayHolder(FixedArray<Time32>(arrowData)))
}

let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds
let arrowData = try ArrowData(ArrowTypeTime64(arrowUnit), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<Time64>.stride)
nullCount: nullCount, stride: MemoryLayout<Time64>.stride)
return .success(ArrowArrayHolder(FixedArray<Time64>(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -97,10 +102,11 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,
}
}

private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
private func makeBoolHolder(_ buffers: [ArrowBuffer],
nullCount: UInt) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<UInt8>.stride)
nullCount: nullCount, stride: MemoryLayout<UInt8>.stride)
return .success(ArrowArrayHolder(BoolArray(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -111,11 +117,12 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder

private func makeFixedHolder<T>(
_: T.Type, buffers: [ArrowBuffer],
arrowType: ArrowType.Info
arrowType: ArrowType.Info,
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<T>.stride)
nullCount: nullCount, stride: MemoryLayout<T>.stride)
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
} catch let error as ArrowError {
return .failure(error)
Expand All @@ -124,9 +131,10 @@ private func makeFixedHolder<T>(
}
}

func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity function_body_length
_ field: org_apache_arrow_flatbuf_Field,
buffers: [ArrowBuffer]
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
let type = field.typeType
switch type {
Expand All @@ -135,45 +143,53 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
let bitWidth = intType.bitWidth
if bitWidth == 8 {
if intType.isSigned {
return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8)
return makeFixedHolder(Int8.self, buffers: buffers,
arrowType: ArrowType.ArrowInt8, nullCount: nullCount)
} else {
return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8)
return makeFixedHolder(UInt8.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt8, nullCount: nullCount)
}
} else if bitWidth == 16 {
if intType.isSigned {
return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16)
return makeFixedHolder(Int16.self, buffers: buffers,
arrowType: ArrowType.ArrowInt16, nullCount: nullCount)
} else {
return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16)
return makeFixedHolder(UInt16.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt16, nullCount: nullCount)
}
} else if bitWidth == 32 {
if intType.isSigned {
return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32)
return makeFixedHolder(Int32.self, buffers: buffers,
arrowType: ArrowType.ArrowInt32, nullCount: nullCount)
} else {
return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32)
return makeFixedHolder(UInt32.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt32, nullCount: nullCount)
}
} else if bitWidth == 64 {
if intType.isSigned {
return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64)
return makeFixedHolder(Int64.self, buffers: buffers,
arrowType: ArrowType.ArrowInt64, nullCount: nullCount)
} else {
return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64)
return makeFixedHolder(UInt64.self, buffers: buffers,
arrowType: ArrowType.ArrowUInt64, nullCount: nullCount)
}
}
return .failure(.unknownType("Int width \(bitWidth) currently not supported"))
case .bool:
return makeBoolHolder(buffers)
return makeBoolHolder(buffers, nullCount: nullCount)
case .floatingpoint:
let floatType = field.type(type: org_apache_arrow_flatbuf_FloatingPoint.self)!
return makeFloatHolder(floatType, buffers: buffers)
return makeFloatHolder(floatType, buffers: buffers, nullCount: nullCount)
case .utf8:
return makeStringHolder(buffers)
return makeStringHolder(buffers, nullCount: nullCount)
case .binary:
return makeBinaryHolder(buffers)
return makeBinaryHolder(buffers, nullCount: nullCount)
case .date:
let dateType = field.type(type: org_apache_arrow_flatbuf_Date.self)!
return makeDateHolder(dateType, buffers: buffers)
return makeDateHolder(dateType, buffers: buffers, nullCount: nullCount)
case .time:
let timeType = field.type(type: org_apache_arrow_flatbuf_Time.self)!
return makeTimeHolder(timeType, buffers: buffers)
return makeTimeHolder(timeType, buffers: buffers, nullCount: nullCount)
default:
return .failure(.unknownType("Type \(type) currently not supported"))
}
Expand Down
40 changes: 33 additions & 7 deletions swift/Arrow/Tests/ArrowTests/IPCTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ func makeSchema() -> ArrowSchema {
return schemaBuilder.addField("col1", type: ArrowType(ArrowType.ArrowUInt8), isNullable: true)
.addField("col2", type: ArrowType(ArrowType.ArrowString), isNullable: false)
.addField("col3", type: ArrowType(ArrowType.ArrowDate32), isNullable: false)
.addField("col4", type: ArrowType(ArrowType.ArrowInt32), isNullable: false)
.addField("col5", type: ArrowType(ArrowType.ArrowFloat), isNullable: false)
.finish()
}

func makeRecordBatch() throws -> RecordBatch {
let uint8Builder: NumberArrayBuilder<UInt8> = try ArrowArrayBuilders.loadNumberArrayBuilder()
uint8Builder.append(10)
uint8Builder.append(22)
uint8Builder.append(33)
uint8Builder.append(nil)
uint8Builder.append(nil)
uint8Builder.append(44)
let stringBuilder = try ArrowArrayBuilders.loadStringArrayBuilder()
stringBuilder.append("test10")
Expand All @@ -85,13 +87,28 @@ func makeRecordBatch() throws -> RecordBatch {
date32Builder.append(date2)
date32Builder.append(date1)
date32Builder.append(date2)
let intHolder = ArrowArrayHolder(try uint8Builder.finish())
let int32Builder: NumberArrayBuilder<Int32> = try ArrowArrayBuilders.loadNumberArrayBuilder()
int32Builder.append(1)
int32Builder.append(2)
int32Builder.append(3)
int32Builder.append(4)
let floatBuilder: NumberArrayBuilder<Float> = try ArrowArrayBuilders.loadNumberArrayBuilder()
floatBuilder.append(211.112)
floatBuilder.append(322.223)
floatBuilder.append(433.334)
floatBuilder.append(544.445)

let uint8Holder = ArrowArrayHolder(try uint8Builder.finish())
let stringHolder = ArrowArrayHolder(try stringBuilder.finish())
let date32Holder = ArrowArrayHolder(try date32Builder.finish())
let int32Holder = ArrowArrayHolder(try int32Builder.finish())
let floatHolder = ArrowArrayHolder(try floatBuilder.finish())
let result = RecordBatch.Builder()
.addColumn("col1", arrowArray: intHolder)
.addColumn("col1", arrowArray: uint8Holder)
.addColumn("col2", arrowArray: stringHolder)
.addColumn("col3", arrowArray: date32Holder)
.addColumn("col4", arrowArray: int32Holder)
.addColumn("col5", arrowArray: floatHolder)
.finish()
switch result {
case .success(let recordBatch):
Expand Down Expand Up @@ -182,15 +199,20 @@ final class IPCFileReaderTests: XCTestCase {
XCTAssertEqual(recordBatches.count, 1)
for recordBatch in recordBatches {
XCTAssertEqual(recordBatch.length, 4)
XCTAssertEqual(recordBatch.columns.count, 3)
XCTAssertEqual(recordBatch.schema.fields.count, 3)
XCTAssertEqual(recordBatch.columns.count, 5)
XCTAssertEqual(recordBatch.schema.fields.count, 5)
XCTAssertEqual(recordBatch.schema.fields[0].name, "col1")
XCTAssertEqual(recordBatch.schema.fields[0].type.info, ArrowType.ArrowUInt8)
XCTAssertEqual(recordBatch.schema.fields[1].name, "col2")
XCTAssertEqual(recordBatch.schema.fields[1].type.info, ArrowType.ArrowString)
XCTAssertEqual(recordBatch.schema.fields[2].name, "col3")
XCTAssertEqual(recordBatch.schema.fields[2].type.info, ArrowType.ArrowDate32)
XCTAssertEqual(recordBatch.schema.fields[3].name, "col4")
XCTAssertEqual(recordBatch.schema.fields[3].type.info, ArrowType.ArrowInt32)
XCTAssertEqual(recordBatch.schema.fields[4].name, "col5")
XCTAssertEqual(recordBatch.schema.fields[4].type.info, ArrowType.ArrowFloat)
let columns = recordBatch.columns
XCTAssertEqual(columns[0].nullCount, 2)
let dateVal =
"\((columns[2].array as! AsString).asString(0))" // swiftlint:disable:this force_cast
XCTAssertEqual(dateVal, "2014-09-10 00:00:00 +0000")
Expand Down Expand Up @@ -227,13 +249,17 @@ final class IPCFileReaderTests: XCTestCase {
case .success(let result):
XCTAssertNotNil(result.schema)
let schema = result.schema!
XCTAssertEqual(schema.fields.count, 3)
XCTAssertEqual(schema.fields.count, 5)
XCTAssertEqual(schema.fields[0].name, "col1")
XCTAssertEqual(schema.fields[0].type.info, ArrowType.ArrowUInt8)
XCTAssertEqual(schema.fields[1].name, "col2")
XCTAssertEqual(schema.fields[1].type.info, ArrowType.ArrowString)
XCTAssertEqual(schema.fields[2].name, "col3")
XCTAssertEqual(schema.fields[2].type.info, ArrowType.ArrowDate32)
XCTAssertEqual(schema.fields[3].name, "col4")
XCTAssertEqual(schema.fields[3].type.info, ArrowType.ArrowInt32)
XCTAssertEqual(schema.fields[4].name, "col5")
XCTAssertEqual(schema.fields[4].type.info, ArrowType.ArrowFloat)
case.failure(let error):
throw error
}
Expand Down
Loading

0 comments on commit 4bf43c6

Please sign in to comment.