Skip to content

Commit 7ff3364

Browse files
authored
feat(csharp/test/Drivers/Databricks): Add mandatory token exchange (#3192)
## Motivation Databricks will eventually require that all non-inhouse OAuth tokens be exchanged for Databricks OAuth tokens before accessing resources. This change implements mandatory token exchange before sending Thrift requests. This check and exchange is performed in the background for now to reduce latency, but it will eventually need to be blocking if non-inhouse OAuth tokens will fail to access Databricks resources in the future. ## Key Components 1. JWT Token Decoder - Decodes JWT tokens to inspect the issuer claim and determine if token exchange is necessary 2. MandatoryTokenExchangeDelegatingHandler - HTTP handler that intercepts requests and performs token exchange when required 3. TokenExchangeClient - Handles the token exchange logic with the same /oidc/v1/token endpoint as token refresh, with slightly different parameters ## Changes - Added new connection string parameter: IdentityFederationClientId for service principal workload identity federation scenarios - Implemented token exchange logic that checks JWT issuer against workspace host - Introduced fallback behavior to maintain backward compatibility if token exchange fails ## Testing `dotnet test --filter "FullyQualifiedName~MandatoryTokenExchangeDelegatingHandlerTests"` ``` [xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit .NET 8.0.7) [xUnit.net 00:00:00.06] Discovering: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.15] Discovered: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.16] Starting: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:01.77] Finished: Apache.Arrow.Adbc.Tests.Drivers.Databricks Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (2.6s) Test summary: total: 11, failed: 0, succeeded: 11, skipped: 0, duration: 2.6s ``` `dotnet test --filter "FullyQualifiedName~TokenExchangeClientTests"` ``` [xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit .NET 8.0.7) [xUnit.net 00:00:00.06] Discovering: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.14] Discovered: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.15] Starting: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.23] Finished: Apache.Arrow.Adbc.Tests.Drivers.Databricks Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (0.8s) Test summary: total: 19, failed: 0, succeeded: 19, skipped: 0, duration: 0.8s ``` `dotnet test --filter "FullyQualifiedName~JwtTokenDecoderTests"` ``` [xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit .NET 8.0.7) [xUnit.net 00:00:00.06] Discovering: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.14] Discovered: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.15] Starting: Apache.Arrow.Adbc.Tests.Drivers.Databricks [xUnit.net 00:00:00.19] Finished: Apache.Arrow.Adbc.Tests.Drivers.Databricks Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (0.8s) Test summary: total: 10, failed: 0, succeeded: 10, skipped: 0, duration: 0.8s ``` Also tested E2E manually with AAD tokens for Azure Databricks workspaces, AAD tokens for AWS Databricks workspaces, and service principal workload identity federation tokens
1 parent 17b6ca9 commit 7ff3364

11 files changed

+1326
-151
lines changed

csharp/src/Drivers/Databricks/Auth/JwtTokenDecoder.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,41 @@ public static bool TryGetExpirationTime(string token, out DateTime expiryTime)
6969
}
7070
}
7171

