Skip to content
4 changes: 2 additions & 2 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1139,11 +1139,11 @@ internal async Task<TGetCrossReferenceResp> GetCrossReferenceAsync(
{
req.ForeignCatalogName = foreignCatalogName!;
}
if (schemaName != null)
if (foreignSchemaName != null)
{
req.ForeignSchemaName = foreignSchemaName!;
}
if (tableName != null)
if (foreignTableName != null)
{
req.ForeignTableName = foreignTableName!;
}
Expand Down
236 changes: 235 additions & 1 deletion csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ internal class HiveServer2Statement : AdbcStatement
private const string GetSchemasCommandName = "getschemas";
private const string GetTablesCommandName = "gettables";
private const string GetColumnsCommandName = "getcolumns";
private const string GetColumnsExtendedCommandName = "getcolumnsextended";
private const string SupportedMetadataCommands =
GetCatalogsCommandName + "," +
GetSchemasCommandName + "," +
GetTablesCommandName + "," +
GetColumnsCommandName + "," +
GetPrimaryKeysCommandName + "," +
GetCrossReferenceCommandName;
GetCrossReferenceCommandName + "," +
GetColumnsExtendedCommandName;

internal HiveServer2Statement(HiveServer2Connection connection)
{
Expand Down Expand Up @@ -360,6 +362,7 @@ private async Task<QueryResult> ExecuteMetadataCommandQuery(CancellationToken ca
GetColumnsCommandName => await GetColumnsAsync(cancellationToken),
GetPrimaryKeysCommandName => await GetPrimaryKeysAsync(cancellationToken),
GetCrossReferenceCommandName => await GetCrossReferenceAsync(cancellationToken),
GetColumnsExtendedCommandName => await GetColumnsExtendedAsync(cancellationToken),
null or "" => throw new ArgumentNullException(nameof(SqlQuery), $"Metadata command for property 'SqlQuery' must not be empty or null. Supported metadata commands: {SupportedMetadataCommands}"),
_ => throw new NotSupportedException($"Metadata command '{SqlQuery}' is not supported. Supported metadata commands: {SupportedMetadataCommands}"),
};
Expand Down Expand Up @@ -587,5 +590,236 @@ protected internal QueryResult EnhanceGetColumnsResult(Schema originalSchema, IR

return new QueryResult(rowCount, new HiveServer2Connection.HiveInfoArrowStream(enhancedSchema, enhancedData));
}

// Helper method to read all batches from a stream
private async Task<(List<RecordBatch> Batches, Schema Schema, int TotalRows)> ReadAllBatchesAsync(
IArrowArrayStream stream, CancellationToken cancellationToken)
{
List<RecordBatch> batches = new List<RecordBatch>();
int totalRows = 0;
Schema schema = stream.Schema;

// Read all batches
while (true)
{
var batch = await stream.ReadNextRecordBatchAsync(cancellationToken);
if (batch == null) break;

if (batch.Length > 0)
{
batches.Add(batch);
totalRows += batch.Length;
}
else
{
batch.Dispose();
}
}

return (batches, schema, totalRows);
}

private async Task<QueryResult> GetColumnsExtendedAsync(CancellationToken cancellationToken = default)
{
// 1. Get all three results at once
var columnsResult = await GetColumnsAsync(cancellationToken);
if (columnsResult.Stream == null) return columnsResult;

var pkResult = await GetPrimaryKeysAsync(cancellationToken);

// For FK lookup, we need to pass in the current catalog/schema/table as the foreign table
var resp = await Connection.GetCrossReferenceAsync(
null,
null,
null,
CatalogName,
SchemaName,
TableName,
cancellationToken);

var fkResult = await GetQueryResult(resp.DirectResults, cancellationToken);

// 2. Read all batches into memory
List<RecordBatch> columnsBatches;
int totalRows;
Schema columnsSchema;
StringArray? columnNames = null;
int colNameIndex = -1;

// Extract column data
using (var stream = columnsResult.Stream)
{
colNameIndex = stream.Schema.GetFieldIndex("COLUMN_NAME");
if (colNameIndex < 0) return columnsResult; // Can't match without column names

var batchResult = await ReadAllBatchesAsync(stream, cancellationToken);
columnsBatches = batchResult.Batches;
columnsSchema = batchResult.Schema;
totalRows = batchResult.TotalRows;

if (columnsBatches.Count == 0) return columnsResult;

// Create column names array from all batches using ArrayDataConcatenator.Concatenate
List<ArrayData> columnNameArrayDataList = columnsBatches.Select(batch =>
batch.Column(colNameIndex).Data).ToList();
ArrayData? concatenatedColumnNames = ArrayDataConcatenator.Concatenate(columnNameArrayDataList);
columnNames = (StringArray)ArrowArrayFactory.BuildArray(concatenatedColumnNames!);
}

// 3. Create combined schema and prepare data
var allFields = new List<Field>(columnsSchema.FieldsList);
var combinedData = new List<IArrowArray>();

// 4. Add all columns data by combining all batches
for (int colIdx = 0; colIdx < columnsSchema.FieldsList.Count; colIdx++)
{
if (columnsBatches.Count == 0)
continue;

var field = columnsSchema.GetFieldByIndex(colIdx);

// Collect arrays for this column from all batches
var columnArrays = new List<IArrowArray>();
foreach (var batch in columnsBatches)
{
columnArrays.Add(batch.Column(colIdx));
}

List<ArrayData> arrayDataList = columnArrays.Select(arr => arr.Data).ToList();
ArrayData? concatenatedData = ArrayDataConcatenator.Concatenate(arrayDataList);
IArrowArray array = ArrowArrayFactory.BuildArray(concatenatedData);
combinedData.Add(array);

}

// 5. Process PK and FK data using helper methods with selected fields
await ProcessRelationshipDataSafe(pkResult, "PK_", "COLUMN_NAME",
new[] { "COLUMN_NAME" }, // Selected PK fields
columnNames, totalRows,
allFields, combinedData, cancellationToken);

await ProcessRelationshipDataSafe(fkResult, "FK_", "FKCOLUMN_NAME",
new[] { "PKCOLUMN_NAME", "PKTABLE_CAT", "PKTABLE_SCHEM", "PKTABLE_NAME", "FKCOLUMN_NAME" }, // Selected FK fields
columnNames, totalRows,
allFields, combinedData, cancellationToken);

// 6. Return the combined result
var combinedSchema = new Schema(allFields, columnsSchema.Metadata);

return new QueryResult(totalRows, new HiveServer2Connection.HiveInfoArrowStream(combinedSchema, combinedData));
}

/**
* Process relationship data (primary/foreign keys) from query results and add to the output.
* This method handles data from PK/FK queries and correlates it with column data.
*
* How it works:
* 1. Add relationship columns to the schema (PK/FK columns with prefixed names)
* 2. Read relationship data from source records
* 3. Build a mapping of column names to their relationship values
* 4. Create arrays for each field, aligning values with the main column result
*/
private async Task ProcessRelationshipDataSafe(QueryResult result, string prefix, string relationColNameField,
string[] includeFields, StringArray colNames, int rowCount,
List<Field> allFields, List<IArrowArray> combinedData, CancellationToken cancellationToken)
{
// STEP 1: Add relationship fields to the output schema
// Each field name is prefixed (e.g., "PK_" for primary keys, "FK_" for foreign keys)
foreach (var fieldName in includeFields)
{
allFields.Add(new Field(prefix + fieldName, StringType.Default, true));
}

// STEP 2: Create a dictionary to map column names to their relationship values
// Structure: Dictionary<fieldName, Dictionary<columnName, relationshipValue>>
// For primary keys - only columns that are PKs are stored:
// {"COLUMN_NAME": {"id": "id"}}
// For foreign keys - only columns that are FKs are stored:
// {"FKCOLUMN_NAME": {"DOLocationId": "LocationId"}}
var relationData = new Dictionary<string, Dictionary<string, string>>(StringComparer.OrdinalIgnoreCase);

// STEP 3: Extract relationship data from the query result
if (result.Stream != null)
{
using (var stream = result.Stream)
{
// Find the column index that contains our key values (e.g., COLUMN_NAME for PK or FKCOLUMN_NAME for FK)
int keyColIndex = stream.Schema.GetFieldIndex(relationColNameField);
if (keyColIndex >= 0)
{
// STEP 3.1: Process each record batch from the relationship data source
while (true)
{
var batch = await stream.ReadNextRecordBatchAsync(cancellationToken);
if (batch == null) break;

// STEP 3.2: Map field names to their column indices for quick lookup
Dictionary<string, int> fieldIndices = new Dictionary<string, int>();
foreach (var fieldName in includeFields)
{
int index = stream.Schema.GetFieldIndex(fieldName);
if (index >= 0) fieldIndices[fieldName] = index;
}

// STEP 3.3: Process each row in the batch
for (int i = 0; i < batch.Length; i++)
{
// Get the key column value (e.g., column name this relationship applies to)
StringArray keyCol = (StringArray)batch.Column(keyColIndex);
if (keyCol.IsNull(i)) continue;

string keyValue = keyCol.GetString(i);
if (string.IsNullOrEmpty(keyValue)) continue;

// STEP 3.4: For each included field, extract its value and store in our map
foreach (var pair in fieldIndices)
{
// Ensure we have an entry for this field
if (!relationData.TryGetValue(pair.Key, out var fieldData))
{
fieldData = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
relationData[pair.Key] = fieldData;
}
StringArray fieldArray = (StringArray)batch.Column(pair.Value);
// Store the relationship value: columnName -> value
relationData[pair.Key][keyValue] = fieldArray.GetString(i);
}
}
}
}
}
}

// STEP 4: Build Arrow arrays for each relationship field
// These arrays align with the main column results, so each row contains
// the appropriate relationship value for its column
foreach (var fieldName in includeFields)
{
// Create a string array builder
var builder = new StringArray.Builder();
var fieldData = relationData.ContainsKey(fieldName) ? relationData[fieldName] : null;

// Process each column name in the main result
for (int i = 0; i < colNames.Length; i++)
{
string? colName = colNames.GetString(i);
string? value = null;

// Look up relationship value for this column
if (!string.IsNullOrEmpty(colName) &&
fieldData != null &&
fieldData.TryGetValue(colName!, out var fieldValue))
{
value = fieldValue;
}

// Add to the array (empty string if no relationship exists)
builder.Append(value);
}

// Add the completed array to our output data
combinedData.Add(builder.Build());
}
}
}
}
97 changes: 97 additions & 0 deletions csharp/test/Drivers/Databricks/StatementTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,103 @@ public async Task CanGetColumnsWithBaseTypeName()
Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, actualBatchLength);
}

