diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs index 4455a25bc0..a8b3089b3d 100644 --- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs +++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs @@ -17,10 +17,12 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Runtime.InteropServices; using Apache.Arrow.C; using Apache.Arrow.Ipc; + #if NETSTANDARD using Apache.Arrow.Adbc.Extensions; #endif @@ -663,7 +665,7 @@ private unsafe AdbcStatusCode NewConnection(CAdbcConnection* nativeConnection, C nativeConnection->private_data = (void*)GCHandle.ToIntPtr(handle); return AdbcStatusCode.Success; } - catch(Exception e) + catch (Exception e) { return SetError(error, e); } @@ -773,9 +775,31 @@ public unsafe void GetObjects(ref CAdbcConnection nativeConnection, int depth, b columnNamePattern = Marshal.PtrToStringUTF8((IntPtr)column_name); #endif - // TODO (GH-1694): Marshaling is incorrect - GCHandle gch = GCHandle.FromIntPtr((IntPtr)table_type); - List tableTypes = (List)gch.Target; + List tableTypes = null; + const int maxTableTypeCount = 100; + if (table_type != null) + { + int count = 0; + while (table_type[count] != null && count <= maxTableTypeCount) + { + count++; + } + + if (count > maxTableTypeCount) + { + throw new InvalidOperationException($"We do not expect to get more than {maxTableTypeCount} table types"); + } + + tableTypes = new List(count); + for (int i = 0; i < count; i++) + { +#if NETSTANDARD + tableTypes.Add(MarshalExtensions.PtrToStringUTF8((IntPtr)table_type[i])); +#else + tableTypes.Add(Marshal.PtrToStringUTF8((IntPtr)table_type[i])); +#endif + } + } AdbcConnection.GetObjectsDepth goDepth = (AdbcConnection.GetObjectsDepth)depth; @@ -812,20 +836,25 @@ public unsafe void GetTableTypes(CArrowArrayStream* cArrayStream) public unsafe void ReadPartition(byte* serializedPartition, int serialized_length, CArrowArrayStream* stream) { - // TODO (GH-1694): Marshaling is incorrect - GCHandle gch = GCHandle.FromIntPtr((IntPtr)serializedPartition); - PartitionDescriptor descriptor = (PartitionDescriptor)gch.Target; + byte[] partition = new byte[serialized_length]; + fixed (byte* partitionPtr = partition) + { + Buffer.MemoryCopy(serializedPartition, partitionPtr, serialized_length, serialized_length); + } - CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(descriptor), stream); + CArrowArrayStreamExporter.ExportArrayStream(connection.ReadPartition(new PartitionDescriptor(partition)), stream); } public unsafe void GetInfo(int* info_codes, int info_codes_length, CArrowArrayStream* stream) { - // TODO (GH-1694): Marshaling is incorrect - GCHandle gch = GCHandle.FromIntPtr((IntPtr)info_codes); - List codes = (List)gch.Target; + int[] infoCodes = new int[info_codes_length]; + fixed (int* infoCodesPtr = infoCodes) + { + long length = (long)info_codes_length * sizeof(int); + Buffer.MemoryCopy(info_codes, infoCodesPtr, length, length); + } - CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(codes), stream); + CArrowArrayStreamExporter.ExportArrayStream(connection.GetInfo(infoCodes.ToList()), stream); } public unsafe void InitConnection(ref CAdbcDatabase nativeDatabase)