72+
/// <summary>
73+
/// Tries to extract the issuer (iss) claim from a JWT token.
74+
/// </summary>
75+
/// <param name="token">The JWT token to parse.</param>
76+
/// <param name="issuer">The extracted issuer, if successful.</param>
77+
/// <returns>True if the issuer was successfully extracted, false otherwise.</returns>
78+
public static bool TryGetIssuer(string token, out string issuer)
79+
{
80+
issuer = string.Empty;
81+
82+
try
83+
{
84+
string[] parts = token.Split('.');
85+
if (parts.Length != 3)
86+
{
87+
return false;
88+
}
89+
90+
string payload = DecodeBase64Url(parts[1]);
91+
using JsonDocument jsonDoc = JsonDocument.Parse(payload);
92+
93+
if (!jsonDoc.RootElement.TryGetProperty("iss", out JsonElement issElement))
94+
{
95+
return false;
96+
}
97+
98+
issuer = issElement.GetString() ?? string.Empty;
99+
return !string.IsNullOrEmpty(issuer);
100+
}
101+
catch
102+
{
103+
return false;
104+
}
105+
}
106+
72107
/// <summary>
73108
/// Decodes a base64url encoded string to a regular string.
74109
/// </summary>
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Net.Http;
20+
using System.Net.Http.Headers;
21+
using System.Threading;
22+
using System.Threading.Tasks;
23+
24+
namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
25+
{
26+
/// <summary>
27+
/// HTTP message handler that performs mandatory token exchange for non-Databricks tokens.
28+
/// Uses a non-blocking approach to exchange tokens in the background.
29+
/// </summary>
30+
internal class MandatoryTokenExchangeDelegatingHandler : DelegatingHandler
31+
{
32+
private readonly string? _identityFederationClientId;
33+
private readonly object _tokenLock = new object();
34+
private readonly ITokenExchangeClient _tokenExchangeClient;
35+
private string? _currentToken;
36+
private string? _lastSeenToken;
37+
38+
protected Task? _pendingTokenTask = null;
39+
40+
/// <summary>
41+
/// Initializes a new instance of the <see cref="MandatoryTokenExchangeDelegatingHandler"/> class.
42+
/// </summary>
43+
/// <param name="innerHandler">The inner handler to delegate to.</param>
44+
/// <param name="tokenExchangeClient">The client for token exchange operations.</param>
45+
/// <param name="identityFederationClientId">Optional identity federation client ID.</param>
46+
public MandatoryTokenExchangeDelegatingHandler(
47+
HttpMessageHandler innerHandler,
48+
ITokenExchangeClient tokenExchangeClient,
49+
string? identityFederationClientId = null)
50+
: base(innerHandler)
51+
{
52+
_tokenExchangeClient = tokenExchangeClient ?? throw new ArgumentNullException(nameof(tokenExchangeClient));
53+
_identityFederationClientId = identityFederationClientId;
54+
}
55+
56+
/// <summary>
57+
/// Determines if token exchange is needed by checking if the token is a Databricks token.
58+
/// </summary>
59+
/// <returns>True if token exchange is needed, false otherwise.</returns>
60+
private bool NeedsTokenExchange(string bearerToken)
61+
{
62+
// If we already started exchange for this token, no need to check again
63+
if (_lastSeenToken == bearerToken)
64+
{
65+
return false;
66+
}
67+
68+
// If we already have a pending token task, don't start another exchange
69+
if (_pendingTokenTask != null)
70+
{
71+
return false;
72+
}
73+
74+
// If we can't parse the token as JWT, default to use existing token
75+
if (!JwtTokenDecoder.TryGetIssuer(bearerToken, out string issuer))
76+
{
77+
return false;
78+
}
79+
80+
// Check if the issuer matches the current workspace host
81+
// If the issuer is from the same host, it's already a Databricks token
82+
string normalizedHost = _tokenExchangeClient.TokenExchangeEndpoint.Replace("/v1/token", "").ToLowerInvariant();
83+
string normalizedIssuer = issuer.TrimEnd('/').ToLowerInvariant();
84+
85+
return normalizedIssuer != normalizedHost;
86+
}
87+
88+
/// <summary>
89+
/// Starts token exchange in the background if needed.
90+
/// </summary>
91+
/// <param name="bearerToken">The bearer token to potentially exchange.</param>
92+
/// <param name="cancellationToken">A cancellation token.</param>
93+
private void StartTokenExchangeIfNeeded(string bearerToken, CancellationToken cancellationToken)
94+
{
95+
if (_lastSeenToken == bearerToken)
96+
{
97+
return;
98+
}
99+
100+
bool needsExchange;
101+
lock (_tokenLock)
102+
{
103+
needsExchange = NeedsTokenExchange(bearerToken);
104+
105+
_lastSeenToken = bearerToken;
106+
}
107+
108+
if (!needsExchange)
109+
{
110+
return;
111+
}
112+
113+
// Start token exchange in the background
114+
_pendingTokenTask = Task.Run(async () =>
115+
{
116+
try
117+
{
118+
TokenExchangeResponse response = await _tokenExchangeClient.ExchangeTokenAsync(
119+
bearerToken,
120+
_identityFederationClientId,
121+
cancellationToken);
122+
123+
lock (_tokenLock)
124+
{
125+
_currentToken = response.AccessToken;
126+
}
127+
}
128+
catch (Exception ex)
129+
{
130+
System.Diagnostics.Debug.WriteLine($"Mandatory token exchange failed: {ex.Message}");
131+
}
132+
}, cancellationToken).ContinueWith(_ =>
133+
{
134+
lock (_tokenLock)
135+
{
136+
_pendingTokenTask = null;
137+
}
138+
}, TaskScheduler.Default);
139+
}
140+
141+
/// <summary>
142+
/// Sends an HTTP request with the current token.
143+
/// </summary>
144+
/// <param name="request">The HTTP request message to send.</param>
145+
/// <param name="cancellationToken">A cancellation token.</param>
146+
/// <returns>The HTTP response message.</returns>
147+
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
148+
{
149+
string? bearerToken = request.Headers.Authorization?.Parameter;
150+
if (!string.IsNullOrEmpty(bearerToken))
151+
{
152+
StartTokenExchangeIfNeeded(bearerToken!, cancellationToken);
153+
154+
string tokenToUse;
155+
lock (_tokenLock)
156+
{
157+
tokenToUse = _currentToken ?? bearerToken!;
158+
}
159+
160+
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenToUse);
161+
}
162+
163+
return await base.SendAsync(request, cancellationToken);
164+
}
165+
166+
protected override void Dispose(bool disposing)
167+
{
168+
if (disposing)
169+
{
170+
// Wait for any pending token task to complete to avoid leaking tasks
171+
if (_pendingTokenTask != null)
172+
{
173+
try
174+
{
175+
// Try to wait for the task to complete, but don't block indefinitely
176+
_pendingTokenTask.Wait(TimeSpan.FromSeconds(10));
177+
}
178+
catch (Exception ex)
179+
{
180+
// Log any exceptions during disposal
181+
System.Diagnostics.Debug.WriteLine($"Exception during token task cleanup: {ex.Message}");
182+
}
183+
}
184+
}
185+
186+
base.Dispose(disposing);
187+
}
188+
}
189+
}

