Skip to content

Commit

Permalink
fix(csharp/src/Apache.Arrow.Adbc): Fix marshaling in three functions …
Browse files Browse the repository at this point in the history
…where it was broken (#1758)

Closes #1694.
  • Loading branch information
CurtHagenlocher committed Apr 24, 2024
1 parent 282819d commit 9287e9a
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Apache.Arrow.C;
using Apache.Arrow.Ipc;


#if NETSTANDARD
using Apache.Arrow.Adbc.Extensions;
#endif
Expand Down Expand Up @@ -663,7 +665,7 @@ private unsafe AdbcStatusCode NewConnection(CAdbcConnection* nativeConnection, C
nativeConnection->private_data = (void*)GCHandle.ToIntPtr(handle);
return AdbcStatusCode.Success;
}
catch(Exception e)
catch (Exception e)
{
return SetError(error, e);
}
Expand Down Expand Up @@ -773,9 +775,31 @@ public unsafe void GetObjects(ref CAdbcConnection nativeConnection, int depth, b
columnNamePattern = Marshal.PtrToStringUTF8((IntPtr)column_name);
#endif

// TODO (GH-1694): Marshaling is incorrect
GCHandle gch = GCHandle.FromIntPtr((IntPtr)table_type);
List<string> tableTypes = (List<string>)gch.Target;
List<string> tableTypes = null;
const int maxTableTypeCount = 100;
if (table_type != null)
{
int count = 0;
while (table_type[count] != null && count <= maxTableTypeCount)
{
count++;
}

if (count > maxTableTypeCount)
{
throw new InvalidOperationException($"We do not expect to get more than {maxTableTypeCount} table types");
}

tableTypes = new List<string>(count);
for (int i = 0; i < count; i++)
{
#if NETSTANDARD
tableTypes.Add(MarshalExtensions.PtrToStringUTF8((IntPtr)table_type[i]));
#else
tableTypes.Add(Marshal.PtrToStringUTF8((IntPtr)table_type[i]));
#endif
}
}

AdbcConnection.GetObjectsDepth goDepth = (AdbcConnection.GetObjectsDepth)depth;

Expand Down Expand Up @@ -812,20 +836,25 @@ public unsafe void GetTableTypes(CArrowArrayStream* cArrayStream)

public unsafe void ReadPartition(byte* serializedPartition, int serialized_length, CArrowArrayStream* stream)
{
// TODO (GH-1694): Marshaling is incorrect
GCHandle gch = GCHandle.FromIntPtr((IntPtr)serializedPartition);
PartitionDescriptor descriptor = (PartitionDescriptor)gch.Target;
byte[] partition = new byte[serialized_length];
fixed (byte* partitionPtr = partition)
{
Buffer.MemoryCopy(serializedPartition, partitionPtr, serialized_length, serialized_length);
}

CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(descriptor), stream);
CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(new PartitionDescriptor(partition)), stream);
}

public unsafe void GetInfo(int* info_codes, int info_codes_length, CArrowArrayStream* stream)
{
// TODO (GH-1694): Marshaling is incorrect
GCHandle gch = GCHandle.FromIntPtr((IntPtr)info_codes);
List<int> codes = (List<int>)gch.Target;
int[] infoCodes = new int[info_codes_length];
fixed (int* infoCodesPtr = infoCodes)
{
long length = (long)info_codes_length * sizeof(int);
Buffer.MemoryCopy(info_codes, infoCodesPtr, length, length);
}

CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(codes), stream);
CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(infoCodes.ToList()), stream);
}

public unsafe void InitConnection(ref CAdbcDatabase nativeDatabase)
Expand Down

0 comments on commit 9287e9a

Please sign in to comment.