diff --git a/csharp/src/Apache.Arrow/Arrays/BinaryViewArray.cs b/csharp/src/Apache.Arrow/Arrays/BinaryViewArray.cs index 4f62dffd1ddeb..b7c9b07336a5a 100644 --- a/csharp/src/Apache.Arrow/Arrays/BinaryViewArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/BinaryViewArray.cs @@ -322,7 +322,7 @@ public ReadOnlySpan GetBytes(int index, out bool isNull) BinaryView binaryView = Views[index]; if (binaryView.IsInline) { - return ViewsBuffer.Span.Slice(16 * index + 4, binaryView.Length); + return ViewsBuffer.Span.Slice(16 * (Offset + index) + 4, binaryView.Length); } return DataBuffer(binaryView._bufferIndex).Span.Slice(binaryView._bufferOffset, binaryView.Length); diff --git a/csharp/src/Apache.Arrow/Arrays/FixedSizeBinaryArray.cs b/csharp/src/Apache.Arrow/Arrays/FixedSizeBinaryArray.cs index 0fa7954724f38..9d597ef1624ea 100644 --- a/csharp/src/Apache.Arrow/Arrays/FixedSizeBinaryArray.cs +++ b/csharp/src/Apache.Arrow/Arrays/FixedSizeBinaryArray.cs @@ -68,7 +68,7 @@ public ReadOnlySpan GetBytes(int index) } int size = ((FixedSizeBinaryType)Data.DataType).ByteWidth; - return ValueBuffer.Span.Slice(index * size, size); + return ValueBuffer.Span.Slice((Offset + index) * size, size); } int IReadOnlyCollection.Count => Length; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs index 4e273dbde5690..a37c501072f4b 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs @@ -261,7 +261,7 @@ private ArrayData LoadField( if (fieldNullCount < 0) { - throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message + throw new InvalidDataException("Null count must be >= 0"); // TODO:Localize exception message } int buffers; diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs index 7b319b03d790c..6127c5a662dfe 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Arrays; @@ -69,23 +70,37 @@ private class ArrowRecordBatchFlatBufferBuilder : IArrowArrayVisitor, IArrowArrayVisitor { + public readonly struct FieldNode + { + public readonly int Length; + public readonly int NullCount; + + public FieldNode(int length, int nullCount) + { + Length = length; + NullCount = nullCount; + } + } + public readonly struct Buffer { - public readonly ArrowBuffer DataBuffer; + public readonly ReadOnlyMemory DataBuffer; public readonly int Offset; - public Buffer(ArrowBuffer buffer, int offset) + public Buffer(ReadOnlyMemory buffer, int offset) { DataBuffer = buffer; Offset = offset; } } + private readonly List _fieldNodes; private readonly List _buffers; private readonly ICompressionCodec _compressionCodec; private readonly MemoryAllocator _allocator; private readonly MemoryStream _compressionStream; + public IReadOnlyList FieldNodes => _fieldNodes; public IReadOnlyList Buffers => _buffers; public List VariadicCounts { get; private set; } @@ -97,56 +112,80 @@ public ArrowRecordBatchFlatBufferBuilder( _compressionCodec = compressionCodec; _compressionStream = compressionStream; _allocator = allocator; + _fieldNodes = new List(); _buffers = new List(); TotalLength = 0; } - public void Visit(Int8Array array) => CreateBuffers(array); - public void Visit(Int16Array array) => CreateBuffers(array); - public void Visit(Int32Array array) => CreateBuffers(array); - public void Visit(Int64Array array) => CreateBuffers(array); - public void Visit(UInt8Array array) => CreateBuffers(array); - public void Visit(UInt16Array array) => CreateBuffers(array); - public void Visit(UInt32Array array) => CreateBuffers(array); - public void Visit(UInt64Array array) => CreateBuffers(array); + public void VisitArray(IArrowArray array) + { + _fieldNodes.Add(new FieldNode(array.Length, array.NullCount)); + + array.Accept(this); + } + + public void Visit(Int8Array array) => VisitPrimitiveArray(array); + public void Visit(Int16Array array) => VisitPrimitiveArray(array); + public void Visit(Int32Array array) => VisitPrimitiveArray(array); + public void Visit(Int64Array array) => VisitPrimitiveArray(array); + public void Visit(UInt8Array array) => VisitPrimitiveArray(array); + public void Visit(UInt16Array array) => VisitPrimitiveArray(array); + public void Visit(UInt32Array array) => VisitPrimitiveArray(array); + public void Visit(UInt64Array array) => VisitPrimitiveArray(array); #if NET5_0_OR_GREATER - public void Visit(HalfFloatArray array) => CreateBuffers(array); + public void Visit(HalfFloatArray array) => VisitPrimitiveArray(array); #endif - public void Visit(FloatArray array) => CreateBuffers(array); - public void Visit(DoubleArray array) => CreateBuffers(array); - public void Visit(TimestampArray array) => CreateBuffers(array); - public void Visit(BooleanArray array) => CreateBuffers(array); - public void Visit(Date32Array array) => CreateBuffers(array); - public void Visit(Date64Array array) => CreateBuffers(array); - public void Visit(Time32Array array) => CreateBuffers(array); - public void Visit(Time64Array array) => CreateBuffers(array); - public void Visit(DurationArray array) => CreateBuffers(array); - public void Visit(YearMonthIntervalArray array) => CreateBuffers(array); - public void Visit(DayTimeIntervalArray array) => CreateBuffers(array); - public void Visit(MonthDayNanosecondIntervalArray array) => CreateBuffers(array); + public void Visit(FloatArray array) => VisitPrimitiveArray(array); + public void Visit(DoubleArray array) => VisitPrimitiveArray(array); + public void Visit(TimestampArray array) => VisitPrimitiveArray(array); + public void Visit(Date32Array array) => VisitPrimitiveArray(array); + public void Visit(Date64Array array) => VisitPrimitiveArray(array); + public void Visit(Time32Array array) => VisitPrimitiveArray(array); + public void Visit(Time64Array array) => VisitPrimitiveArray(array); + public void Visit(DurationArray array) => VisitPrimitiveArray(array); + public void Visit(YearMonthIntervalArray array) => VisitPrimitiveArray(array); + public void Visit(DayTimeIntervalArray array) => VisitPrimitiveArray(array); + public void Visit(MonthDayNanosecondIntervalArray array) => VisitPrimitiveArray(array); + + private void VisitPrimitiveArray(PrimitiveArray array) + where T : struct + { + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ValueBuffer, array.Offset, array.Length)); + } + + public void Visit(BooleanArray array) + { + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateBitmapBuffer(array.ValueBuffer, array.Offset, array.Length)); + } public void Visit(ListArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ValueOffsetsBuffer, array.Offset, array.Length + 1)); - array.Values.Accept(this); + VisitArray(array.Values); } public void Visit(ListViewArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer)); - _buffers.Add(CreateBuffer(array.SizesBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ValueOffsetsBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.SizesBuffer, array.Offset, array.Length)); - array.Values.Accept(this); + VisitArray(array.Values); } public void Visit(FixedSizeListArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); - array.Values.Accept(this); + var listSize = ((FixedSizeListType)array.Data.DataType).ListSize; + var valuesSlice = + ArrowArrayFactory.Slice(array.Values, array.Offset * listSize, array.Length * listSize); + + VisitArray(valuesSlice); } public void Visit(StringArray array) => Visit(array as BinaryArray); @@ -155,15 +194,15 @@ public void Visit(FixedSizeListArray array) public void Visit(BinaryArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ValueOffsetsBuffer, array.Offset, array.Length + 1)); _buffers.Add(CreateBuffer(array.ValueBuffer)); } public void Visit(BinaryViewArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ViewsBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ViewsBuffer, array.Offset, array.Length)); for (int i = 0; i < array.DataBufferCount; i++) { _buffers.Add(CreateBuffer(array.DataBuffer(i))); @@ -174,45 +213,40 @@ public void Visit(BinaryViewArray array) public void Visit(FixedSizeBinaryArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueBuffer)); + var itemSize = ((FixedSizeBinaryType)array.Data.DataType).ByteWidth; + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); + _buffers.Add(CreateSlicedBuffer(array.ValueBuffer, itemSize, array.Offset, array.Length)); } - public void Visit(Decimal128Array array) - { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueBuffer)); - } + public void Visit(Decimal128Array array) => Visit(array as FixedSizeBinaryArray); - public void Visit(Decimal256Array array) - { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueBuffer)); - } + public void Visit(Decimal256Array array) => Visit(array as FixedSizeBinaryArray); public void Visit(StructArray array) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); + _buffers.Add(CreateBitmapBuffer(array.NullBitmapBuffer, array.Offset, array.Length)); for (int i = 0; i < array.Fields.Count; i++) { - array.Fields[i].Accept(this); + // Fields property accessor handles slicing field arrays if required + VisitArray(array.Fields[i]); } } public void Visit(UnionArray array) { - _buffers.Add(CreateBuffer(array.TypeBuffer)); + _buffers.Add(CreateSlicedBuffer(array.TypeBuffer, array.Offset, array.Length)); ArrowBuffer? offsets = (array as DenseUnionArray)?.ValueOffsetBuffer; if (offsets != null) { - _buffers.Add(CreateBuffer(offsets.Value)); + _buffers.Add(CreateSlicedBuffer(offsets.Value, array.Offset, array.Length)); } for (int i = 0; i < array.Fields.Count; i++) { - array.Fields[i].Accept(this); + // Fields property accessor handles slicing field arrays for sparse union arrays if required + VisitArray(array.Fields[i]); } } @@ -221,8 +255,7 @@ public void Visit(DictionaryArray array) // Dictionary is serialized separately in Dictionary serialization. // We are only interested in indices at this context. - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.IndicesBuffer)); + array.Indices.Accept(this); } public void Visit(NullArray array) @@ -230,25 +263,67 @@ public void Visit(NullArray array) // There are no buffers for a NullArray } - private void CreateBuffers(BooleanArray array) + private Buffer CreateBitmapBuffer(ArrowBuffer buffer, int offset, int length) { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueBuffer)); + if (buffer.IsEmpty) + { + return CreateBuffer(buffer.Memory); + } + + var paddedLength = CalculatePaddedBufferLength(BitUtility.ByteCount(length)); + if (offset % 8 == 0) + { + var byteOffset = offset / 8; + var sliceLength = Math.Min(paddedLength, buffer.Length - byteOffset); + + return CreateBuffer(buffer.Memory.Slice(byteOffset, sliceLength)); + } + else + { + // Need to copy bitmap so the first bit is aligned with the first byte + var memoryOwner = _allocator.Allocate(paddedLength); + var outputSpan = memoryOwner.Memory.Span; + var inputSpan = buffer.Span; + for (var i = 0; i < length; ++i) + { + BitUtility.SetBit(outputSpan, i, BitUtility.GetBit(inputSpan, offset + i)); + } + + return CreateBuffer(memoryOwner.Memory); + } } - private void CreateBuffers(PrimitiveArray array) + private Buffer CreateSlicedBuffer(ArrowBuffer buffer, int offset, int length) where T : struct { - _buffers.Add(CreateBuffer(array.NullBitmapBuffer)); - _buffers.Add(CreateBuffer(array.ValueBuffer)); + return CreateSlicedBuffer(buffer, Unsafe.SizeOf(), offset, length); + } + + private Buffer CreateSlicedBuffer(ArrowBuffer buffer, int itemSize, int offset, int length) + { + var byteLength = length * itemSize; + var paddedLength = CalculatePaddedBufferLength(byteLength); + if (offset != 0 || paddedLength < buffer.Length) + { + var byteOffset = offset * itemSize; + var sliceLength = Math.Min(paddedLength, buffer.Length - byteOffset); + return CreateBuffer(buffer.Memory.Slice(byteOffset, sliceLength)); + } + + return CreateBuffer(buffer.Memory); } private Buffer CreateBuffer(ArrowBuffer buffer) + { + return CreateBuffer(buffer.Memory); + } + + private Buffer CreateBuffer(ReadOnlyMemory buffer) { int offset = TotalLength; const int UncompressedLengthSize = 8; - ArrowBuffer bufferToWrite; + ReadOnlyMemory bufferToWrite; if (_compressionCodec == null) { bufferToWrite = buffer; @@ -258,7 +333,7 @@ private Buffer CreateBuffer(ArrowBuffer buffer) // Write zero length and skip compression var uncompressedLengthBytes = _allocator.Allocate(UncompressedLengthSize); BinaryPrimitives.WriteInt64LittleEndian(uncompressedLengthBytes.Memory.Span, 0); - bufferToWrite = new ArrowBuffer(uncompressedLengthBytes); + bufferToWrite = uncompressedLengthBytes.Memory; } else { @@ -266,14 +341,14 @@ private Buffer CreateBuffer(ArrowBuffer buffer) // compressed buffers are stored. _compressionStream.Seek(0, SeekOrigin.Begin); _compressionStream.SetLength(0); - _compressionCodec.Compress(buffer.Memory, _compressionStream); + _compressionCodec.Compress(buffer, _compressionStream); if (_compressionStream.Length < buffer.Length) { var newBuffer = _allocator.Allocate((int) _compressionStream.Length + UncompressedLengthSize); BinaryPrimitives.WriteInt64LittleEndian(newBuffer.Memory.Span, buffer.Length); _compressionStream.Seek(0, SeekOrigin.Begin); _compressionStream.ReadFullBuffer(newBuffer.Memory.Slice(UncompressedLengthSize)); - bufferToWrite = new ArrowBuffer(newBuffer); + bufferToWrite = newBuffer.Memory; } else { @@ -281,8 +356,8 @@ private Buffer CreateBuffer(ArrowBuffer buffer) // buffer instead, and indicate this by setting the uncompressed length to -1 var newBuffer = _allocator.Allocate(buffer.Length + UncompressedLengthSize); BinaryPrimitives.WriteInt64LittleEndian(newBuffer.Memory.Span, -1); - buffer.Memory.CopyTo(newBuffer.Memory.Slice(UncompressedLengthSize)); - bufferToWrite = new ArrowBuffer(newBuffer); + buffer.CopyTo(newBuffer.Memory.Slice(UncompressedLengthSize)); + bufferToWrite = newBuffer.Memory; } } @@ -366,29 +441,6 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp } } - private void CreateSelfAndChildrenFieldNodes(ArrayData data) - { - if (data.DataType is NestedType) - { - // flatbuffer struct vectors have to be created in reverse order - for (int i = data.Children.Length - 1; i >= 0; i--) - { - CreateSelfAndChildrenFieldNodes(data.Children[i]); - } - } - Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.GetNullCount()); - } - - private static int CountAllNodes(IReadOnlyList fields) - { - int count = 0; - foreach (Field arrowArray in fields) - { - CountSelfAndChildrenNodes(arrowArray.DataType, ref count); - } - return count; - } - private Offset GetBodyCompression() { if (_options.CompressionCodec == null) @@ -406,18 +458,6 @@ private static int CountAllNodes(IReadOnlyList fields) Builder, compressionType, Flatbuf.BodyCompressionMethod.BUFFER); } - private static void CountSelfAndChildrenNodes(IArrowType type, ref int count) - { - if (type is NestedType nestedType) - { - foreach (Field childField in nestedType.Fields) - { - CountSelfAndChildrenNodes(childField.DataType, ref count); - } - } - count++; - } - private protected void WriteRecordBatchInternal(RecordBatch recordBatch) { // TODO: Truncate buffers with extraneous padding / unused capacity @@ -461,8 +501,6 @@ private protected void WriteRecordBatchInternal(RecordBatch recordBatch) private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default) { - // TODO: Truncate buffers with extraneous padding / unused capacity - if (!HasWrittenSchema) { await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false); @@ -506,11 +544,11 @@ private long WriteBufferData(IReadOnlyList buffer = buffers[i].DataBuffer; if (buffer.IsEmpty) continue; - WriteBuffer(buffer); + BaseStream.Write(buffer); int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); int padding = paddedLength - buffer.Length; @@ -537,11 +575,11 @@ private async ValueTask WriteBufferDataAsync(IReadOnlyList buffer = buffers[i].DataBuffer; if (buffer.IsEmpty) continue; - await WriteBufferAsync(buffer, cancellationToken).ConfigureAwait(false); + await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length)); int padding = paddedLength - buffer.Length; @@ -571,22 +609,6 @@ private Tuple Pre { Builder.Clear(); - // Serialize field nodes - - int fieldCount = fields.Count; - - Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes(fields)); - - // flatbuffer struct vectors have to be created in reverse order - for (int i = fieldCount - 1; i >= 0; i--) - { - CreateSelfAndChildrenFieldNodes(arrays[i].Data); - } - - VectorOffset fieldNodesVectorOffset = Builder.EndVector(); - - // Serialize buffers - // CompressionCodec can be disposed after all data is visited by the builder, // and doesn't need to be alive for the full lifetime of the ArrowRecordBatchFlatBufferBuilder using var compressionCodec = _options.CompressionCodec.HasValue @@ -594,20 +616,34 @@ private Tuple Pre : null; var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder(compressionCodec, _allocator, _compressionStream); - for (int i = 0; i < fieldCount; i++) + + // Visit all arrays recursively + for (int i = 0; i < fields.Count; i++) { IArrowArray fieldArray = arrays[i]; - fieldArray.Accept(recordBatchBuilder); + recordBatchBuilder.VisitArray(fieldArray); + } + + // Serialize field nodes + IReadOnlyList fieldNodes = recordBatchBuilder.FieldNodes; + Flatbuf.RecordBatch.StartNodesVector(Builder, fieldNodes.Count); + + // flatbuffer struct vectors have to be created in reverse order + for (int i = fieldNodes.Count - 1; i >= 0; i--) + { + Flatbuf.FieldNode.CreateFieldNode(Builder, fieldNodes[i].Length, fieldNodes[i].NullCount); } + VectorOffset fieldNodesVectorOffset = Builder.EndVector(); + VectorOffset variadicCountOffset = default; if (recordBatchBuilder.VariadicCounts != null) { variadicCountOffset = Flatbuf.RecordBatch.CreateVariadicCountsVectorBlock(Builder, recordBatchBuilder.VariadicCounts.ToArray()); } + // Serialize buffers IReadOnlyList buffers = recordBatchBuilder.Buffers; - Flatbuf.RecordBatch.StartBuffersVector(Builder, buffers.Count); // flatbuffer struct vectors have to be created in reverse order @@ -783,16 +819,6 @@ public async Task WriteEndAsync(CancellationToken cancellationToken = default) } } - private void WriteBuffer(ArrowBuffer arrowBuffer) - { - BaseStream.Write(arrowBuffer.Memory); - } - - private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default) - { - return BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken); - } - private protected Offset SerializeSchema(Schema schema) { // Build metadata @@ -1056,6 +1082,15 @@ protected int CalculatePadding(long offset, int alignment = 8) } } + private static int CalculatePaddedBufferLength(int length) + { + long result = BitUtility.RoundUpToMultiplePowerOfTwo(length, MemoryAllocator.DefaultAlignment); + checked + { + return (int)result; + } + } + private protected void WritePadding(int length) { if (length > 0) diff --git a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs index 25ef289f0dc25..700de58adb8c1 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs @@ -29,6 +29,12 @@ public void TestStandardCases() { foreach ((List testTargetArrayList, IArrowArray expectedArray) in GenerateTestData()) { + if (expectedArray is UnionArray) + { + // Union array concatenation is incorrect. See https://github.com/apache/arrow/issues/41198 + continue; + } + IArrowArray actualArray = ArrowArrayConcatenator.Concatenate(testTargetArrayList); ArrowReaderVerifier.CompareArrays(expectedArray, actualArray); } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs index 69b8410d030f2..faf650973d64c 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs @@ -15,8 +15,11 @@ using Apache.Arrow.Ipc; using System; +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading.Tasks; +using Apache.Arrow.Types; using Xunit; namespace Apache.Arrow.Tests @@ -106,13 +109,38 @@ public async Task WritesFooterAlignedMultipleOf8Async() await ValidateRecordBatchFile(stream, originalBatch); } - private async Task ValidateRecordBatchFile(Stream stream, RecordBatch recordBatch) + [Theory] + [InlineData(0, 45)] + [InlineData(3, 45)] + [InlineData(16, 45)] + public async Task WriteSlicedArrays(int sliceOffset, int sliceLength) + { + var originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var slicedArrays = originalBatch.Arrays + .Select(array => ArrowArrayFactory.Slice(array, sliceOffset, sliceLength)) + .ToList(); + var slicedBatch = new RecordBatch(originalBatch.Schema, slicedArrays, sliceLength); + + var stream = new MemoryStream(); + var writer = new ArrowFileWriter(stream, slicedBatch.Schema, leaveOpen: true); + + await writer.WriteRecordBatchAsync(slicedBatch); + await writer.WriteEndAsync(); + + stream.Position = 0; + + // Disable strict comparison because we don't expect buffers to match exactly + // due to writing slices of buffers, and instead need to compare array values + await ValidateRecordBatchFile(stream, slicedBatch, strictCompare: false); + } + + private async Task ValidateRecordBatchFile(Stream stream, RecordBatch recordBatch, bool strictCompare = true) { var reader = new ArrowFileReader(stream); int count = await reader.RecordBatchCountAsync(); Assert.Equal(1, count); RecordBatch readBatch = await reader.ReadRecordBatchAsync(0); - ArrowReaderVerifier.CompareBatches(recordBatch, readBatch); + ArrowReaderVerifier.CompareBatches(recordBatch, readBatch, strictCompare); } /// diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs index ceeab92860e6f..07c8aa3f56b3b 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs @@ -160,7 +160,7 @@ public void Visit(StructArray array) Assert.Equal(expectedArray.Length, array.Length); Assert.Equal(expectedArray.NullCount, array.NullCount); - Assert.Equal(expectedArray.Offset, array.Offset); + Assert.Equal(0, array.Offset); Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length); Assert.Equal(expectedArray.Fields.Count, array.Fields.Count); @@ -178,10 +178,42 @@ public void Visit(UnionArray array) Assert.Equal(expectedArray.Mode, array.Mode); Assert.Equal(expectedArray.Length, array.Length); Assert.Equal(expectedArray.NullCount, array.NullCount); - Assert.Equal(expectedArray.Offset, array.Offset); + Assert.Equal(0, array.Offset); Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length); Assert.Equal(expectedArray.Fields.Count, array.Fields.Count); + if (_strictCompare) + { + Assert.True(expectedArray.TypeBuffer.Span.SequenceEqual(array.TypeBuffer.Span)); + } + else + { + for (int i = 0; i < expectedArray.Length; i++) + { + Assert.Equal(expectedArray.TypeIds[i], array.TypeIds[i]); + } + } + + if (_expectedArray is DenseUnionArray expectedDenseArray) + { + Assert.IsAssignableFrom(array); + var denseArray = array as DenseUnionArray; + Assert.NotNull(denseArray); + + if (_strictCompare) + { + Assert.True(expectedDenseArray.ValueOffsetBuffer.Span.SequenceEqual(denseArray.ValueOffsetBuffer.Span)); + } + else + { + for (int i = 0; i < expectedDenseArray.Length; i++) + { + Assert.Equal( + expectedDenseArray.ValueOffsets[i], denseArray.ValueOffsets[i]); + } + } + } + for (int i = 0; i < array.Fields.Count; i++) { array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i], _strictCompare)); @@ -220,9 +252,9 @@ private void CompareBinaryArrays(BinaryArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -252,9 +284,9 @@ private void CompareVariadicArrays(BinaryViewArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); Assert.True(expectedArray.Views.SequenceEqual(actualArray.Views)); @@ -277,9 +309,9 @@ private void CompareArrays(FixedSizeBinaryArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -306,9 +338,9 @@ private void CompareArrays(PrimitiveArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -338,9 +370,9 @@ private void CompareArrays(BooleanArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -365,9 +397,9 @@ private void CompareArrays(ListArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -375,8 +407,9 @@ private void CompareArrays(ListArray actualArray) } else { + int offsetsStart = (expectedArray.Offset) * sizeof(int); int offsetsLength = (expectedArray.Length + 1) * sizeof(int); - Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength))); + Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(offsetsStart, offsetsLength).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength))); } actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare)); @@ -391,9 +424,9 @@ private void CompareArrays(ListViewArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); if (_strictCompare) { @@ -402,9 +435,10 @@ private void CompareArrays(ListViewArray actualArray) } else { + int start = expectedArray.Offset * sizeof(int); int length = expectedArray.Length * sizeof(int); - Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(0, length).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, length))); - Assert.True(expectedArray.SizesBuffer.Span.Slice(0, length).SequenceEqual(actualArray.SizesBuffer.Span.Slice(0, length))); + Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(start, length).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, length))); + Assert.True(expectedArray.SizesBuffer.Span.Slice(start, length).SequenceEqual(actualArray.SizesBuffer.Span.Slice(0, length))); } actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare)); @@ -419,23 +453,31 @@ private void CompareArrays(FixedSizeListArray actualArray) Assert.Equal(expectedArray.Length, actualArray.Length); Assert.Equal(expectedArray.NullCount, actualArray.NullCount); - Assert.Equal(expectedArray.Offset, actualArray.Offset); + Assert.Equal(0, actualArray.Offset); - CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer); + CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer); - actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare)); + var listSize = ((FixedSizeListType)expectedArray.Data.DataType).ListSize; + var expectedValuesSlice = ArrowArrayFactory.Slice( + expectedArray.Values, expectedArray.Offset * listSize, expectedArray.Length * listSize); + actualArray.Values.Accept(new ArrayComparer(expectedValuesSlice, _strictCompare)); } - private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer expectedValidityBuffer, ArrowBuffer actualValidityBuffer) + private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer expectedValidityBuffer, int expectedBufferOffset, ArrowBuffer actualValidityBuffer) { if (_strictCompare) { Assert.True(expectedValidityBuffer.Span.SequenceEqual(actualValidityBuffer.Span)); } - else if (nullCount != 0 && arrayLength > 0) + else if (actualValidityBuffer.IsEmpty) + { + Assert.True(nullCount == 0 || arrayLength == 0); + } + else if (expectedBufferOffset % 8 == 0) { int validityBitmapByteCount = BitUtility.ByteCount(arrayLength); - ReadOnlySpan expectedSpanPartial = expectedValidityBuffer.Span.Slice(0, validityBitmapByteCount - 1); + int byteOffset = BitUtility.ByteCount(expectedBufferOffset); + ReadOnlySpan expectedSpanPartial = expectedValidityBuffer.Span.Slice(byteOffset, validityBitmapByteCount - 1); ReadOnlySpan actualSpanPartial = actualValidityBuffer.Span.Slice(0, validityBitmapByteCount - 1); // Compare the first validityBitmapByteCount - 1 bytes @@ -445,7 +487,7 @@ private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer e // Compare the last byte bitwise (because there is no guarantee about the value of // bits outside the range [0, arrayLength]) - ReadOnlySpan expectedSpanFull = expectedValidityBuffer.Span.Slice(0, validityBitmapByteCount); + ReadOnlySpan expectedSpanFull = expectedValidityBuffer.Span.Slice(byteOffset, validityBitmapByteCount); ReadOnlySpan actualSpanFull = actualValidityBuffer.Span.Slice(0, validityBitmapByteCount); for (int i = 8 * (validityBitmapByteCount - 1); i < arrayLength; i++) { @@ -454,6 +496,18 @@ private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer e string.Format("Bit at index {0}/{1} is not equal", i, arrayLength)); } } + else + { + // Have to compare all values bitwise + var expectedSpan = expectedValidityBuffer.Span; + var actualSpan = actualValidityBuffer.Span; + for (int i = 0; i < arrayLength; i++) + { + Assert.True( + BitUtility.GetBit(expectedSpan, expectedBufferOffset + i) == BitUtility.GetBit(actualSpan, i), + string.Format("Bit at index {0}/{1} is not equal", i, arrayLength)); + } + } } } } diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs index c4c0b6ec9ff21..db8369fa618e9 100644 --- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs +++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs @@ -203,7 +203,37 @@ public async Task WriteBatchWithNullsAsync() await TestRoundTripRecordBatchAsync(originalBatch); } - private static void TestRoundTripRecordBatches(List originalBatches, IpcOptions options = null) + [Theory] + [InlineData(0, 45)] + [InlineData(3, 45)] + [InlineData(16, 45)] + public void WriteSlicedArrays(int sliceOffset, int sliceLength) + { + var originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var slicedArrays = originalBatch.Arrays + .Select(array => ArrowArrayFactory.Slice(array, sliceOffset, sliceLength)) + .ToList(); + var slicedBatch = new RecordBatch(originalBatch.Schema, slicedArrays, sliceLength); + + TestRoundTripRecordBatch(slicedBatch, strictCompare: false); + } + + [Theory] + [InlineData(0, 45)] + [InlineData(3, 45)] + [InlineData(16, 45)] + public async Task WriteSlicedArraysAsync(int sliceOffset, int sliceLength) + { + var originalBatch = TestData.CreateSampleRecordBatch(length: 100); + var slicedArrays = originalBatch.Arrays + .Select(array => ArrowArrayFactory.Slice(array, sliceOffset, sliceLength)) + .ToList(); + var slicedBatch = new RecordBatch(originalBatch.Schema, slicedArrays, sliceLength); + + await TestRoundTripRecordBatchAsync(slicedBatch, strictCompare: false); + } + + private static void TestRoundTripRecordBatches(List originalBatches, IpcOptions options = null, bool strictCompare = true) { using (MemoryStream stream = new MemoryStream()) { @@ -223,13 +253,13 @@ private static void TestRoundTripRecordBatches(List originalBatches foreach (RecordBatch originalBatch in originalBatches) { RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch, strictCompare: strictCompare); } } } } - private static async Task TestRoundTripRecordBatchesAsync(List originalBatches, IpcOptions options = null) + private static async Task TestRoundTripRecordBatchesAsync(List originalBatches, IpcOptions options = null, bool strictCompare = true) { using (MemoryStream stream = new MemoryStream()) { @@ -249,20 +279,20 @@ private static async Task TestRoundTripRecordBatchesAsync(List orig foreach (RecordBatch originalBatch in originalBatches) { RecordBatch newBatch = reader.ReadNextRecordBatch(); - ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); + ArrowReaderVerifier.CompareBatches(originalBatch, newBatch, strictCompare: strictCompare); } } } } - private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) + private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null, bool strictCompare = true) { - TestRoundTripRecordBatches(new List { originalBatch }, options); + TestRoundTripRecordBatches(new List { originalBatch }, options, strictCompare: strictCompare); } - private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) + private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null, bool strictCompare = true) { - await TestRoundTripRecordBatchesAsync(new List { originalBatch }, options); + await TestRoundTripRecordBatchesAsync(new List { originalBatch }, options, strictCompare: strictCompare); } [Fact] diff --git a/csharp/test/Apache.Arrow.Tests/BinaryViewArrayTests.cs b/csharp/test/Apache.Arrow.Tests/BinaryViewArrayTests.cs new file mode 100644 index 0000000000000..7c18a49e96944 --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/BinaryViewArrayTests.cs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Xunit; + +namespace Apache.Arrow.Tests; + +public class BinaryViewArrayTests +{ + [Fact] + public void SliceBinaryViewArray() + { + var array = new BinaryViewArray.Builder() + .Append(new byte[] { 0, 1, 2 }) + .Append(new byte[] { 3, 4 }) + .AppendNull() + .Append(new byte[] { 5, 6 }) + .Append(new byte[] { 7, 8 }) + .Build(); + + var slice = (BinaryViewArray)array.Slice(1, 3); + + Assert.Equal(3, slice.Length); + Assert.Equal(new byte[] {3, 4}, slice.GetBytes(0).ToArray()); + Assert.True(slice.GetBytes(1).IsEmpty); + Assert.Equal(new byte[] {5, 6}, slice.GetBytes(2).ToArray()); + } +} diff --git a/csharp/test/Apache.Arrow.Tests/FixedSizeBinaryArrayTests.cs b/csharp/test/Apache.Arrow.Tests/FixedSizeBinaryArrayTests.cs new file mode 100644 index 0000000000000..abc66d6ce9c9d --- /dev/null +++ b/csharp/test/Apache.Arrow.Tests/FixedSizeBinaryArrayTests.cs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Linq; +using Apache.Arrow.Arrays; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Tests; + +public class FixedSizeBinaryArrayTests +{ + [Fact] + public void SliceFixedSizeBinaryArray() + { + const int byteWidth = 2; + const int length = 5; + const int nullCount = 1; + + var validityBuffer = new ArrowBuffer.BitmapBuilder() + .AppendRange(true, 2) + .Append(false) + .AppendRange(true, 2) + .Build(); + var dataBuffer = new ArrowBuffer.Builder() + .AppendRange(Enumerable.Range(0, length * byteWidth).Select(i => (byte)i)) + .Build(); + var arrayData = new ArrayData( + new FixedSizeBinaryType(byteWidth), + length, nullCount, 0, new [] {validityBuffer, dataBuffer}); + var array = new FixedSizeBinaryArray(arrayData); + + var slice = (FixedSizeBinaryArray)array.Slice(1, 3); + + Assert.Equal(3, slice.Length); + Assert.Equal(new byte[] {2, 3}, slice.GetBytes(0).ToArray()); + Assert.True(slice.GetBytes(1).IsEmpty); + Assert.Equal(new byte[] {6, 7}, slice.GetBytes(2).ToArray()); + } +}