csharp/src/Drivers/Databricks/Auth/TokenExchangeClient.cs

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,26 @@ internal class TokenExchangeResponse
5656
internal interface ITokenExchangeClient
5757
{
5858
/// <summary>
59-
/// Exchanges the provided token for a new token.
59+
/// Gets the token exchange endpoint URL.
60+
/// </summary>
61+
string TokenExchangeEndpoint { get; }
62+
63+
/// <summary>
64+
/// Refreshes the provided token to extend the lifetime.
65+
/// </summary>
66+
/// <param name="token">The token to refresh.</param>
67+
/// <param name="cancellationToken">A cancellation token.</param>
68+
/// <returns>The response from the token exchange API.</returns>
69+
Task<TokenExchangeResponse> RefreshTokenAsync(string token, CancellationToken cancellationToken);
70+
71+
/// <summary>
72+
/// Exchanges the provided token for a Databricks OAuth token.
6073
/// </summary>
6174
/// <param name="token">The token to exchange.</param>
75+
/// <param name="identityFederationClientId">Optional identity federation client ID.</param>
6276
/// <param name="cancellationToken">A cancellation token.</param>
6377
/// <returns>The response from the token exchange API.</returns>
64-
Task<TokenExchangeResponse> ExchangeTokenAsync(string token, CancellationToken cancellationToken);
78+
Task<TokenExchangeResponse> ExchangeTokenAsync(string token, string? identityFederationClientId, CancellationToken cancellationToken);
6579
}
6680

6781
/// <summary>
@@ -72,6 +86,8 @@ internal class TokenExchangeClient : ITokenExchangeClient
7286
private readonly HttpClient _httpClient;
7387
private readonly string _tokenExchangeEndpoint;
7488

89+
public string TokenExchangeEndpoint => _tokenExchangeEndpoint;
90+
7591
/// <summary>
7692
/// Initializes a new instance of the <see cref="TokenExchangeClient"/> class.
7793
/// </summary>
@@ -93,12 +109,12 @@ public TokenExchangeClient(HttpClient httpClient, string host)
93109
}
94110

