Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(csharp/src/Drivers/Apache): extend capability of GetInfo for Spark driver #1863

76 changes: 24 additions & 52 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
Expand All @@ -35,17 +36,29 @@ public abstract class HiveServer2Connection : AdbcConnection
internal TTransport? transport;
internal TCLIService.Client? client;
internal TSessionHandle? sessionHandle;
private Lazy<string> _vendorVersion;
private Lazy<string> _vendorName;

internal HiveServer2Connection(IReadOnlyDictionary<string, string> properties)
{
this.properties = properties;
// Note: "LazyThreadSafetyMode.PublicationOnly" is thread-safe initialization where
// the first successful thread sets the value. If an exception is thrown, initialization
// will retry until it successfully returns a value without an exception.
// https://learn.microsoft.com/en-us/dotnet/framework/performance/lazy-initialization#exceptions-in-lazy-objects
_vendorVersion = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_VER), LazyThreadSafetyMode.PublicationOnly);
_vendorName = new Lazy<string>(() => GetInfoTypeStringValue(TGetInfoType.CLI_DBMS_NAME), LazyThreadSafetyMode.PublicationOnly);
}

internal TCLIService.Client Client
{
get { return this.client ?? throw new InvalidOperationException("connection not open"); }
}

protected string VendorVersion => _vendorVersion.Value;

protected string VendorName => _vendorName.Value;

internal async Task OpenAsync()
{
TProtocol protocol = await CreateProtocolAsync();
Expand Down Expand Up @@ -103,64 +116,23 @@ protected Schema GetSchema()
return SchemaParser.GetArrowSchema(response.Schema);
}

sealed class GetObjectsReader : IArrowArrayStream
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed unused code sealed class GetObjectsReader : IArrowArrayStream

private string GetInfoTypeStringValue(TGetInfoType infoType)
{
HiveServer2Connection? connection;
Schema schema;
List<TSparkArrowBatch>? batches;
int index;
IArrowReader? reader;

public GetObjectsReader(HiveServer2Connection connection, Schema schema)
TGetInfoReq req = new()
{
this.connection = connection;
this.schema = schema;
}
SessionHandle = this.sessionHandle ?? throw new InvalidOperationException("session not created"),
InfoType = infoType,
};

public Schema Schema { get { return schema; } }

public async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
TGetInfoResp getInfoResp = Client.GetInfo(req).Result;
if (getInfoResp.Status.StatusCode == TStatusCode.ERROR_STATUS)
{
while (true)
{
if (this.reader != null)
{
RecordBatch? next = await this.reader.ReadNextRecordBatchAsync(cancellationToken);
if (next != null)
{
return next;
}
this.reader = null;
}

if (this.batches != null && this.index < this.batches.Count)
{
this.reader = new ArrowStreamReader(new ChunkStream(this.schema, this.batches[this.index++].Batch));
continue;
}

this.batches = null;
this.index = 0;

if (this.connection == null)
{
return null;
}

TFetchResultsReq request = new TFetchResultsReq(this.connection.operationHandle, TFetchOrientation.FETCH_NEXT, 50000);
TFetchResultsResp response = await this.connection.Client.FetchResults(request, cancellationToken);
this.batches = response.Results.ArrowBatches;

if (!response.HasMoreRows)
{
this.connection = null;
}
}
throw new HiveServer2Exception(getInfoResp.Status.ErrorMessage)
.SetNativeError(getInfoResp.Status.ErrorCode)
.SetSqlState(getInfoResp.Status.SqlState);
}

public void Dispose()
{
}
return getInfoResp.InfoValue.StringValue;
}
}
}
46 changes: 37 additions & 9 deletions csharp/src/Drivers/Apache/Spark/SparkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ public class SparkConnection : HiveServer2Connection
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorName
AdbcInfoCode.VendorName,
AdbcInfoCode.VendorSql,
AdbcInfoCode.VendorVersion,
};

const string InfoDriverName = "ADBC Spark Driver";
// TODO: Make this dynamically return current version
const string InfoDriverVersion = "1.0.0";
const string InfoVendorName = "Spark";
const string InfoDriverArrowVersion = "1.0.0";
const bool InfoVendorSql = true;
const int DecimalPrecisionDefault = 10;
const int DecimalScaleDefault = 0;

