Skip to content

Commit

Permalink
GH-40517: [C#] Fix writing sliced arrays to IPC format (#41197)
Browse files Browse the repository at this point in the history
### Rationale for this change

Fixes writing sliced arrays to IPC files or streams, so that they can be successfully read back in. Previously, writing such data would succeed but then couldn't be read.

### What changes are included in this PR?

* Fixes `BinaryViewArray.GetBytes` to account for the array offset
* Fixes `FixedSizeBinaryArray.GetBytes` to account for the array offset
* Updates `ArrowStreamWriter` so that it writes slices of buffers when required, and handles slicing bitmap arrays by creating a copy if the offset isn't a multiple of 8
* Refactors `ArrowStreamWriter`, making the `ArrowRecordBatchFlatBufferBuilder` class responsible for building a list of field nodes as well as buffers. This was required to avoid having to duplicate logic for handling array types with child data between the `ArrowRecordBatchFlatBufferBuilder` class and the `CreateSelfAndChildrenFieldNodes` method, which I've removed.

Note that after this change, we still write more data than required when writing a slice of a `ListArray`, `BinaryArray`, `ListViewArray`, `BinaryViewArray` or `DenseUnionArray`. When writing a `ListArray` for example, we write slices of the null bitmap and value offsets and write the full values array. Ideally we should write a slice of the values and adjust the value offsets so they start at zero. The C++ implementation for example handles this [here](https://github.com/apache/arrow/blob/18c74b0733c9ff473a211259cf10705b2c9be891/cpp/src/arrow/ipc/writer.cc#L316). I will make a follow-up issue for this once this PR is merged.

### Are these changes tested?

Yes, I've added new unit tests for this.

### Are there any user-facing changes?

Yes, this is a user-facing bug fix.
* GitHub Issue: #40517

Authored-by: Adam Reeve <adreeve@gmail.com>
Signed-off-by: Curt Hagenlocher <curt@hagenlocher.org>
  • Loading branch information
adamreeve authored Apr 15, 2024
1 parent fec096a commit a6cdcd0
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 176 deletions.
2 changes: 1 addition & 1 deletion csharp/src/Apache.Arrow/Arrays/BinaryViewArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ public ReadOnlySpan<byte> GetBytes(int index, out bool isNull)
BinaryView binaryView = Views[index];
if (binaryView.IsInline)
{
return ViewsBuffer.Span.Slice(16 * index + 4, binaryView.Length);
return ViewsBuffer.Span.Slice(16 * (Offset + index) + 4, binaryView.Length);
}

return DataBuffer(binaryView._bufferIndex).Span.Slice(binaryView._bufferOffset, binaryView.Length);
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Apache.Arrow/Arrays/FixedSizeBinaryArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public ReadOnlySpan<byte> GetBytes(int index)
}

int size = ((FixedSizeBinaryType)Data.DataType).ByteWidth;
return ValueBuffer.Span.Slice(index * size, size);
return ValueBuffer.Span.Slice((Offset + index) * size, size);
}

