diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 4079fce26c..c4c7414c26 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -368,7 +368,7 @@ private async Task ExecuteMetadataCommandQuery(CancellationToken ca }; } - private async Task GetCrossReferenceAsync(CancellationToken cancellationToken = default) + protected virtual async Task GetCrossReferenceAsync(CancellationToken cancellationToken = default) { TGetCrossReferenceResp resp = await Connection.GetCrossReferenceAsync( CatalogName, @@ -383,7 +383,7 @@ private async Task GetCrossReferenceAsync(CancellationToken cancell return await GetQueryResult(resp.DirectResults, cancellationToken); } - private async Task GetPrimaryKeysAsync(CancellationToken cancellationToken = default) + protected virtual async Task GetPrimaryKeysAsync(CancellationToken cancellationToken = default) { TGetPrimaryKeysResp resp = await Connection.GetPrimaryKeysAsync( CatalogName, diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index 3cb0e544fd..1282d0c8aa 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -40,6 +40,7 @@ internal class DatabricksConnection : SparkHttpConnection private bool _applySSPWithQueries = false; private bool _enableDirectResults = true; private bool _enableMultipleCatalogSupport = true; + private bool _enablePKFK = true; internal static TSparkGetDirectResults defaultGetDirectResults = new() { @@ -71,6 +72,18 @@ protected override TCLIService.IAsync CreateTCLIServiceClient(TProtocol protocol private void ValidateProperties() { + if (Properties.TryGetValue(DatabricksParameters.EnablePKFK, out string? enablePKFKStr)) + { + if (bool.TryParse(enablePKFKStr, out bool enablePKFKValue)) + { + _enablePKFK = enablePKFKValue; + } + else + { + throw new ArgumentException($"Parameter '{DatabricksParameters.EnablePKFK}' value '{enablePKFKStr}' could not be parsed. Valid values are 'true', 'false'."); + } + } + if (Properties.TryGetValue(DatabricksParameters.EnableMultipleCatalogSupport, out string? enableMultipleCatalogSupportStr)) { if (bool.TryParse(enableMultipleCatalogSupportStr, out bool enableMultipleCatalogSupportValue)) @@ -204,6 +217,11 @@ private void ValidateProperties() /// internal bool EnableMultipleCatalogSupport => _enableMultipleCatalogSupport; + /// + /// Gets whether PK/FK metadata call is enabled + /// + public bool EnablePKFK => _enablePKFK; + /// /// Gets a value indicating whether to retry requests that receive a 503 response with a Retry-After header. /// diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs b/csharp/src/Drivers/Databricks/DatabricksParameters.cs index 85e33f62da..db62c04b25 100644 --- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs +++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs @@ -155,6 +155,12 @@ public class DatabricksParameters : SparkParameters /// Default value is true if not specified. /// public const string EnableMultipleCatalogSupport = "adbc.databricks.enable_multiple_catalog_support"; + + /// + /// Whether to enable primary key foreign key metadata call. + /// Default value is true if not specified. + /// + public const string EnablePKFK = "adbc.databricks.enable_pk_fk"; } /// diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs b/csharp/src/Drivers/Databricks/DatabricksStatement.cs index 1fdeee285b..cb92cdd5e4 100644 --- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs +++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs @@ -16,7 +16,6 @@ */ using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache; @@ -37,6 +36,7 @@ internal class DatabricksStatement : SparkStatement, IHiveServer2Statement private bool canDecompressLz4; private long maxBytesPerFile; private bool enableMultipleCatalogSupport; + private bool enablePKFK; public DatabricksStatement(DatabricksConnection connection) : base(connection) @@ -46,6 +46,7 @@ public DatabricksStatement(DatabricksConnection connection) canDecompressLz4 = connection.CanDecompressLz4; maxBytesPerFile = connection.MaxBytesPerFile; enableMultipleCatalogSupport = connection.EnableMultipleCatalogSupport; + enablePKFK = connection.EnablePKFK; } protected override void SetStatementProperties(TExecuteStatementReq statement) @@ -386,5 +387,119 @@ protected override async Task GetColumnsAsync(CancellationToken can // Call the base implementation with the potentially modified catalog name return await base.GetColumnsAsync(cancellationToken); } + + /// + /// Determines whether PK/FK metadata queries (GetPrimaryKeys/GetCrossReference) should return an empty result set without hitting the server. + /// + /// Why: + /// - For certain catalog names (null, empty, "SPARK", "hive_metastore"), Databricks does not support PK/FK metadata, + /// or these are legacy/synthesized catalogs that should gracefully return empty results for compatibility. + /// - The EnablePKFK flag allows the client to globally disable PK/FK metadata queries for performance or compatibility reasons. + /// + /// What it does: + /// - Returns true if PK/FK queries should return an empty result (and not hit the server), based on: + /// - The EnablePKFK flag (if false, always return empty) + /// - The catalog name (SPARK, hive_metastore, null, or empty string) + /// - Returns false if the query should proceed to the server (for valid, supported catalogs). + /// + internal bool ShouldReturnEmptyPkFkResult() + { + if (!enablePKFK) + return true; + + // Handle special catalog cases + if (string.IsNullOrEmpty(CatalogName) || + string.Equals(CatalogName, "SPARK", StringComparison.OrdinalIgnoreCase) || + string.Equals(CatalogName, "hive_metastore", StringComparison.OrdinalIgnoreCase)) + { + return true; + } + + return false; + } + + protected override async Task GetPrimaryKeysAsync(CancellationToken cancellationToken = default) + { + if (ShouldReturnEmptyPkFkResult()) + return EmptyPrimaryKeysResult(); + + return await base.GetPrimaryKeysAsync(cancellationToken); + } + + private QueryResult EmptyPrimaryKeysResult() + { + var fields = new[] + { + new Field("TABLE_CAT", StringType.Default, true), + new Field("TABLE_SCHEM", StringType.Default, true), + new Field("TABLE_NAME", StringType.Default, true), + new Field("COLUMN_NAME", StringType.Default, true), + new Field("KEQ_SEQ", Int32Type.Default, true), + new Field("PK_NAME", StringType.Default, true) + }; + var schema = new Schema(fields, null); + + var arrays = new IArrowArray[] + { + new StringArray.Builder().Build(), // TABLE_CAT + new StringArray.Builder().Build(), // TABLE_SCHEM + new StringArray.Builder().Build(), // TABLE_NAME + new StringArray.Builder().Build(), // COLUMN_NAME + new Int16Array.Builder().Build(), // KEQ_SEQ + new StringArray.Builder().Build() // PK_NAME + }; + + return new QueryResult(0, new HiveServer2Connection.HiveInfoArrowStream(schema, arrays)); + } + + protected override async Task GetCrossReferenceAsync(CancellationToken cancellationToken = default) + { + if (ShouldReturnEmptyPkFkResult()) + return EmptyCrossReferenceResult(); + + return await base.GetCrossReferenceAsync(cancellationToken); + } + + private QueryResult EmptyCrossReferenceResult() + { + var fields = new[] + { + new Field("PKTABLE_CAT", StringType.Default, true), + new Field("PKTABLE_SCHEM", StringType.Default, true), + new Field("PKTABLE_NAME", StringType.Default, true), + new Field("PKCOLUMN_NAME", StringType.Default, true), + new Field("FKTABLE_CAT", StringType.Default, true), + new Field("FKTABLE_SCHEM", StringType.Default, true), + new Field("FKTABLE_NAME", StringType.Default, true), + new Field("FKCOLUMN_NAME", StringType.Default, true), + new Field("KEY_SEQ", Int16Type.Default, true), + new Field("UPDATE_RULE", Int16Type.Default, true), + new Field("DELETE_RULE", Int16Type.Default, true), + new Field("FK_NAME", StringType.Default, true), + new Field("PK_NAME", StringType.Default, true), + new Field("DEFERRABILITY", Int16Type.Default, true) + }; + var schema = new Schema(fields, null); + + var arrays = new IArrowArray[] + { + new StringArray.Builder().Build(), // PKTABLE_CAT + new StringArray.Builder().Build(), // PKTABLE_SCHEM + new StringArray.Builder().Build(), // PKTABLE_NAME + new StringArray.Builder().Build(), // PKCOLUMN_NAME + new StringArray.Builder().Build(), // FKTABLE_CAT + new StringArray.Builder().Build(), // FKTABLE_SCHEM + new StringArray.Builder().Build(), // FKTABLE_NAME + new StringArray.Builder().Build(), // FKCOLUMN_NAME + new Int16Array.Builder().Build(), // KEY_SEQ + new Int16Array.Builder().Build(), // UPDATE_RULE + new Int16Array.Builder().Build(), // DELETE_RULE + new StringArray.Builder().Build(), // FK_NAME + new StringArray.Builder().Build(), // PK_NAME + new Int16Array.Builder().Build() // DEFERRABILITY + }; + + return new QueryResult(0, new HiveServer2Connection.HiveInfoArrowStream(schema, arrays)); + } } } diff --git a/csharp/test/Drivers/Databricks/StatementTests.cs b/csharp/test/Drivers/Databricks/StatementTests.cs index ef1a84b00e..368a5e0d73 100644 --- a/csharp/test/Drivers/Databricks/StatementTests.cs +++ b/csharp/test/Drivers/Databricks/StatementTests.cs @@ -652,5 +652,103 @@ private void AssertField(Schema schema, int index, string expectedName, IArrowTy Assert.True(expectedType.Equals(field.DataType), $"Field {index} type mismatch"); Assert.True(expectedNullable == field.IsNullable, $"Field {index} nullability mismatch"); } + + [Theory] + [InlineData(false, "main", true)] + [InlineData(true, null, true)] + [InlineData(true, "", true)] + [InlineData(true, "SPARK", true)] + [InlineData(true, "hive_metastore", true)] + [InlineData(true, "main", false)] + public void ShouldReturnEmptyPkFkResult_WorksAsExpected(bool enablePKFK, string? catalogName, bool expected) + { + // Arrange: create test configuration and connection + var testConfig = (DatabricksTestConfiguration)TestConfiguration.Clone(); + var connectionParams = new Dictionary + { + [DatabricksParameters.EnablePKFK] = enablePKFK.ToString().ToLowerInvariant() + }; + using var connection = NewConnection(testConfig, connectionParams); + var statement = connection.CreateStatement(); + + // Set CatalogName using SetOption + if(catalogName != null) + { + statement.SetOption(ApacheParameters.CatalogName, catalogName); + } + + // Act + var result = ((DatabricksStatement)statement).ShouldReturnEmptyPkFkResult(); + + // Assert + Assert.Equal(expected, result); + } + + [SkippableFact] + public async Task PKFK_EmptyResult_SchemaMatches_RealMetadataResponse() + { + // Arrange: create test configuration and connection + var testConfig = (DatabricksTestConfiguration)TestConfiguration.Clone(); + var connectionParams = new Dictionary + { + [DatabricksParameters.EnablePKFK] = "true" + }; + using var connection = NewConnection(testConfig, connectionParams); + var statement = connection.CreateStatement(); + + // Get real PK metadata schema + statement.SetOption(ApacheParameters.IsMetadataCommand, "true"); + statement.SetOption(ApacheParameters.CatalogName, "powerbi"); + statement.SetOption(ApacheParameters.SchemaName, TestConfiguration.Metadata.Schema); + statement.SetOption(ApacheParameters.TableName, TestConfiguration.Metadata.Table); + statement.SqlQuery = "GetPrimaryKeys"; + var realPkResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(realPkResult.Stream); + var realPkSchema = realPkResult.Stream.Schema; + + // Get empty PK result schema (using SPARK catalog which should return empty) + statement.SetOption(ApacheParameters.CatalogName, "SPARK"); + var emptyPkResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(emptyPkResult.Stream); + var emptyPkSchema = emptyPkResult.Stream.Schema; + + // Verify PK schemas match + Assert.Equal(realPkSchema.FieldsList.Count, emptyPkSchema.FieldsList.Count); + for (int i = 0; i < realPkSchema.FieldsList.Count; i++) + { + var realField = realPkSchema.FieldsList[i]; + var emptyField = emptyPkSchema.FieldsList[i]; + AssertField(emptyField, realField.Name, realField.DataType, realField.IsNullable); + } + + // Get real FK metadata schema + statement.SetOption(ApacheParameters.CatalogName, TestConfiguration.Metadata.Catalog); + statement.SqlQuery = "GetCrossReference"; + var realFkResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(realFkResult.Stream); + var realFkSchema = realFkResult.Stream.Schema; + + // Get empty FK result schema + statement.SetOption(ApacheParameters.CatalogName, "SPARK"); + var emptyFkResult = await statement.ExecuteQueryAsync(); + Assert.NotNull(emptyFkResult.Stream); + var emptyFkSchema = emptyFkResult.Stream.Schema; + + // Verify FK schemas match + Assert.Equal(realFkSchema.FieldsList.Count, emptyFkSchema.FieldsList.Count); + for (int i = 0; i < realFkSchema.FieldsList.Count; i++) + { + var realField = realFkSchema.FieldsList[i]; + var emptyField = emptyFkSchema.FieldsList[i]; + AssertField(emptyField, realField.Name, realField.DataType, realField.IsNullable); + } + } + + private void AssertField(Field field, string expectedName, IArrowType expectedType, bool expectedNullable) + { + Assert.True(expectedName.Equals(field.Name), $"Field name mismatch: expected {expectedName}, got {field.Name}"); + Assert.True(expectedType.Equals(field.DataType), $"Field type mismatch: expected {expectedType}, got {field.DataType}"); + Assert.True(expectedNullable == field.IsNullable, $"Field nullability mismatch: expected {expectedNullable}, got {field.IsNullable}"); + } } }