95111
/// <summary>
96-
/// Exchanges the provided token for a new token.
112+
/// Refreshes the provided token to extend the lifetime.
97113
/// </summary>
98-
/// <param name="token">The token to exchange.</param>
114+
/// <param name="token">The token to refresh.</param>
99115
/// <param name="cancellationToken">A cancellation token.</param>
100116
/// <returns>The response from the token exchange API.</returns>
101-
public async Task<TokenExchangeResponse> ExchangeTokenAsync(string token, CancellationToken cancellationToken)
117+
public async Task<TokenExchangeResponse> RefreshTokenAsync(string token, CancellationToken cancellationToken)
102118
{
103119
var content = new FormUrlEncodedContent(new[]
104120
{
@@ -120,6 +136,50 @@ public async Task<TokenExchangeResponse> ExchangeTokenAsync(string token, Cancel
120136
return ParseTokenResponse(responseContent);
121137
}
122138

139+
/// <summary>
140+
/// Exchanges the provided token for a Databricks OAuth token.
141+
/// </summary>
142+
/// <param name="token">The token to exchange.</param>
143+
/// <param name="identityFederationClientId">Optional identity federation client ID.</param>
144+
/// <param name="cancellationToken">A cancellation token.</param>
145+
/// <returns>The response from the token exchange API.</returns>
146+
public async Task<TokenExchangeResponse> ExchangeTokenAsync(
147+
string token,
148+
string? identityFederationClientId,
149+
CancellationToken cancellationToken)
150+
{
151+
var formData = new List<KeyValuePair<string, string>>
152+
{
153+
new KeyValuePair<string, string>("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
154+
new KeyValuePair<string, string>("assertion", token),
155+
new KeyValuePair<string, string>("scope", "sql")
156+
};
157+
158+
if (!string.IsNullOrEmpty(identityFederationClientId))
159+
{
160+
formData.Add(new KeyValuePair<string, string>("identity_federation_client_id", identityFederationClientId!));
161+
}
162+
else
163+
{
164+
formData.Add(new KeyValuePair<string, string>("return_original_token_if_authenticated", "true"));
165+
}
166+
167+
var content = new FormUrlEncodedContent(formData);
168+
169+
var request = new HttpRequestMessage(HttpMethod.Post, _tokenExchangeEndpoint)
170+
{
171+
Content = content
172+
};
173+
request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("*/*"));
174+
175+
HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
176+
177+
response.EnsureSuccessStatusCode();
178+
179+
string responseContent = await response.Content.ReadAsStringAsync();
180+
return ParseTokenResponse(responseContent);
181+
}
182+
123183
/// <summary>
124184
/// Parses the token exchange API response.
125185
/// </summary>

csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs renamed to csharp/src/Drivers/Databricks/Auth/TokenRefreshDelegatingHandler.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
2727
/// HTTP message handler that automatically refreshes OAuth tokens before they expire.
2828
/// Uses a non-blocking approach to refresh tokens in the background.
2929
/// </summary>
30-
internal class TokenExchangeDelegatingHandler : DelegatingHandler
30+
internal class TokenRefreshDelegatingHandler : DelegatingHandler
3131
{
3232
private readonly string _initialToken;
3333
private readonly int _tokenRenewLimitMinutes;
@@ -40,14 +40,14 @@ internal class TokenExchangeDelegatingHandler : DelegatingHandler
4040
private Task? _pendingTokenTask = null;
4141

4242
/// <summary>
43-
/// Initializes a new instance of the <see cref="TokenExchangeDelegatingHandler"/> class.
43+
/// Initializes a new instance of the <see cref="TokenRefreshDelegatingHandler"/> class.
4444
/// </summary>
4545
/// <param name="innerHandler">The inner handler to delegate to.</param>
4646
/// <param name="tokenExchangeClient">The client for token exchange operations.</param>
4747
/// <param name="initialToken">The initial token from the connection string.</param>
4848
/// <param name="tokenExpiryTime">The expiry time of the initial token.</param>
4949
/// <param name="tokenRenewLimitMinutes">The minutes before token expiration when we should start renewing the token.</param>
50-
public TokenExchangeDelegatingHandler(
50+
public TokenRefreshDelegatingHandler(
5151
HttpMessageHandler innerHandler,
5252
ITokenExchangeClient tokenExchangeClient,
5353
string initialToken,
@@ -111,7 +111,7 @@ private void StartTokenRenewalIfNeeded(CancellationToken cancellationToken)
111111
{
112112
try
113113
{
114-
TokenExchangeResponse response = await _tokenExchangeClient.ExchangeTokenAsync(_initialToken, cancellationToken);
114+
TokenExchangeResponse response = await _tokenExchangeClient.RefreshTokenAsync(_initialToken, cancellationToken);
115115

116116
// Update the token atomically when ready
117117
lock (_tokenLock)

0 commit comments

Comments
 (0)