Expand Down Expand Up @@ -137,6 +140,7 @@ public override AdbcStatement CreateStatement()
public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
{
const int strValTypeID = 0;
const int boolValTypeId = 1;

UnionType infoUnionType = new UnionType(
new Field[]
Expand Down Expand Up @@ -178,8 +182,11 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
ArrowBuffer.Builder<byte> typeBuilder = new ArrowBuffer.Builder<byte>();
ArrowBuffer.Builder<int> offsetBuilder = new ArrowBuffer.Builder<int>();
StringArray.Builder stringInfoBuilder = new StringArray.Builder();
BooleanArray.Builder booleanInfoBuilder = new BooleanArray.Builder();

int nullCount = 0;
int arrayLength = codes.Count;
int offset = 0;

foreach (AdbcInfoCode code in codes)
{
Expand All @@ -188,32 +195,53 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
case AdbcInfoCode.DriverName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.DriverArrowVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.Append(InfoDriverArrowVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorName:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
stringInfoBuilder.Append(InfoVendorName);
offsetBuilder.Append(offset++);
string vendorName = VendorName;
stringInfoBuilder.Append(vendorName);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorVersion:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(offset++);
string? vendorVersion = VendorVersion;
stringInfoBuilder.Append(vendorVersion);
booleanInfoBuilder.AppendNull();
break;
case AdbcInfoCode.VendorSql:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(boolValTypeId);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.Append(InfoVendorSql);
break;
default:
infoNameBuilder.Append((UInt32)code);
typeBuilder.Append(strValTypeID);
offsetBuilder.Append(stringInfoBuilder.Length);
offsetBuilder.Append(offset++);
stringInfoBuilder.AppendNull();
booleanInfoBuilder.AppendNull();
nullCount++;
break;
}
Expand All @@ -231,7 +259,7 @@ public override IArrowArrayStream GetInfo(IReadOnlyList<AdbcInfoCode> codes)
IArrowArray[] childrenArrays = new IArrowArray[]
{
stringInfoBuilder.Build(),
new BooleanArray.Builder().Build(),
booleanInfoBuilder.Build(),
new Int64Array.Builder().Build(),
new Int32Array.Builder().Build(),
new ListArray.Builder(StringType.Default).Build(),
Expand Down
77 changes: 73 additions & 4 deletions csharp/test/Drivers/Apache/Spark/DriverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,30 @@ public async Task CanGetInfo()
{
AdbcConnection adbcConnection = NewConnection();

using IArrowArrayStream stream = adbcConnection.GetInfo(new List<AdbcInfoCode>() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName });
// Test the supported info codes
List<AdbcInfoCode> handledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.DriverName,
AdbcInfoCode.DriverVersion,
AdbcInfoCode.VendorName,
AdbcInfoCode.DriverArrowVersion,
AdbcInfoCode.VendorVersion,
AdbcInfoCode.VendorSql
};
using IArrowArrayStream stream = adbcConnection.GetInfo(handledCodes);

RecordBatch recordBatch = await stream.ReadNextRecordBatchAsync();
UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> expectedValues = new List<string>() { "DriverName", "DriverVersion", "VendorName" };
List<string> expectedValues = new List<string>()
{
"DriverName",
"DriverVersion",
"VendorName",
"DriverArrowVersion",
"VendorVersion",
"VendorSql"
};

for (int i = 0; i < infoNameArray.Length; i++)
{
Expand All @@ -98,8 +116,59 @@ public async Task CanGetInfo()

Assert.Contains(value.ToString(), expectedValues);

StringArray stringArray = (StringArray)valueArray.Fields[0];
Console.WriteLine($"{value}={stringArray.GetString(i)}");
switch (value)
{
case AdbcInfoCode.VendorSql:
// TODO: How does external developer know the second field is the boolean field?
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
bool? boolValue = booleanArray.GetValue(i);
OutputHelper?.WriteLine($"{value}={boolValue}");
Assert.True(boolValue);
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
string stringValue = stringArray.GetString(i);
OutputHelper?.WriteLine($"{value}={stringValue}");
Assert.NotNull(stringValue);
break;
}
}

// Test the unhandled info codes.
List<AdbcInfoCode> unhandledCodes = new List<AdbcInfoCode>()
{
AdbcInfoCode.VendorArrowVersion,
AdbcInfoCode.VendorSubstrait,
AdbcInfoCode.VendorSubstraitMaxVersion
};
using IArrowArrayStream stream2 = adbcConnection.GetInfo(unhandledCodes);

recordBatch = await stream2.ReadNextRecordBatchAsync();
infoNameArray = (UInt32Array)recordBatch.Column("info_name");

List<string> unexpectedValues = new List<string>()
{
"VendorArrowVersion",
"VendorSubstrait",
"VendorSubstraitMaxVersion"
};
for (int i = 0; i < infoNameArray.Length; i++)
{
AdbcInfoCode? value = (AdbcInfoCode?)infoNameArray.GetValue(i);
DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value");

Assert.Contains(value.ToString(), unexpectedValues);
switch (value)
{
case AdbcInfoCode.VendorSql:
BooleanArray booleanArray = (BooleanArray)valueArray.Fields[1];
Assert.Null(booleanArray.GetValue(i));
break;
default:
StringArray stringArray = (StringArray)valueArray.Fields[0];
Assert.Null(stringArray.GetString(i));
break;
}
}
}

Expand Down
Loading