int IReadOnlyCollection<byte[]>.Count => Length;
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ private ArrayData LoadField(

if (fieldNullCount < 0)
{
throw new InvalidDataException("Null count length must be >= 0"); // TODO:Localize exception message
throw new InvalidDataException("Null count must be >= 0"); // TODO:Localize exception message
}

int buffers;
Expand Down
309 changes: 172 additions & 137 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions csharp/test/Apache.Arrow.Tests/ArrowArrayConcatenatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public void TestStandardCases()
{
foreach ((List<IArrowArray> testTargetArrayList, IArrowArray expectedArray) in GenerateTestData())
{
if (expectedArray is UnionArray)
{
// Union array concatenation is incorrect. See https://github.com/apache/arrow/issues/41198
continue;
}

IArrowArray actualArray = ArrowArrayConcatenator.Concatenate(testTargetArrayList);
ArrowReaderVerifier.CompareArrays(expectedArray, actualArray);
}
Expand Down
32 changes: 30 additions & 2 deletions csharp/test/Apache.Arrow.Tests/ArrowFileWriterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

using Apache.Arrow.Ipc;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Types;
using Xunit;

namespace Apache.Arrow.Tests
Expand Down Expand Up @@ -106,13 +109,38 @@ public async Task WritesFooterAlignedMultipleOf8Async()
await ValidateRecordBatchFile(stream, originalBatch);
}

private async Task ValidateRecordBatchFile(Stream stream, RecordBatch recordBatch)
[Theory]
[InlineData(0, 45)]
[InlineData(3, 45)]
[InlineData(16, 45)]
public async Task WriteSlicedArrays(int sliceOffset, int sliceLength)
{
var originalBatch = TestData.CreateSampleRecordBatch(length: 100);
var slicedArrays = originalBatch.Arrays
.Select(array => ArrowArrayFactory.Slice(array, sliceOffset, sliceLength))
.ToList();
var slicedBatch = new RecordBatch(originalBatch.Schema, slicedArrays, sliceLength);

var stream = new MemoryStream();
var writer = new ArrowFileWriter(stream, slicedBatch.Schema, leaveOpen: true);

await writer.WriteRecordBatchAsync(slicedBatch);
await writer.WriteEndAsync();

stream.Position = 0;

// Disable strict comparison because we don't expect buffers to match exactly
// due to writing slices of buffers, and instead need to compare array values
await ValidateRecordBatchFile(stream, slicedBatch, strictCompare: false);
}

private async Task ValidateRecordBatchFile(Stream stream, RecordBatch recordBatch, bool strictCompare = true)
{
var reader = new ArrowFileReader(stream);
int count = await reader.RecordBatchCountAsync();
Assert.Equal(1, count);
RecordBatch readBatch = await reader.ReadRecordBatchAsync(0);
ArrowReaderVerifier.CompareBatches(recordBatch, readBatch);
ArrowReaderVerifier.CompareBatches(recordBatch, readBatch, strictCompare);
}

/// <summary>
Expand Down
106 changes: 80 additions & 26 deletions csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void Visit(StructArray array)

Assert.Equal(expectedArray.Length, array.Length);
Assert.Equal(expectedArray.NullCount, array.NullCount);
Assert.Equal(expectedArray.Offset, array.Offset);
Assert.Equal(0, array.Offset);
Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length);
Assert.Equal(expectedArray.Fields.Count, array.Fields.Count);

Expand All @@ -178,10 +178,42 @@ public void Visit(UnionArray array)
Assert.Equal(expectedArray.Mode, array.Mode);
Assert.Equal(expectedArray.Length, array.Length);
Assert.Equal(expectedArray.NullCount, array.NullCount);
Assert.Equal(expectedArray.Offset, array.Offset);
Assert.Equal(0, array.Offset);
Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length);
Assert.Equal(expectedArray.Fields.Count, array.Fields.Count);

if (_strictCompare)
{
Assert.True(expectedArray.TypeBuffer.Span.SequenceEqual(array.TypeBuffer.Span));
}
else
{
for (int i = 0; i < expectedArray.Length; i++)
{
Assert.Equal(expectedArray.TypeIds[i], array.TypeIds[i]);
}
}

if (_expectedArray is DenseUnionArray expectedDenseArray)
{
Assert.IsAssignableFrom<DenseUnionArray>(array);
var denseArray = array as DenseUnionArray;
Assert.NotNull(denseArray);

if (_strictCompare)
{
Assert.True(expectedDenseArray.ValueOffsetBuffer.Span.SequenceEqual(denseArray.ValueOffsetBuffer.Span));
}
else
{
for (int i = 0; i < expectedDenseArray.Length; i++)
{
Assert.Equal(
expectedDenseArray.ValueOffsets[i], denseArray.ValueOffsets[i]);
}
}
}

for (int i = 0; i < array.Fields.Count; i++)
{
array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i], _strictCompare));
Expand Down Expand Up @@ -220,9 +252,9 @@ private void CompareBinaryArrays<T>(BinaryArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Expand Down Expand Up @@ -252,9 +284,9 @@ private void CompareVariadicArrays<T>(BinaryViewArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

Assert.True(expectedArray.Views.SequenceEqual(actualArray.Views));

Expand All @@ -277,9 +309,9 @@ private void CompareArrays(FixedSizeBinaryArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Expand All @@ -306,9 +338,9 @@ private void CompareArrays<T>(PrimitiveArray<T> actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Expand Down Expand Up @@ -338,9 +370,9 @@ private void CompareArrays(BooleanArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Expand All @@ -365,18 +397,19 @@ private void CompareArrays(ListArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Assert.True(expectedArray.ValueOffsetsBuffer.Span.SequenceEqual(actualArray.ValueOffsetsBuffer.Span));
}
else
{
int offsetsStart = (expectedArray.Offset) * sizeof(int);
int offsetsLength = (expectedArray.Length + 1) * sizeof(int);
Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength)));
Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(offsetsStart, offsetsLength).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, offsetsLength)));
}

actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare));
Expand All @@ -391,9 +424,9 @@ private void CompareArrays(ListViewArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

if (_strictCompare)
{
Expand All @@ -402,9 +435,10 @@ private void CompareArrays(ListViewArray actualArray)
}
else
{
int start = expectedArray.Offset * sizeof(int);
int length = expectedArray.Length * sizeof(int);
Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(0, length).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, length)));
Assert.True(expectedArray.SizesBuffer.Span.Slice(0, length).SequenceEqual(actualArray.SizesBuffer.Span.Slice(0, length)));
Assert.True(expectedArray.ValueOffsetsBuffer.Span.Slice(start, length).SequenceEqual(actualArray.ValueOffsetsBuffer.Span.Slice(0, length)));
Assert.True(expectedArray.SizesBuffer.Span.Slice(start, length).SequenceEqual(actualArray.SizesBuffer.Span.Slice(0, length)));
}

actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare));
Expand All @@ -419,23 +453,31 @@ private void CompareArrays(FixedSizeListArray actualArray)

Assert.Equal(expectedArray.Length, actualArray.Length);
Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
Assert.Equal(expectedArray.Offset, actualArray.Offset);
Assert.Equal(0, actualArray.Offset);

CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, actualArray.NullBitmapBuffer);
CompareValidityBuffer(expectedArray.NullCount, _expectedArray.Length, expectedArray.NullBitmapBuffer, expectedArray.Offset, actualArray.NullBitmapBuffer);

actualArray.Values.Accept(new ArrayComparer(expectedArray.Values, _strictCompare));
var listSize = ((FixedSizeListType)expectedArray.Data.DataType).ListSize;
var expectedValuesSlice = ArrowArrayFactory.Slice(
expectedArray.Values, expectedArray.Offset * listSize, expectedArray.Length * listSize);
actualArray.Values.Accept(new ArrayComparer(expectedValuesSlice, _strictCompare));
}

private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer expectedValidityBuffer, ArrowBuffer actualValidityBuffer)
private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer expectedValidityBuffer, int expectedBufferOffset, ArrowBuffer actualValidityBuffer)
{
if (_strictCompare)
{
Assert.True(expectedValidityBuffer.Span.SequenceEqual(actualValidityBuffer.Span));
}
else if (nullCount != 0 && arrayLength > 0)
else if (actualValidityBuffer.IsEmpty)
{
Assert.True(nullCount == 0 || arrayLength == 0);
}
else if (expectedBufferOffset % 8 == 0)
{
int validityBitmapByteCount = BitUtility.ByteCount(arrayLength);
ReadOnlySpan<byte> expectedSpanPartial = expectedValidityBuffer.Span.Slice(0, validityBitmapByteCount - 1);
int byteOffset = BitUtility.ByteCount(expectedBufferOffset);
ReadOnlySpan<byte> expectedSpanPartial = expectedValidityBuffer.Span.Slice(byteOffset, validityBitmapByteCount - 1);
ReadOnlySpan<byte> actualSpanPartial = actualValidityBuffer.Span.Slice(0, validityBitmapByteCount - 1);

// Compare the first validityBitmapByteCount - 1 bytes
Expand All @@ -445,7 +487,7 @@ private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer e

// Compare the last byte bitwise (because there is no guarantee about the value of
// bits outside the range [0, arrayLength])
ReadOnlySpan<byte> expectedSpanFull = expectedValidityBuffer.Span.Slice(0, validityBitmapByteCount);
ReadOnlySpan<byte> expectedSpanFull = expectedValidityBuffer.Span.Slice(byteOffset, validityBitmapByteCount);
ReadOnlySpan<byte> actualSpanFull = actualValidityBuffer.Span.Slice(0, validityBitmapByteCount);
for (int i = 8 * (validityBitmapByteCount - 1); i < arrayLength; i++)
{
Expand All @@ -454,6 +496,18 @@ private void CompareValidityBuffer(int nullCount, int arrayLength, ArrowBuffer e
string.Format("Bit at index {0}/{1} is not equal", i, arrayLength));
}
}
else
{
// Have to compare all values bitwise
var expectedSpan = expectedValidityBuffer.Span;
var actualSpan = actualValidityBuffer.Span;
for (int i = 0; i < arrayLength; i++)
{
Assert.True(
BitUtility.GetBit(expectedSpan, expectedBufferOffset + i) == BitUtility.GetBit(actualSpan, i),
string.Format("Bit at index {0}/{1} is not equal", i, arrayLength));
}
}
}
}
}
Expand Down
Loading

0 comments on commit a6cdcd0

Please sign in to comment.