[SkippableFact]
public async Task CanGetColumnsExtended()
{
// Get the runtime version using GetInfo
var infoCodes = new List<AdbcInfoCode> { AdbcInfoCode.VendorVersion };
var infoValues = Connection.GetInfo(infoCodes);

// Set up statement for GetColumnsExtended
var statement = Connection.CreateStatement();
statement.SetOption(ApacheParameters.IsMetadataCommand, "true");
statement.SetOption(ApacheParameters.CatalogName, TestConfiguration.Metadata.Catalog);
statement.SetOption(ApacheParameters.SchemaName, TestConfiguration.Metadata.Schema);
statement.SetOption(ApacheParameters.TableName, TestConfiguration.Metadata.Table);
statement.SqlQuery = "GetColumnsExtended";

QueryResult queryResult = await statement.ExecuteQueryAsync();
Assert.NotNull(queryResult.Stream);

// Verify schema has more fields than the regular GetColumns result (which has 24 fields)
// We expect additional PK and FK fields
OutputHelper?.WriteLine($"Column count in result schema: {queryResult.Stream.Schema.FieldsList.Count}");
Assert.True(queryResult.Stream.Schema.FieldsList.Count > 24,
"GetColumnsExtended should return more columns than GetColumns (at least 24+)");

// Verify that key fields from each original metadata call are present
bool hasColumnName = false;
bool hasPkKeySeq = false;
bool hasFkTableName = false;

foreach (var field in queryResult.Stream.Schema.FieldsList)
{
OutputHelper?.WriteLine($"Field in schema: {field.Name} ({field.DataType})");

if (field.Name.Equals("COLUMN_NAME", StringComparison.OrdinalIgnoreCase))
hasColumnName = true;
else if (field.Name.Equals("PK_COLUMN_NAME", StringComparison.OrdinalIgnoreCase))
hasPkKeySeq = true;
else if (field.Name.Equals("FK_PKTABLE_NAME", StringComparison.OrdinalIgnoreCase))
hasFkTableName = true;
}

Assert.True(hasColumnName, "Schema should contain COLUMN_NAME field from GetColumns");
Assert.True(hasPkKeySeq, "Schema should contain PK_KEY_SEQ field from GetPrimaryKeys");
Assert.True(hasFkTableName, "Schema should contain FK_PKTABLE_NAME field from GetCrossReference");

// Read and verify data
int rowCount = 0;
while (queryResult.Stream != null)
{
RecordBatch? batch = await queryResult.Stream.ReadNextRecordBatchAsync();
if (batch == null) break;

rowCount += batch.Length;

// Output rows for debugging (limit to first 10)
if (batch.Length > 0)
{
int rowsToPrint = Math.Min(batch.Length, 10); // Limit to 10 rows
OutputHelper?.WriteLine($"Found {batch.Length} rows, showing first {rowsToPrint}:");

for (int rowIndex = 0; rowIndex < rowsToPrint; rowIndex++)
{
OutputHelper?.WriteLine($"Row {rowIndex}:");
for (int i = 0; i < batch.ColumnCount; i++)
{
string fieldName = queryResult.Stream.Schema.FieldsList[i].Name;
string fieldValue = GetStringValue(batch.Column(i), rowIndex);
OutputHelper?.WriteLine($" {fieldName}: {fieldValue}");
}
OutputHelper?.WriteLine(""); // Add blank line between rows
}
}
}

// Verify we got rows matching the expected column count
Assert.Equal(TestConfiguration.Metadata.ExpectedColumnCount, rowCount);
OutputHelper?.WriteLine($"Successfully retrieved {rowCount} columns with extended information");
}

// Helper method to get string representation of array values
private string GetStringValue(IArrowArray array, int index)
{
if (array == null || index >= array.Length || array.IsNull(index))
return "null";

if (array is StringArray strArray)
return strArray.GetString(index) ?? "null";
else if (array is Int32Array int32Array)
return int32Array.GetValue(index).ToString() ?? "null";
else if (array is Int16Array int16Array)
return int16Array.GetValue(index).ToString() ?? "null";
else if (array is BooleanArray boolArray)
return boolArray.GetValue(index).ToString() ?? "null";

return "unknown";
}

protected override void PrepareCreateTableWithPrimaryKeys(out string sqlUpdate, out string tableNameParent, out string fullTableNameParent, out IReadOnlyList<string> primaryKeys)
{
CreateNewTableName(out tableNameParent, out fullTableNameParent);
Expand Down
Loading