Skip to content

Commit

Permalink
apacheGH-37726: [Swift] Update flight behavior to be similar to exist…
Browse files Browse the repository at this point in the history
…ing impls
  • Loading branch information
abandy committed Oct 25, 2023
1 parent 25e5d19 commit eedd214
Show file tree
Hide file tree
Showing 14 changed files with 5,530 additions and 235 deletions.
54 changes: 44 additions & 10 deletions swift/Arrow/Sources/Arrow/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class ArrowReader {
}

public class ArrowReaderResult {
fileprivate var messageSchema: org_apache_arrow_flatbuf_Schema?
public var schema: ArrowSchema?
public var batches = [RecordBatch]()
}
Expand Down Expand Up @@ -95,19 +96,14 @@ public class ArrowReader {
}
}

private func loadRecordBatch(_ message: org_apache_arrow_flatbuf_Message,
schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema,
data: Data,
messageEndOffset: Int64
) -> Result<RecordBatch, ArrowError> {
let recordBatch = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)
let nodesCount = recordBatch?.nodesCount ?? 0
private func loadRecordBatch(_ recordBatch: org_apache_arrow_flatbuf_RecordBatch, schema: org_apache_arrow_flatbuf_Schema,
arrowSchema: ArrowSchema, data: Data, messageEndOffset: Int64) -> Result<RecordBatch, ArrowError> {
let nodesCount = recordBatch.nodesCount
var bufferIndex: Int32 = 0
var columns: [ArrowArrayHolder] = []
for nodeIndex in 0 ..< nodesCount {
let field = schema.fields(at: nodeIndex)!
let loadInfo = DataLoadInfo(recordBatch: recordBatch!, field: field,
let loadInfo = DataLoadInfo(recordBatch: recordBatch, field: field,
nodeIndex: nodeIndex, bufferIndex: bufferIndex,
fileData: data, messageOffset: messageEndOffset)
var result: Result<ArrowArrayHolder, ArrowError>
Expand Down Expand Up @@ -172,7 +168,8 @@ public class ArrowReader {
switch message.headerType {
case .recordbatch:
do {
let recordBatch = try loadRecordBatch(message, schema: footer.schema!, arrowSchema: result.schema!,
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
let recordBatch = try loadRecordBatch(rbMessage, schema: footer.schema!, arrowSchema: result.schema!,
data: fileData, messageEndOffset: messageEndOffset).get()
result.batches.append(recordBatch)
} catch let error as ArrowError {
Expand Down Expand Up @@ -203,4 +200,41 @@ public class ArrowReader {
return .failure(.unknownError("Error loading file: \(error)"))
}
}

static public func MakeArrowReaderResult() -> ArrowReaderResult{
return ArrowReaderResult();
}

public func fromMessage(_ dataHeader: Data, dataBody: Data, result: ArrowReaderResult) -> Result<Void, ArrowError> {
let mbb = ByteBuffer(data: dataHeader)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: mbb)
switch message.headerType {
case .schema:
let sMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
switch loadSchema(sMessage) {
case .success(let schema):
result.schema = schema
result.messageSchema = sMessage
return .success(())
case .failure(let error):
return .failure(error)
}
case .recordbatch:
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
do {
let recordBatch = try loadRecordBatch(rbMessage, schema: result.messageSchema!, arrowSchema: result.schema!,
data: dataBody, messageEndOffset: 0).get()
result.batches.append(recordBatch)
return .success(())
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
26 changes: 13 additions & 13 deletions swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ private func makeFloatHolder(_ floatType: org_apache_arrow_flatbuf_FloatingPoint
) -> Result<ArrowArrayHolder, ArrowError> {
switch floatType.precision {
case .single:
return makeFixedHolder(Float.self, buffers: buffers)
return makeFixedHolder(Float.self, buffers: buffers, arrowType: ArrowType.ArrowFloat)
case .double:
return makeFixedHolder(Double.self, buffers: buffers)
return makeFixedHolder(Double.self, buffers: buffers, arrowType: ArrowType.ArrowDouble)
default:
return .failure(.unknownType("Float precision \(floatType.precision) currently not supported"))
}
Expand Down Expand Up @@ -99,7 +99,7 @@ private func makeTimeHolder(_ timeType: org_apache_arrow_flatbuf_Time,

private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowBool), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<UInt8>.stride)
return .success(ArrowArrayHolder(BoolArray(arrowData)))
} catch let error as ArrowError {
Expand All @@ -109,9 +109,9 @@ private func makeBoolHolder(_ buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder
}
}

private func makeFixedHolder<T>(_: T.Type, buffers: [ArrowBuffer]) -> Result<ArrowArrayHolder, ArrowError> {
fileprivate func makeFixedHolder<T>(_: T.Type, buffers: [ArrowBuffer], arrowType: ArrowType.Info) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(ArrowType(ArrowType.ArrowInt32), buffers: buffers,
let arrowData = try ArrowData(ArrowType(arrowType), buffers: buffers,
nullCount: buffers[0].length, stride: MemoryLayout<T>.stride)
return .success(ArrowArrayHolder(FixedArray<T>(arrowData)))
} catch let error as ArrowError {
Expand All @@ -132,27 +132,27 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
let bitWidth = intType.bitWidth
if bitWidth == 8 {
if intType.isSigned {
return makeFixedHolder(Int8.self, buffers: buffers)
return makeFixedHolder(Int8.self, buffers: buffers, arrowType: ArrowType.ArrowInt8)
} else {
return makeFixedHolder(UInt8.self, buffers: buffers)
return makeFixedHolder(UInt8.self, buffers: buffers, arrowType: ArrowType.ArrowUInt8)
}
} else if bitWidth == 16 {
if intType.isSigned {
return makeFixedHolder(Int16.self, buffers: buffers)
return makeFixedHolder(Int16.self, buffers: buffers, arrowType: ArrowType.ArrowInt16)
} else {
return makeFixedHolder(UInt16.self, buffers: buffers)
return makeFixedHolder(UInt16.self, buffers: buffers, arrowType: ArrowType.ArrowUInt16)
}
} else if bitWidth == 32 {
if intType.isSigned {
return makeFixedHolder(Int32.self, buffers: buffers)
return makeFixedHolder(Int32.self, buffers: buffers, arrowType: ArrowType.ArrowInt32)
} else {
return makeFixedHolder(UInt32.self, buffers: buffers)
return makeFixedHolder(UInt32.self, buffers: buffers, arrowType: ArrowType.ArrowUInt32)
}
} else if bitWidth == 64 {
if intType.isSigned {
return makeFixedHolder(Int64.self, buffers: buffers)
return makeFixedHolder(Int64.self, buffers: buffers, arrowType: ArrowType.ArrowInt64)
} else {
return makeFixedHolder(UInt64.self, buffers: buffers)
return makeFixedHolder(UInt64.self, buffers: buffers, arrowType: ArrowType.ArrowUInt64)
}
}
return .failure(.unknownType("Int width \(bitWidth) currently not supported"))
Expand Down
11 changes: 11 additions & 0 deletions swift/Arrow/Sources/Arrow/ArrowType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ public class ArrowType {
public init(_ info: ArrowType.Info) {
self.info = info
}

public var id: ArrowTypeId {
switch self.info {
case .primitiveInfo(let id):
return id;
case .timeInfo(let id):
return id;
case .variableInfo(let id):
return id;
}
}

public enum Info {
case primitiveInfo(ArrowTypeId)
Expand Down
95 changes: 67 additions & 28 deletions swift/Arrow/Sources/Arrow/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public class ArrowWriter {
self.data.append(data)
}
}

public class FileDataWriter: DataWriter {
public class FileDataWriter : DataWriter {
private var handle: FileHandle
private var currentSize: Int = 0
public var count: Int { return currentSize }
Expand All @@ -52,7 +52,7 @@ public class ArrowWriter {
self.currentSize += data.count
}
}

public class Info {
public let type: org_apache_arrow_flatbuf_MessageHeader
public let schema: ArrowSchema
Expand All @@ -62,7 +62,7 @@ public class ArrowWriter {
self.schema = schema
self.batches = batches
}

public convenience init(_ type: org_apache_arrow_flatbuf_MessageHeader, schema: ArrowSchema) {
self.init(type, schema: schema, batches: [RecordBatch]())
}
Expand Down Expand Up @@ -91,7 +91,7 @@ public class ArrowWriter {
return .failure(error)
}
}

private func writeSchema(_ fbb: inout FlatBufferBuilder, schema: ArrowSchema) -> Result<Offset, ArrowError> {
var fieldOffsets = [Offset]()
for field in schema.fields {
Expand All @@ -103,19 +103,19 @@ public class ArrowWriter {
}

}

let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
let schemaOffset =
org_apache_arrow_flatbuf_Schema.createSchema(&fbb,
endianness: .little,
fieldsVectorOffset: fieldsOffset)
return .success(schemaOffset)

}

private func writeRecordBatches(_ writer: inout DataWriter,
batches: [RecordBatch]
) -> Result<[org_apache_arrow_flatbuf_Block], ArrowError> {

private func writeRecordBatches(_ writer: inout DataWriter, batches: [RecordBatch]) -> Result<[org_apache_arrow_flatbuf_Block], ArrowError> {
var rbBlocks = [org_apache_arrow_flatbuf_Block]()

for batch in batches {
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
Expand All @@ -135,15 +135,14 @@ public class ArrowWriter {
return .failure(error)
}
}

return .success(rbBlocks)
}

private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
let schema = batch.schema
var output = Data()
var fbb = FlatBufferBuilder()

// write out field nodes
var fieldNodeOffsets = [Offset]()
fbb.startVector(schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
Expand All @@ -154,8 +153,9 @@ public class ArrowWriter {
nullCount: Int64(column.nullCount))
fieldNodeOffsets.append(fbb.create(struct: fieldNode))
}

let nodeOffset = fbb.endVector(len: schema.fields.count)

// write out buffers
var buffers = [org_apache_arrow_flatbuf_Buffer]()
var bufferOffset = Int(0)
Expand All @@ -174,21 +174,22 @@ public class ArrowWriter {
for buffer in buffers.reversed() {
fbb.create(struct: buffer)
}

let batchBuffersOffset = fbb.endVector(len: buffers.count)
let startRb = org_apache_arrow_flatbuf_RecordBatch.startRecordBatch(&fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(nodes: nodeOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(buffers: batchBuffersOffset, &fbb)
let startRb = org_apache_arrow_flatbuf_RecordBatch.startRecordBatch(&fbb);
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(nodes:nodeOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.addVectorOf(buffers:batchBuffersOffset, &fbb)
org_apache_arrow_flatbuf_RecordBatch.add(length: Int64(batch.length), &fbb)
let recordBatchOffset = org_apache_arrow_flatbuf_RecordBatch.endRecordBatch(&fbb, start: startRb)
let bodySize = Int64(bufferOffset)
let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(version: .max, &fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(bodySize), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .recordbatch, &fbb)
org_apache_arrow_flatbuf_Message.add(header: recordBatchOffset, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
output.append(fbb.data)
return .success((output, Offset(offset: UInt32(output.count))))
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
}

private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<Bool, ArrowError> {
Expand All @@ -200,7 +201,7 @@ public class ArrowWriter {
writer.append(bufferData)
}
}

return .success(true)
}

Expand All @@ -224,10 +225,10 @@ public class ArrowWriter {
case .failure(let error):
return .failure(error)
}

return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
switch writeSchema(&fbb, schema: info.schema) {
Expand Down Expand Up @@ -256,10 +257,10 @@ public class ArrowWriter {
case .failure(let error):
return .failure(error)
}

return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
Expand All @@ -273,7 +274,7 @@ public class ArrowWriter {
return .failure(error)
}
}

public func toFile(_ fileName: URL, info: ArrowWriter.Info) -> Result<Bool, ArrowError> {
do {
try Data().write(to: fileName)
Expand All @@ -286,7 +287,7 @@ public class ArrowWriter {

var markerData = FILEMARKER.data(using: .utf8)!
addPadForAlignment(&markerData)

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
Expand All @@ -298,4 +299,42 @@ public class ArrowWriter {

return .success(true)
}

public func toMessage(_ batch: RecordBatch) -> Result<[Data], ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeRecordBatch(batch: batch) {
case .success(let message):
writer.append(message.0)
addPadForAlignment(&writer)
var dataWriter: any DataWriter = InMemDataWriter()
switch writeRecordBatchData(&dataWriter, batch: batch) {
case .success(_):
return .success([(writer as! InMemDataWriter).data, (dataWriter as! InMemDataWriter).data])
case .failure(let error):
return .failure(error)
}
case .failure(let error):
return .failure(error)
}
}

public func toMessage(_ schema: ArrowSchema) -> Result<Data, ArrowError> {
var schemaSize: Int32 = 0;
var fbb = FlatBufferBuilder()
switch writeSchema(&fbb, schema: schema) {
case .success(let schemaOffset):
schemaSize = Int32(schemaOffset.o)
case .failure(let error):
return .failure(error)
}

let startMessage = org_apache_arrow_flatbuf_Message.startMessage(&fbb)
org_apache_arrow_flatbuf_Message.add(bodyLength: Int64(0), &fbb)
org_apache_arrow_flatbuf_Message.add(headerType: .schema, &fbb)
org_apache_arrow_flatbuf_Message.add(header: Offset(offset: UOffset(schemaSize)), &fbb)
org_apache_arrow_flatbuf_Message.add(version: .max, &fbb)
let messageOffset = org_apache_arrow_flatbuf_Message.endMessage(&fbb, start: startMessage)
fbb.finish(offset: messageOffset)
return .success(fbb.data)
}
}
Loading

0 comments on commit eedd214

Please sign in to comment.