Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 2 additions & 38 deletions csharp/src/Drivers/Databricks/BaseDatabricksReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@
*/

using System;
using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Tracing;

namespace Apache.Arrow.Adbc.Drivers.Databricks
{
/// <summary>
/// Base class for Databricks readers that handles common functionality. Handles the operation status poller.
/// Base class for Databricks readers that handles common functionality of DatabricksReader and CloudFetchReader
/// </summary>
internal abstract class BaseDatabricksReader : TracingReader
{
protected DatabricksStatement statement;
protected readonly Schema schema;
protected readonly bool isLz4Compressed;
protected DatabricksOperationStatusPoller? operationStatusPoller;
protected bool hasNoMoreRows = false;
private bool isDisposed;

Expand All @@ -39,48 +37,14 @@ protected BaseDatabricksReader(DatabricksStatement statement, Schema schema, boo
this.schema = schema;
this.isLz4Compressed = isLz4Compressed;
this.statement = statement;
if (statement.DirectResults?.ResultSet != null && !statement.DirectResults.ResultSet.HasMoreRows)
{
return;
}
operationStatusPoller = new DatabricksOperationStatusPoller(statement);
operationStatusPoller.Start();
}

public override Schema Schema { get { return schema; } }

protected void StopOperationStatusPoller()
{
operationStatusPoller?.Stop();
}

protected override void Dispose(bool disposing)
{
if (!isDisposed)
{
if (disposing)
{
DisposeOperationStatusPoller();
DisposeResources();
}
isDisposed = true;
}

base.Dispose(disposing);
}

protected virtual void DisposeResources()
{
}

protected void DisposeOperationStatusPoller()
{
if (operationStatusPoller != null)
{
operationStatusPoller.Stop();
operationStatusPoller.Dispose();
operationStatusPoller = null;
}
isDisposed = true;
}

protected void ThrowIfDisposed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,13 @@ public CloudFetchReader(DatabricksStatement statement, Schema schema, TFetchResu
}
}

StopOperationStatusPoller();
// If we get here, there are no more files
return null;
}
});
}

