Skip to content

Commit e5c25e1

Browse files
authored
feat(csharp/test/Drivers/Databricks): Support token refresh to extend connection lifetime (#3177)
## Motivation In scenarios like PowerBI dataset refresh, if a query runs longer than the OAuth token's expiration time (typically 1 hour for AAD tokens), the connection fails. PowerBI only refreshes access tokens if they have less than 20 minutes of expiration time and never updates tokens after a connection is opened. This PR implements token refresh functionality in the Databricks ADBC driver using the Databricks token exchange API. When an OAuth token is about to expire within a configurable time limit, the driver automatically exchanges it for a new token with a longer expiration time. ## Key Components 1. **JWT Token Decoder**: Parses JWT tokens to extract expiration time 2. **Token Exchange Client**: Handles API calls to the Databricks token exchange endpoint 3. **Token Exchange Handler**: HTTP handler that intercepts requests and refreshes tokens when needed ## Changes - Added new connection string parameter `adbc.databricks.token_renew_limit` to control when token renewal happens - Implemented JWT token decoding to extract token expiration time - Created token exchange client to handle API calls to Databricks token exchange endpoint - Added HTTP handler to intercept requests and refresh tokens when needed - Updated connection handling to create and configure the token exchange components ## Testing - Unit tests for JWT token decoding, token exchange client, and token exchange handler - End-to-end tests that verify token refresh functionality with real tokens ``` dotnet test --filter "FullyQualifiedName~JwtTokenDecoderTests" dotnet test --filter "FullyQualifiedName~TokenExchangeClientTests" dotnet test --filter "FullyQualifiedName~TokenExchangeDelegatingHandlerTests" ```
1 parent 49585fd commit e5c25e1

11 files changed

+1638
-19
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Text;
20+
using System.Text.Json;
21+
22+
namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
23+
{
24+
/// <summary>
25+
/// Utility class for decoding JWT tokens and extracting claims.
26+
/// </summary>
27+
internal static class JwtTokenDecoder
28+
{
29+
/// <summary>
30+
/// Tries to parse a JWT token and extract its expiration time.
31+
/// </summary>
32+
/// <param name="token">The JWT token to parse.</param>
33+
/// <param name="expiryTime">The extracted expiration time, if successful.</param>
34+
/// <returns>True if the expiration time was successfully extracted, false otherwise.</returns>
35+
public static bool TryGetExpirationTime(string token, out DateTime expiryTime)
36+
{
37+
expiryTime = DateTime.MinValue;
38+
39+
try
40+
{
41+
// JWT tokens have three parts separated by dots: header.payload.signature
42+
string[] parts = token.Split('.');
43+
if (parts.Length != 3)
44+
{
45+
return false;
46+
}
47+
48+
string payload = DecodeBase64Url(parts[1]);
49+
50+
using JsonDocument jsonDoc = JsonDocument.Parse(payload);
51+
52+
if (!jsonDoc.RootElement.TryGetProperty("exp", out JsonElement expElement))
53+
{
54+
return false;
55+
}
56+
57+
// The exp claim is a Unix timestamp (seconds since epoch)
58+
if (!expElement.TryGetInt64(out long expSeconds))
59+
{
60+
return false;
61+
}
62+
63+
expiryTime = DateTimeOffset.FromUnixTimeSeconds(expSeconds).UtcDateTime;
64+
return true;
65+
}
66+
catch
67+
{
68+
return false;
69+
}
70+
}
71+
72+
/// <summary>
73+
/// Decodes a base64url encoded string to a regular string.
74+
/// </summary>
75+
/// <param name="base64Url">The base64url encoded string.</param>
76+
/// <returns>The decoded string.</returns>
77+
private static string DecodeBase64Url(string base64Url)
78+
{
79+
// Convert base64url to base64
80+
string base64 = base64Url
81+
.Replace('-', '+')
82+
.Replace('_', '/');
83+
84+
// Add padding if needed
85+
switch (base64.Length % 4)
86+
{
87+
case 2: base64 += "=="; break;
88+
case 3: base64 += "="; break;
89+
}
90+
91+
byte[] bytes = Convert.FromBase64String(base64);
92+
93+
return Encoding.UTF8.GetString(bytes);
94+
}
95+
}
96+
}

csharp/src/Drivers/Databricks/Auth/OAuthClientCredentialsProvider.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,6 @@ public async Task<string> GetAccessTokenAsync(CancellationToken cancellationToke
228228
public void Dispose()
229229
{
230230
_tokenLock.Dispose();
231-
_httpClient.Dispose();
232231
}
233232

234233
public string? GetCachedTokenScope()
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
using System;
19+
using System.Collections.Generic;
20+
using System.Net.Http;
21+
using System.Text.Json;
22+
using System.Threading;
23+
using System.Threading.Tasks;
24+
25+
namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
26+
{
27+
/// <summary>
28+
/// Response from the token exchange API.
29+
/// </summary>
30+
internal class TokenExchangeResponse
31+
{
32+
/// <summary>
33+
/// The new access token.
34+
/// </summary>
35+
public string AccessToken { get; set; } = string.Empty;
36+
37+
/// <summary>
38+
/// The token type (e.g., "Bearer").
39+
/// </summary>
40+
public string TokenType { get; set; } = string.Empty;
41+
42+
/// <summary>
43+
/// The number of seconds until the token expires.
44+
/// </summary>
45+
public int ExpiresIn { get; set; }
46+
47+
/// <summary>
48+
/// The calculated expiration time based on ExpiresIn.
49+
/// </summary>
50+
public DateTime ExpiryTime { get; set; }
51+
}
52+
53+
/// <summary>
54+
/// Interface for token exchange operations.
55+
/// </summary>
56+
internal interface ITokenExchangeClient
57+
{
58+
/// <summary>
59+
/// Exchanges the provided token for a new token.
60+
/// </summary>
61+
/// <param name="token">The token to exchange.</param>
62+
/// <param name="cancellationToken">A cancellation token.</param>
63+
/// <returns>The response from the token exchange API.</returns>
64+
Task<TokenExchangeResponse> ExchangeTokenAsync(string token, CancellationToken cancellationToken);
65+
}
66+
67+
/// <summary>
68+
/// Client for exchanging tokens using the Databricks token exchange API.
69+
/// </summary>
70+
internal class TokenExchangeClient : ITokenExchangeClient
71+
{
72+
private readonly HttpClient _httpClient;
73+
private readonly string _tokenExchangeEndpoint;
74+
75+
/// <summary>
76+
/// Initializes a new instance of the <see cref="TokenExchangeClient"/> class.
77+
/// </summary>
78+
/// <param name="httpClient">The HTTP client to use for requests.</param>
79+
/// <param name="host">The host of the Databricks workspace.</param>
80+
public TokenExchangeClient(HttpClient httpClient, string host)
81+
{
82+
_httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient));
83+
84+
if (string.IsNullOrEmpty(host))
85+
{
86+
throw new ArgumentNullException(nameof(host));
87+
}
88+
89+
// Ensure the host doesn't have a trailing slash
90+
host = host.TrimEnd('/');
91+
92+
_tokenExchangeEndpoint = $"https://{host}/oidc/v1/token";
93+
}
94+
95+
/// <summary>
96+
/// Exchanges the provided token for a new token.
97+
/// </summary>
98+
/// <param name="token">The token to exchange.</param>
99+
/// <param name="cancellationToken">A cancellation token.</param>
100+
/// <returns>The response from the token exchange API.</returns>
101+
public async Task<TokenExchangeResponse> ExchangeTokenAsync(string token, CancellationToken cancellationToken)
102+
{
103+
var content = new FormUrlEncodedContent(new[]
104+
{
105+
new KeyValuePair<string, string>("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
106+
new KeyValuePair<string, string>("assertion", token)
107+
});
108+
109+
var request = new HttpRequestMessage(HttpMethod.Post, _tokenExchangeEndpoint)
110+
{
111+
Content = content
112+
};
113+
request.Headers.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("*/*"));
114+
115+
HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
116+
117+
response.EnsureSuccessStatusCode();
118+
119+
string responseContent = await response.Content.ReadAsStringAsync();
120+
return ParseTokenResponse(responseContent);
121+
}
122+
123+
/// <summary>
124+
/// Parses the token exchange API response.
125+
/// </summary>
126+
/// <param name="responseContent">The response content to parse.</param>
127+
/// <returns>The parsed token exchange response.</returns>
128+
private TokenExchangeResponse ParseTokenResponse(string responseContent)
129+
{
130+
using JsonDocument jsonDoc = JsonDocument.Parse(responseContent);
131+
var root = jsonDoc.RootElement;
132+
133+
if (!root.TryGetProperty("access_token", out JsonElement accessTokenElement))
134+
{
135+
throw new DatabricksException("Token exchange response did not contain an access_token");
136+
}
137+
138+
string? accessToken = accessTokenElement.GetString();
139+
if (string.IsNullOrEmpty(accessToken))
140+
{
141+
throw new DatabricksException("Token exchange access_token was null or empty");
142+
}
143+
144+
if (!root.TryGetProperty("token_type", out JsonElement tokenTypeElement))
145+
{
146+
throw new DatabricksException("Token exchange response did not contain token_type");
147+
}
148+
149+
string? tokenType = tokenTypeElement.GetString();
150+
if (string.IsNullOrEmpty(tokenType))
151+
{
152+
throw new DatabricksException("Token exchange token_type was null or empty");
153+
}
154+
155+
if (!root.TryGetProperty("expires_in", out JsonElement expiresInElement))
156+
{
157+
throw new DatabricksException("Token exchange response did not contain expires_in");
158+
}
159+
160+
int expiresIn = expiresInElement.GetInt32();
161+
if (expiresIn <= 0)
162+
{
163+
throw new DatabricksException("Token exchange expires_in value must be positive");
164+
}
165+
166+
DateTime expiryTime = DateTime.UtcNow.AddSeconds(expiresIn);
167+
168+
return new TokenExchangeResponse
169+
{
170+
AccessToken = accessToken!,
171+
TokenType = tokenType!,
172+
ExpiresIn = expiresIn,
173+
ExpiryTime = expiryTime
174+
};
175+
}
176+
}
177+
}

0 commit comments

Comments
 (0)