Skip to content

Commit 446e97a

Browse files
committed
DatabricksStatement thread-safe ownership
1 parent 4623b48 commit 446e97a

File tree

5 files changed

+27
-33
lines changed

5 files changed

+27
-33
lines changed

csharp/src/Drivers/Databricks/Client/ThreadSafeClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ public void Dispose()
110110
}
111111
}
112112
}
113-
}
113+
}

csharp/src/Drivers/Databricks/DatabricksConnection.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,6 @@ internal override IArrowArrayStream NewReader<T>(T statement, Schema schema, TGe
298298

299299
internal override SchemaParser SchemaParser => new DatabricksSchemaParser();
300300

301-
internal ThreadSafeClient ThreadSafeClient => new ThreadSafeClient(base.Client);
302-
303301
public override AdbcStatement CreateStatement()
304302
{
305303
DatabricksStatement statement = new DatabricksStatement(this);

csharp/src/Drivers/Databricks/DatabricksStatement.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
2929
/// <summary>
3030
/// Databricks-specific implementation of <see cref="AdbcStatement"/>
3131
/// </summary>
32-
internal class DatabricksStatement : SparkStatement, IHiveServer2Statement, IDisposable
32+
internal class DatabricksStatement : SparkStatement, IHiveServer2Statement
3333
{
3434
private bool useCloudFetch;
3535
private bool canDecompressLz4;
3636
private long maxBytesPerFile;
37-
private ThreadSafeClient _threadSafeClient { get; }
37+
private ThreadSafeClient _threadSafeClient;
3838

3939
public DatabricksStatement(DatabricksConnection connection)
4040
: base(connection)
@@ -43,7 +43,7 @@ public DatabricksStatement(DatabricksConnection connection)
4343
useCloudFetch = connection.UseCloudFetch;
4444
canDecompressLz4 = connection.CanDecompressLz4;
4545
maxBytesPerFile = connection.MaxBytesPerFile;
46-
_threadSafeClient = connection.ThreadSafeClient;
46+
_threadSafeClient = new ThreadSafeClient(Connection.Client);
4747
}
4848

4949
protected override void SetStatementProperties(TExecuteStatementReq statement)
@@ -155,11 +155,5 @@ internal void SetMaxBytesPerFile(long maxBytesPerFile)
155155
{
156156
this.maxBytesPerFile = maxBytesPerFile;
157157
}
158-
159-
public override void Dispose()
160-
{
161-
_threadSafeClient?.Dispose();
162-
base.Dispose();
163-
}
164158
}
165159
}

csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using System.Collections.Generic;
2121
using System.Threading;
2222
using System.Threading.Tasks;
23+
using Apache.Arrow.Adbc.Drivers.Apache.Databricks.Client;
2324
using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
2425
using Apache.Arrow.Adbc.Drivers.Databricks;
2526
using Apache.Hive.Service.Rpc.Thrift;
@@ -46,13 +47,13 @@ public CloudFetchResultFetcherTest()
4647
public async Task StartAsync_CalledTwice_ThrowsException()
4748
{
4849
// Arrange
49-
var mockClient = new Mock<TCLIService.Client>();
50-
mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
50+
var mockClient = new Mock<ThreadSafeClient>();
51+
mockClient.Setup(c => c.FetchResultsAsync(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
5152
.ReturnsAsync(CreateFetchResultsResponse(new List<TSparkArrowResultLink>(), false));
5253

5354
var mockStatement = new Mock<DatabricksStatement>();
5455
mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle());
55-
mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
56+
mockStatement.Setup(s => s.ThreadSafeClient).Returns(mockClient.Object);
5657

5758
var fetcher = new CloudFetchResultFetcher(
5859
mockStatement.Object,
@@ -79,13 +80,13 @@ public async Task FetchResultsAsync_SuccessfullyFetchesResults()
7980
CreateTestResultLink(200, 100, "http://test.com/file3")
8081
};
8182

82-
var mockClient = new Mock<TCLIService.Client>();
83-
mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
83+
var mockClient = new Mock<ThreadSafeClient>();
84+
mockClient.Setup(c => c.FetchResultsAsync(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
8485
.ReturnsAsync(CreateFetchResultsResponse(resultLinks, false));
8586

8687
var mockStatement = new Mock<DatabricksStatement>();
8788
mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle());
88-
mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
89+
mockStatement.Setup(s => s.ThreadSafeClient).Returns(mockClient.Object);
8990

9091
var fetcher = new CloudFetchResultFetcher(
9192
mockStatement.Object,
@@ -153,14 +154,14 @@ public async Task FetchResultsAsync_WithMultipleBatches_FetchesAllResults()
153154
CreateTestResultLink(300, 100, "http://test.com/file4")
154155
};
155156

156-
var mockClient = new Mock<TCLIService.Client>();
157-
mockClient.SetupSequence(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
157+
var mockClient = new Mock<ThreadSafeClient>();
158+
mockClient.SetupSequence(c => c.FetchResultsAsync(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
158159
.ReturnsAsync(CreateFetchResultsResponse(firstBatchLinks, true))
159160
.ReturnsAsync(CreateFetchResultsResponse(secondBatchLinks, false));
160161

161162
var mockStatement = new Mock<DatabricksStatement>();
162163
mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle());
163-
mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
164+
mockStatement.Setup(s => s.ThreadSafeClient).Returns(mockClient.Object);
164165

165166
var fetcher = new CloudFetchResultFetcher(
166167
mockStatement.Object,
@@ -207,13 +208,13 @@ public async Task FetchResultsAsync_WithMultipleBatches_FetchesAllResults()
207208
public async Task FetchResultsAsync_WithEmptyResults_CompletesGracefully()
208209
{
209210
// Arrange
210-
var mockClient = new Mock<TCLIService.Client>();
211-
mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
211+
var mockClient = new Mock<ThreadSafeClient>();
212+
mockClient.Setup(c => c.FetchResultsAsync(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
212213
.ReturnsAsync(CreateFetchResultsResponse(new List<TSparkArrowResultLink>(), false));
213214

214215
var mockStatement = new Mock<DatabricksStatement>();
215216
mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle());
216-
mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
217+
mockStatement.Setup(s => s.ThreadSafeClient).Returns(mockClient.Object);
217218

218219
var fetcher = new CloudFetchResultFetcher(
219220
mockStatement.Object,
@@ -296,8 +297,8 @@ public async Task StopAsync_CancelsFetching()
296297
var fetchStarted = new TaskCompletionSource<bool>();
297298
var fetchCancelled = new TaskCompletionSource<bool>();
298299

299-
var mockClient = new Mock<TCLIService.Client>();
300-
mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
300+
var mockClient = new Mock<ThreadSafeClient>();
301+
mockClient.Setup(c => c.FetchResultsAsync(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
301302
.Returns(async (TFetchResultsReq req, CancellationToken token) =>
302303
{
303304
fetchStarted.TrySetResult(true);
@@ -319,7 +320,7 @@ public async Task StopAsync_CancelsFetching()
319320

320321
var mockStatement = new Mock<DatabricksStatement>();
321322
mockStatement.Setup(s => s.OperationHandle).Returns(CreateOperationHandle());
322-
mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
323+
mockStatement.Setup(s => s.ThreadSafeClient).Returns(mockClient.Object);
323324

324325
var fetcher = new CloudFetchResultFetcher(
325326
mockStatement.Object,

csharp/test/Drivers/Databricks/DatabricksOperationStatusPollerTests.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System;
1919
using System.Threading;
2020
using System.Threading.Tasks;
21+
using Apache.Arrow.Adbc.Drivers.Apache.Databricks.Client;
2122
using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
2223
using Apache.Arrow.Adbc.Drivers.Databricks;
2324
using Apache.Hive.Service.Rpc.Thrift;
@@ -31,23 +32,23 @@ public class DatabricksOperationStatusPollerTests
3132
{
3233
private readonly ITestOutputHelper _outputHelper;
3334
private readonly Mock<DatabricksStatement> _mockStatement;
34-
private readonly Mock<TCLIService.Client> _mockClient;
35+
private readonly Mock<ThreadSafeClient> _mockClient;
3536
private readonly TOperationHandle _operationHandle;
3637

3738
private readonly int _heartbeatIntervalSeconds = 1000;
3839

3940
public DatabricksOperationStatusPollerTests(ITestOutputHelper outputHelper)
4041
{
4142
_outputHelper = outputHelper;
42-
_mockClient = new Mock<TCLIService.Client>();
43+
_mockClient = new Mock<ThreadSafeClient>();
4344
_mockStatement = new Mock<DatabricksStatement>();
4445
_operationHandle = new TOperationHandle
4546
{
4647
OperationId = new THandleIdentifier { Guid = new byte[] { 1, 2, 3, 4 } },
4748
OperationType = TOperationType.EXECUTE_STATEMENT
4849
};
4950

50-
_mockStatement.Setup(s => s.Client).Returns(_mockClient.Object);
51+
_mockStatement.Setup(s => s.ThreadSafeClient).Returns(_mockClient.Object);
5152
_mockStatement.Setup(s => s.OperationHandle).Returns(_operationHandle);
5253
}
5354

@@ -57,7 +58,7 @@ public async Task StartPollsOperationStatusAtInterval()
5758
// Arrange
5859
var poller = new DatabricksOperationStatusPoller(_mockStatement.Object, _heartbeatIntervalSeconds);
5960
var pollCount = 0;
60-
_mockClient.Setup(c => c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()))
61+
_mockClient.Setup(c => c.GetOperationStatusAsync(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()))
6162
.ReturnsAsync(new TGetOperationStatusResp())
6263
.Callback(() => pollCount++);
6364

@@ -67,7 +68,7 @@ public async Task StartPollsOperationStatusAtInterval()
6768

6869
// Assert
6970
Assert.True(pollCount > 0, "Should have polled at least once");
70-
_mockClient.Verify(c => c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()), Times.AtLeastOnce);
71+
_mockClient.Verify(c => c.GetOperationStatusAsync(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()), Times.AtLeastOnce);
7172
}
7273

7374
[Fact]
@@ -76,7 +77,7 @@ public async Task DisposeStopsPolling()
7677
// Arrange
7778
var poller = new DatabricksOperationStatusPoller(_mockStatement.Object, _heartbeatIntervalSeconds);
7879
var pollCount = 0;
79-
_mockClient.Setup(c => c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()))
80+
_mockClient.Setup(c => c.GetOperationStatusAsync(It.IsAny<TGetOperationStatusReq>(), It.IsAny<CancellationToken>()))
8081
.ReturnsAsync(new TGetOperationStatusResp())
8182
.Callback(() => pollCount++);
8283

0 commit comments

Comments
 (0)