protected override void DisposeResources()
protected override void Dispose(bool disposing)
{
if (this.currentReader != null)
{
Expand All @@ -181,6 +180,7 @@ protected override void DisposeResources()
this.downloadManager.Dispose();
this.downloadManager = null;
}
base.Dispose(disposing);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Hive.Service.Rpc.Thrift;

namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
Expand Down Expand Up @@ -160,8 +161,11 @@ public async Task StopAsync()

request.StartRowOffset = offset;

// Cancelling mid-request breaks the client; Dispose() should not break the underlying client
CancellationToken expiringToken = ApacheUtility.GetCancellationToken(_statement.QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds);

// Fetch results
TFetchResultsResp response = await _statement.Client.FetchResults(request, cancellationToken);
TFetchResultsResp response = await _statement.Client.FetchResults(request, expiringToken);

// Process the results
if (response.Status.StatusCode == TStatusCode.SUCCESS_STATUS &&
Expand Down Expand Up @@ -257,7 +261,7 @@ private async Task FetchResultsAsync(CancellationToken cancellationToken)
// Add the end of results guard to the queue even in case of error
try
{
_downloadQueue.Add(EndOfResultsGuard.Instance, CancellationToken.None);
_downloadQueue.TryAdd(EndOfResultsGuard.Instance, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't use cancellation token here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want Dispose to stall, what do you think?

If caller disposes statement, and download_queue is full, we would have to wait on however long cancellation token is configured to expire

}
catch (Exception)
{
Expand All @@ -282,7 +286,9 @@ private async Task FetchNextResultBatchAsync(long? offset, CancellationToken can
TFetchResultsResp response;
try
{
response = await _statement.Client.FetchResults(request, cancellationToken).ConfigureAwait(false);
// Use the statement's configured query timeout
CancellationToken expiringToken = ApacheUtility.GetCancellationToken(_statement.QueryTimeoutSeconds, ApacheUtility.TimeUnit.Seconds);
response = await _statement.Client.FetchResults(request, expiringToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,10 @@ internal interface IHiveServer2Statement
/// </summary>
/// <returns>True if direct results are available and contain result data, false otherwise.</returns>
bool HasDirectResults { get; }

/// <summary>
/// Gets the query timeout in seconds.
/// </summary>
int QueryTimeoutSeconds { get; }
}
}
45 changes: 38 additions & 7 deletions csharp/src/Drivers/Databricks/DatabricksCompositeReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@
*/

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow;
using Apache.Arrow.Adbc;
using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Arrow.Adbc.Tracing;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;

namespace Apache.Arrow.Adbc.Drivers.Databricks
Expand All @@ -50,6 +45,8 @@ internal sealed class DatabricksCompositeReader : TracingReader
private readonly TlsProperties _tlsOptions;
private readonly HiveServer2ProxyConfigurator _proxyConfigurator;

private DatabricksOperationStatusPoller? operationStatusPoller;

/// <summary>
/// Initializes a new instance of the <see cref="DatabricksCompositeReader"/> class.
/// </summary>
Expand All @@ -66,10 +63,15 @@ internal DatabricksCompositeReader(DatabricksStatement statement, Schema schema,
_proxyConfigurator = proxyConfigurator;

// use direct results if available
if (_statement.HasDirectResults && _statement.DirectResults != null && _statement.DirectResults.__isset.resultSet)
if (_statement.HasDirectResults && _statement.DirectResults != null && _statement.DirectResults.__isset.resultSet && statement.DirectResults?.ResultSet != null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be simplified to if (_statement.HasDirectResults)? It looks like that method is performing the same checks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a bit helpful for linter

{
_activeReader = DetermineReader(_statement.DirectResults.ResultSet);
}
if (_statement.DirectResults?.ResultSet.HasMoreRows ?? true)
{
operationStatusPoller = new DatabricksOperationStatusPoller(statement);
operationStatusPoller.Start();
}
}

private BaseDatabricksReader DetermineReader(TFetchResultsResp initialResults)
Expand All @@ -93,7 +95,7 @@ private BaseDatabricksReader DetermineReader(TFetchResultsResp initialResults)
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The next record batch, or null if there are no more batches.</returns>
public override async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
private async ValueTask<RecordBatch?> ReadNextRecordBatchInternalAsync(CancellationToken cancellationToken = default)
{
// Initialize the active reader if not already done
if (_activeReader == null)
Expand All @@ -108,5 +110,34 @@ private BaseDatabricksReader DetermineReader(TFetchResultsResp initialResults)

return await _activeReader.ReadNextRecordBatchAsync(cancellationToken);
}

public override async ValueTask<RecordBatch?> ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
var result = await ReadNextRecordBatchInternalAsync(cancellationToken);
// Stop the poller when we've reached the end of results
if (result == null)
{
StopOperationStatusPoller();
}
return result;
}

protected override void Dispose(bool disposing)
{
if (disposing)
{
_activeReader?.Dispose();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to null after dispose?

StopOperationStatusPoller();
}
_activeReader = null;
base.Dispose(disposing);
}

private void StopOperationStatusPoller()
{
operationStatusPoller?.Stop();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider setting to null here instead of the DisposeOperationStatusPoller method to avoid duplicate calls.

operationStatusPoller?.Dispose();
operationStatusPoller = null;
}
}
}
57 changes: 34 additions & 23 deletions csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;

Expand All @@ -31,14 +32,19 @@ internal class DatabricksOperationStatusPoller : IDisposable
{
private readonly IHiveServer2Statement _statement;
private readonly int _heartbeatIntervalSeconds;
private readonly int _requestTimeoutSeconds;
// internal cancellation token source - won't affect the external token
private CancellationTokenSource? _internalCts;
private Task? _operationStatusPollingTask;

public DatabricksOperationStatusPoller(IHiveServer2Statement statement, int heartbeatIntervalSeconds = DatabricksConstants.DefaultOperationStatusPollingIntervalSeconds)
public DatabricksOperationStatusPoller(
IHiveServer2Statement statement,
int heartbeatIntervalSeconds = DatabricksConstants.DefaultOperationStatusPollingIntervalSeconds,
int requestTimeoutSeconds = DatabricksConstants.DefaultOperationStatusRequestTimeoutSeconds)
{
_statement = statement ?? throw new ArgumentNullException(nameof(statement));
_heartbeatIntervalSeconds = heartbeatIntervalSeconds;
_requestTimeoutSeconds = requestTimeoutSeconds;
}

public bool IsStarted => _operationStatusPollingTask != null;
Expand All @@ -62,29 +68,27 @@ public void Start(CancellationToken externalToken = default)

private async Task PollOperationStatus(CancellationToken cancellationToken)
{
try
while (!cancellationToken.IsCancellationRequested)
{
while (!cancellationToken.IsCancellationRequested)
{
var operationHandle = _statement.OperationHandle;
if (operationHandle == null) break;
var operationHandle = _statement.OperationHandle;
if (operationHandle == null) break;

Copy link
Contributor Author

@toddmeng-db toddmeng-db Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use a timeout token here, instead of cancelling when canceltoken is triggered; if an interrupt is triggered prematurely, the TCLI client may still have unsent/unconsumed results in the buffers, affecting subsequent calls with that client (which is any future call in the same Session)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you able to repro this? should we do this to all the thrift rpc calls in the driver?

Copy link
Contributor Author

@toddmeng-db toddmeng-db Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is because in THTTPTransport (used by SparkHttpConnection -> DatabricksHttpconnection), a new Stream is created when the request is flushed. If cancellation happens before this, that stream doesn't get discarded:
https://github.com/apache/thrift/blob/master/lib/netstd/Thrift/Transport/Client/THttpTransport.cs#L281

Yes, during testing, got some errors. In the proxy logs, I remember seeing requests sent out with both GetOperationStatus and CloseOperationStatus (in the same request) while testing another PR

I think we are safe in HiveServer2Statement, but we might need to adjust CancellationToken in DatabricksReader, CloudFetchResultFetcher, and DatabricksCompositeReader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think this depends a bit on how CancellationToken could be used by PBI, too
@CurtHagenlocher will mashup ever trigger cancellationTokens passed into IArrowStreamReader.ReadNextBatchAsync? Do we need to ensure that the connection still remains usable for subsequent statements?

Copy link
Contributor Author

@toddmeng-db toddmeng-db Aug 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for now, I think we can operate this way:

  1. If the user cancels the token passed in to ReadNextBatchAsync, we should not to break the client
  2. Dispose() should not break the client either

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CurtHagenlocher will mashup ever trigger cancellationTokens passed into IArrowStreamReader.ReadNextBatchAsync? Do we need to ensure that the connection still remains usable for subsequent statements?

This is currently unimplemented but we'll need to implement it before GA for parity with the ODBC implementation. What is probably most important for cancellation is query execution, and unless we manage to push forward the proposed ADBC 1.1 API, currently the only way to cancel a running query is to call AdbcStatement.Cancel. There is currently no implementation of this method for any of the C#-implemented drivers :(.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a Power BI perspective, the most important use of cancellation is for Direct Query because users can generate a lot of queries simply by clicking around in a visual and in-progress queries will need to be cancelled if their output is no longer needed. DQ output tends to be relatively small, so being able to cancel in the middle of reading the output is arguably less important than being able to cancel before the results start coming back.

CancellationToken GetOperationStatusTimeoutToken = ApacheUtility.GetCancellationToken(_requestTimeoutSeconds, ApacheUtility.TimeUnit.Seconds);

var request = new TGetOperationStatusReq(operationHandle);
var response = await _statement.Client.GetOperationStatus(request, cancellationToken);
await Task.Delay(TimeSpan.FromSeconds(_heartbeatIntervalSeconds), cancellationToken);
var request = new TGetOperationStatusReq(operationHandle);
var response = await _statement.Client.GetOperationStatus(request, GetOperationStatusTimeoutToken);
await Task.Delay(TimeSpan.FromSeconds(_heartbeatIntervalSeconds), cancellationToken);

// end the heartbeat if the command has terminated
if (response.OperationState == TOperationState.CANCELED_STATE ||
response.OperationState == TOperationState.ERROR_STATE)
{
break;
}
// end the heartbeat if the command has terminated
if (response.OperationState == TOperationState.CANCELED_STATE ||
response.OperationState == TOperationState.ERROR_STATE ||
response.OperationState == TOperationState.CLOSED_STATE ||
response.OperationState == TOperationState.TIMEDOUT_STATE ||
response.OperationState == TOperationState.UKNOWN_STATE)
{
break;
}
}
catch (TaskCanceledException)
{
// ignore
}
}

public void Stop()
Expand All @@ -94,12 +98,19 @@ public void Stop()

public void Dispose()
{
if (_internalCts != null)
_internalCts?.Cancel();
try
{
_internalCts.Cancel();
_operationStatusPollingTask?.Wait();
_internalCts.Dispose();
_operationStatusPollingTask?.GetAwaiter().GetResult();
}
catch (OperationCanceledException)
{
// Expected, no-op
}

_internalCts?.Dispose();
_internalCts = null;
_operationStatusPollingTask = null;
}
}
}
7 changes: 6 additions & 1 deletion csharp/src/Drivers/Databricks/DatabricksParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,15 @@ public class DatabricksParameters : SparkParameters
public class DatabricksConstants
{
/// <summary>
/// Default heartbeat interval in seconds for long-running operations
/// Default heartbeat interval in seconds for long-running operations. TODO: make this user-configurable
/// </summary>
public const int DefaultOperationStatusPollingIntervalSeconds = 60;

/// <summary>
/// Default timeout in seconds for operation status polling requests. TODO: make this user-configurable
/// </summary>
public const int DefaultOperationStatusRequestTimeoutSeconds = 30;

/// <summary>
/// OAuth grant type constants
/// </summary>
Expand Down
3 changes: 1 addition & 2 deletions csharp/src/Drivers/Databricks/DatabricksReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ public DatabricksReader(DatabricksStatement statement, Schema schema, TFetchResu

if (this.hasNoMoreRows)
{
StopOperationStatusPoller();
return null;
}

// TODO: use an expiring cancellationtoken
TFetchResultsReq request = new TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT, this.statement.BatchSize);
TFetchResultsResp response = await this.statement.Connection.Client!.FetchResults(request, cancellationToken);

Expand Down
3 changes: 3 additions & 0 deletions csharp/src/Drivers/Databricks/DatabricksStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ public TSparkDirectResults? DirectResults
// Cast the Client to IAsync for CloudFetch compatibility
TCLIService.IAsync IHiveServer2Statement.Client => Connection.Client;

// Expose QueryTimeoutSeconds for IHiveServer2Statement
int IHiveServer2Statement.QueryTimeoutSeconds => base.QueryTimeoutSeconds;

public override void SetOption(string key, string value)
{
switch (key)
Expand Down
8 changes: 6 additions & 2 deletions csharp/test/Drivers/Databricks/E2E/StatementTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,14 @@ public async Task AllStatementTypesDisposeWithoutErrors(string statementType, st
var batch = await queryResult.Stream.ReadNextRecordBatchAsync();
// Note: batch might be null for empty results, that's OK

// test disposing the stream does not throw
var streamException = Record.Exception(() => queryResult.Stream.Dispose());
Assert.Null(streamException);

// The critical test: disposal should not throw any exceptions
// This specifically tests the fix for the GetColumns bug where _directResults wasn't set
var exception = Record.Exception(() => statement.Dispose());
Assert.Null(exception);
var statementException = Record.Exception(() => statement.Dispose());
Assert.Null(statementException);
}
catch (Exception ex)
{
Expand Down
Loading
Loading