diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs index b87862f249..f101c22566 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs @@ -41,7 +41,7 @@ internal abstract class HiveServer2Connection : AdbcConnection internal const int PollTimeMillisecondsDefault = 500; private const int ConnectTimeoutMillisecondsDefault = 30000; private TTransport? _transport; - private TCLIService.Client? _client; + private TCLIService.IAsync? _client; private readonly Lazy _vendorVersion; private readonly Lazy _vendorName; @@ -287,7 +287,7 @@ internal HiveServer2Connection(IReadOnlyDictionary properties) } } - internal TCLIService.Client Client + internal TCLIService.IAsync Client { get { return _client ?? throw new InvalidOperationException("connection not open"); } } @@ -308,7 +308,7 @@ internal async Task OpenAsync() TTransport transport = CreateTransport(); TProtocol protocol = await CreateProtocolAsync(transport, cancellationToken); _transport = protocol.Transport; - _client = new TCLIService.Client(protocol); + _client = CreateTCLIServiceClient(protocol); TOpenSessionReq request = CreateSessionRequest(); TOpenSessionResp? session = await Client.OpenSession(request, cancellationToken); @@ -338,6 +338,11 @@ internal async Task OpenAsync() } } + protected virtual TCLIService.IAsync CreateTCLIServiceClient(TProtocol protocol) + { + return new TCLIService.Client(protocol); + } + internal TSessionHandle? SessionHandle { get; private set; } protected internal DataTypeConversion DataTypeConversion { get; set; } = DataTypeConversion.None; @@ -696,7 +701,10 @@ public override void Dispose() TCloseSessionReq r6 = new(SessionHandle); _client.CloseSession(r6, cancellationToken).Wait(); _transport?.Close(); - _client.Dispose(); + if (_client is IDisposable disposableClient) + { + disposableClient.Dispose(); + } _transport = null; _client = null; } diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs index 5f908815a7..a7bb35d98e 100644 --- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs +++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs @@ -312,7 +312,7 @@ protected internal int QueryTimeoutSeconds public TOperationHandle? OperationHandle { get; private set; } // Keep the original Client property for internal use - public TCLIService.Client Client => Connection.Client; + public TCLIService.IAsync Client => Connection.Client; private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0 ? pollTimeMilliseconds diff --git a/csharp/src/Drivers/Apache/Hive2/ThreadSafeClient.cs b/csharp/src/Drivers/Apache/Hive2/ThreadSafeClient.cs new file mode 100644 index 0000000000..1e0b06bbf6 --- /dev/null +++ b/csharp/src/Drivers/Apache/Hive2/ThreadSafeClient.cs @@ -0,0 +1,241 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Apache.Hive.Service.Rpc.Thrift; + +namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2.Client +{ + /// + /// A thread-safe wrapper for TCLIService.IAsync client that ensures all operations + /// are properly synchronized to prevent concurrent access issues. + /// + internal class ThreadSafeClient : IDisposable, TCLIService.IAsync + { + private readonly TCLIService.IAsync _client; + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + private bool _disposed = false; + + /// + /// Initializes a new instance of the class. + /// + /// The TCLIService client to wrap. + internal ThreadSafeClient(TCLIService.IAsync client) + { + _client = client ?? throw new ArgumentNullException(nameof(client)); + } + + /// + /// Executes a client operation in a thread-safe manner. + /// + /// The type of the result. + /// The operation to execute. + /// The cancellation token. + /// A task representing the asynchronous operation with the result. + private async Task ExecuteOperationAsync( + Func> operation, + CancellationToken cancellationToken) + { + ThrowIfDisposed(); + await _semaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + return await operation(_client, cancellationToken).ConfigureAwait(false); + } + finally + { + _semaphore.Release(); + } + } + + #region TCLIService.IAsync Implementation + + /// + public Task OpenSession(TOpenSessionReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.OpenSession(req, token), cancellationToken); + } + + /// + public Task CloseSession(TCloseSessionReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.CloseSession(req, token), cancellationToken); + } + + /// + public Task GetInfo(TGetInfoReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetInfo(req, token), cancellationToken); + } + + /// + public Task ExecuteStatement(TExecuteStatementReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.ExecuteStatement(req, token), cancellationToken); + } + + /// + public Task GetTypeInfo(TGetTypeInfoReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetTypeInfo(req, token), cancellationToken); + } + + /// + public Task GetCatalogs(TGetCatalogsReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetCatalogs(req, token), cancellationToken); + } + + /// + public Task GetSchemas(TGetSchemasReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetSchemas(req, token), cancellationToken); + } + + /// + public Task GetTables(TGetTablesReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetTables(req, token), cancellationToken); + } + + /// + public Task GetTableTypes(TGetTableTypesReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetTableTypes(req, token), cancellationToken); + } + + /// + public Task GetColumns(TGetColumnsReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetColumns(req, token), cancellationToken); + } + + /// + public Task GetFunctions(TGetFunctionsReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetFunctions(req, token), cancellationToken); + } + + /// + public Task GetPrimaryKeys(TGetPrimaryKeysReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetPrimaryKeys(req, token), cancellationToken); + } + + /// + public Task GetCrossReference(TGetCrossReferenceReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetCrossReference(req, token), cancellationToken); + } + + /// + public Task GetOperationStatus(TGetOperationStatusReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetOperationStatus(req, token), cancellationToken); + } + + /// + public Task CancelOperation(TCancelOperationReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.CancelOperation(req, token), cancellationToken); + } + + /// + public Task CloseOperation(TCloseOperationReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.CloseOperation(req, token), cancellationToken); + } + + /// + public Task GetResultSetMetadata(TGetResultSetMetadataReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetResultSetMetadata(req, token), cancellationToken); + } + + /// + public Task FetchResults(TFetchResultsReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.FetchResults(req, token), cancellationToken); + } + + /// + public Task GetDelegationToken(TGetDelegationTokenReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetDelegationToken(req, token), cancellationToken); + } + + /// + public Task CancelDelegationToken(TCancelDelegationTokenReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.CancelDelegationToken(req, token), cancellationToken); + } + + /// + public Task RenewDelegationToken(TRenewDelegationTokenReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.RenewDelegationToken(req, token), cancellationToken); + } + + /// + public Task GetQueryId(TGetQueryIdReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.GetQueryId(req, token), cancellationToken); + } + + /// + public Task SetClientInfo(TSetClientInfoReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.SetClientInfo(req, token), cancellationToken); + } + + /// + public Task UploadData(TUploadDataReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.UploadData(req, token), cancellationToken); + } + + /// + public Task DownloadData(TDownloadDataReq req, CancellationToken cancellationToken = default) + { + return ExecuteOperationAsync((client, token) => client.DownloadData(req, token), cancellationToken); + } + + #endregion + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(ThreadSafeClient)); + } + } + + /// + /// Disposes the resources used by the client. + /// + public void Dispose() + { + if (!_disposed) + { + _semaphore.Dispose(); + _disposed = true; + } + } + } +} diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs b/csharp/src/Drivers/Databricks/DatabricksConnection.cs index b7a26e2da7..66b55e1aab 100644 --- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs +++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs @@ -25,11 +25,13 @@ using System.Threading.Tasks; using Apache.Arrow.Adbc.Drivers.Apache; using Apache.Arrow.Adbc.Drivers.Apache.Hive2; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2.Client; using Apache.Arrow.Adbc.Drivers.Apache.Spark; using Apache.Arrow.Adbc.Drivers.Databricks.Auth; using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch; using Apache.Arrow.Ipc; using Apache.Hive.Service.Rpc.Thrift; +using Thrift.Protocol; namespace Apache.Arrow.Adbc.Drivers.Databricks { @@ -61,6 +63,11 @@ public DatabricksConnection(IReadOnlyDictionary properties) : ba ValidateProperties(); } + protected override TCLIService.IAsync CreateTCLIServiceClient(TProtocol protocol) + { + return new ThreadSafeClient(new TCLIService.Client(protocol)); + } + private void ValidateProperties() { if (Properties.TryGetValue(DatabricksParameters.ApplySSPWithQueries, out string? applySSPWithQueriesStr))