Skip to content

Commit fe1fb11

Browse files
committed
SephamoreSlim lock to prevent simultaneous fetchResults + poll status
1 parent e08f484 commit fe1fb11

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
using System.Diagnostics;
2222
using System.Threading;
2323
using System.Threading.Tasks;
24+
using Apache.Arrow.Adbc.Drivers.Databricks;
2425
using Apache.Hive.Service.Rpc.Thrift;
2526

2627
namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
@@ -30,7 +31,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
3031
/// </summary>
3132
internal sealed class CloudFetchResultFetcher : ICloudFetchResultFetcher
3233
{
33-
private readonly IHiveServer2Statement _statement;
34+
private readonly DatabricksStatement _statement;
3435
private readonly ICloudFetchMemoryBufferManager _memoryManager;
3536
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
3637
private long _startOffset;
@@ -49,7 +50,7 @@ internal sealed class CloudFetchResultFetcher : ICloudFetchResultFetcher
4950
/// <param name="downloadQueue">The queue to add download tasks to.</param>
5051
/// <param name="prefetchCount">The number of result chunks to prefetch.</param>
5152
public CloudFetchResultFetcher(
52-
IHiveServer2Statement statement,
53+
DatabricksStatement statement,
5354
ICloudFetchMemoryBufferManager memoryManager,
5455
BlockingCollection<IDownloadResult> downloadQueue,
5556
long batchSize)
@@ -197,7 +198,8 @@ private async Task FetchNextResultBatchAsync(CancellationToken cancellationToken
197198
TFetchResultsResp response;
198199
try
199200
{
200-
response = await _statement.Client.FetchResults(request, cancellationToken).ConfigureAwait(false);
201+
// Use thread-safe method to fetch results
202+
response = await _statement.FetchResultsAsync(request, cancellationToken).ConfigureAwait(false);
201203
}
202204
catch (Exception ex)
203205
{

csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
2929
/// </summary>
3030
internal class DatabricksOperationStatusPoller : IDisposable
3131
{
32-
private readonly IHiveServer2Statement _statement;
32+
private readonly DatabricksStatement _statement;
3333
private readonly int _heartbeatIntervalSeconds;
3434
// internal cancellation token source - won't affect the external token
3535
private CancellationTokenSource? _internalCts;
3636
private Task? _operationStatusPollingTask;
3737

38-
public DatabricksOperationStatusPoller(IHiveServer2Statement statement, int heartbeatIntervalSeconds = DatabricksConstants.DefaultOperationStatusPollingIntervalSeconds)
38+
public DatabricksOperationStatusPoller(DatabricksStatement statement, int heartbeatIntervalSeconds = DatabricksConstants.DefaultOperationStatusPollingIntervalSeconds)
3939
{
4040
_statement = statement ?? throw new ArgumentNullException(nameof(statement));
4141
_heartbeatIntervalSeconds = heartbeatIntervalSeconds;
@@ -70,7 +70,8 @@ private async Task PollOperationStatus(CancellationToken cancellationToken)
7070
if (operationHandle == null) break;
7171

7272
var request = new TGetOperationStatusReq(operationHandle);
73-
var response = await _statement.Client.GetOperationStatus(request, cancellationToken);
73+
74+
var response = await _statement.GetOperationStatusAsync(request, cancellationToken);
7475
await Task.Delay(TimeSpan.FromSeconds(_heartbeatIntervalSeconds), cancellationToken);
7576

7677
// end the heartbeat if the command has terminated

csharp/src/Drivers/Databricks/DatabricksReader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public DatabricksReader(DatabricksStatement statement, Schema schema, bool isLz4
7878
}
7979

8080
TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize);
81-
TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken);
81+
TFetchResultsResp response = await this.statement.FetchResultsAsync(request, cancellationToken);
8282

8383
// Make sure we get the arrowBatches
8484
this.batches = response.Results.ArrowBatches;

csharp/src/Drivers/Databricks/DatabricksStatement.cs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717

1818
using System;
19+
using System.Threading;
20+
using System.Threading.Tasks;
1921
using Apache.Arrow.Adbc.Drivers.Apache;
2022
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
2123
using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
@@ -26,11 +28,14 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
2628
/// <summary>
2729
/// Databricks-specific implementation of <see cref="AdbcStatement"/>
2830
/// </summary>
29-
internal class DatabricksStatement : SparkStatement, IHiveServer2Statement
31+
internal class DatabricksStatement : SparkStatement, IHiveServer2Statement, IDisposable
3032
{
3133
private bool useCloudFetch;
3234
private bool canDecompressLz4;
3335
private long maxBytesPerFile;
36+
37+
// Semaphore lock to ensure that polling and fetching results do not both use transport at the same time
38+
private readonly SemaphoreSlim _clientSemaphore = new SemaphoreSlim(1, 1);
3439

3540
public DatabricksStatement(DatabricksConnection connection)
3641
: base(connection)
@@ -70,6 +75,58 @@ public TSparkDirectResults? DirectResults
7075
// Cast the Client to IAsync for CloudFetch compatibility
7176
TCLIService.IAsync IHiveServer2Statement.Client => Connection.Client;
7277

78+
/// <summary>
79+
/// Executes a client operation in a thread-safe manner.
80+
/// </summary>
81+
/// <typeparam name="TResult">The type of the result.</typeparam>
82+
/// <param name="operation">The operation to execute.</param>
83+
/// <param name="cancellationToken">The cancellation token.</param>
84+
/// <returns>A task representing the asynchronous operation with the result.</returns>
85+
public async Task<TResult> ExecuteClientOperationAsync<TResult>(
86+
Func<TCLIService.IAsync, CancellationToken, Task<TResult>> operation,
87+
CancellationToken cancellationToken)
88+
{
89+
await _clientSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
90+
try
91+
{
92+
return await operation(Connection.Client, cancellationToken).ConfigureAwait(false);
93+
}
94+
finally
95+
{
96+
_clientSemaphore.Release();
97+
}
98+
}
99+
100+
/// <summary>
101+
/// Gets the operation status in a thread-safe manner.
102+
/// </summary>
103+
/// <param name="request">The get operation status request.</param>
104+
/// <param name="cancellationToken">The cancellation token.</param>
105+
/// <returns>A task representing the asynchronous operation with the result.</returns>
106+
public Task<TGetOperationStatusResp> GetOperationStatusAsync(
107+
TGetOperationStatusReq request,
108+
CancellationToken cancellationToken)
109+
{
110+
return ExecuteClientOperationAsync(
111+
(client, token) => client.GetOperationStatus(request, token),
112+
cancellationToken);
113+
}
114+
115+
/// <summary>
116+
/// Fetches results in a thread-safe manner.
117+
/// </summary>
118+
/// <param name="request">The fetch results request.</param>
119+
/// <param name="cancellationToken">The cancellation token.</param>
120+
/// <returns>A task representing the asynchronous operation with the result.</returns>
121+
public Task<TFetchResultsResp> FetchResultsAsync(
122+
TFetchResultsReq request,
123+
CancellationToken cancellationToken)
124+
{
125+
return ExecuteClientOperationAsync(
126+
(client, token) => client.FetchResults(request, token),
127+
cancellationToken);
128+
}
129+
73130
public override void SetOption(string key, string value)
74131
{
75132
switch (key)
@@ -151,5 +208,18 @@ internal void SetMaxBytesPerFile(long maxBytesPerFile)
151208
{
152209
this.maxBytesPerFile = maxBytesPerFile;
153210
}
211+
212+
/// <summary>
213+
/// Disposes the resources used by the statement.
214+
/// </summary>
215+
protected override void Dispose(bool disposing)
216+
{
217+
if (disposing)
218+
{
219+
_clientSemaphore.Dispose();
220+
}
221+
222+
base.Dispose(disposing);
223+
}
154224
}
155225
}

0 commit comments

Comments
 (0)