Skip to content

Commit 3a5f541

Browse files
committed
adding ARC Managed Identity support to hadoop-azure ABFS
1 parent 4d78253 commit 3a5f541

File tree

7 files changed

+264
-15
lines changed

7 files changed

+264
-15
lines changed

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/AbfsConfiguration.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider;
5858
import org.apache.hadoop.fs.azurebfs.oauth2.CustomTokenProviderAdapter;
5959
import org.apache.hadoop.fs.azurebfs.oauth2.MsiTokenProvider;
60+
import org.apache.hadoop.fs.azurebfs.oauth2.ArcMsiTokenProvider;
6061
import org.apache.hadoop.fs.azurebfs.oauth2.RefreshTokenBasedTokenProvider;
6162
import org.apache.hadoop.fs.azurebfs.oauth2.UserPasswordTokenProvider;
6263
import org.apache.hadoop.fs.azurebfs.oauth2.WorkloadIdentityTokenProvider;
@@ -961,6 +962,9 @@ public AccessTokenProvider getTokenProvider() throws TokenAccessProviderExceptio
961962
String authEndpoint = getTrimmedPasswordString(
962963
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
963964
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
965+
String apiVersion = getTrimmedPasswordString(
966+
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION,
967+
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION);
964968
String tenantGuid =
965969
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
966970
String clientId =
@@ -969,9 +973,27 @@ public AccessTokenProvider getTokenProvider() throws TokenAccessProviderExceptio
969973
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
970974
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
971975
authority = appendSlashIfNeeded(authority);
972-
tokenProvider = new MsiTokenProvider(authEndpoint, tenantGuid,
976+
tokenProvider = new MsiTokenProvider(authEndpoint, apiVersion, tenantGuid,
973977
clientId, authority);
974978
LOG.trace("MsiTokenProvider initialized");
979+
} else if (tokenProviderClass == ArcMsiTokenProvider.class) {
980+
String authEndpoint = getTrimmedPasswordString(
981+
FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT,
982+
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT);
983+
String apiVersion = getTrimmedPasswordString(
984+
FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION,
985+
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION);
986+
String tenantGuid =
987+
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_MSI_TENANT);
988+
String clientId =
989+
getPasswordString(FS_AZURE_ACCOUNT_OAUTH_CLIENT_ID);
990+
String authority = getTrimmedPasswordString(
991+
FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY,
992+
AuthConfigurations.DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY);
993+
authority = appendSlashIfNeeded(authority);
994+
tokenProvider = new ArcMsiTokenProvider(authEndpoint, apiVersion, tenantGuid,
995+
clientId, authority);
996+
LOG.trace("ArcMsiTokenProvider initialized");
975997
} else if (tokenProviderClass == RefreshTokenBasedTokenProvider.class) {
976998
String authEndpoint = getTrimmedPasswordString(
977999
FS_AZURE_ACCOUNT_OAUTH_REFRESH_TOKEN_ENDPOINT,

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/constants/AuthConfigurations.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ public final class AuthConfigurations {
4343
public static final String
4444
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_TOKEN_FILE =
4545
"/var/run/secrets/azure/tokens/azure-identity-token";
46+
/** Default OAuth api-version end point for the MSI flow. */
47+
public static final String
48+
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION =
49+
"2018-02-01";
50+
/** Default OAuth api-version end point for the ARC MSI flow. */
51+
public static final String
52+
DEFAULT_FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION =
53+
"2021-02-01";
4654

4755
private AuthConfigurations() {
4856
}

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/constants/ConfigurationKeys.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,10 @@ public final class ConfigurationKeys {
258258
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT = "fs.azure.account.oauth2.msi.endpoint";
259259
/** Key for oauth msi Authority: {@value}. */
260260
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_AUTHORITY = "fs.azure.account.oauth2.msi.authority";
261+
/** Key for oauth msi endpoint api version: {@value}. */
262+
public static final String FS_AZURE_ACCOUNT_OAUTH_MSI_ENDPOINT_API_VERSION = "fs.azure.account.oauth2.msi.endpoint.api.version";
263+
/** Key for oauth arc msi endpoint api version: {@value}. */
264+
public static final String FS_AZURE_ACCOUNT_OAUTH_ARC_MSI_ENDPOINT_API_VERSION = "fs.azure.account.oauth2.arc.msi.endpoint.api.version";
261265
/** Key for oauth user name: {@value}. */
262266
public static final String FS_AZURE_ACCOUNT_OAUTH_USER_NAME = "fs.azure.account.oauth2.user.name";
263267
/** Key for oauth user password: {@value}. */
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/**
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.hadoop.fs.azurebfs.oauth2;
20+
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
24+
import java.io.IOException;
25+
26+
/**
27+
* Provides tokens based on Azure VM's Managed Service Identity.
28+
*/
29+
public class ArcMsiTokenProvider extends AccessTokenProvider {
30+
31+
private final String authEndpoint;
32+
33+
private final String apiVersion;
34+
35+
private final String authority;
36+
37+
private final String tenantGuid;
38+
39+
private final String clientId;
40+
41+
private long tokenFetchTime = -1;
42+
43+
private static final long ONE_HOUR = 3600 * 1000;
44+
45+
private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);
46+
47+
public ArcMsiTokenProvider(final String authEndpoint, final String apiVersion, final String tenantGuid,
48+
final String clientId, final String authority) {
49+
this.authEndpoint = authEndpoint;
50+
this.tenantGuid = tenantGuid;
51+
this.clientId = clientId;
52+
this.authority = authority;
53+
this.apiVersion = apiVersion;
54+
}
55+
56+
@Override
57+
protected AzureADToken refreshToken() throws IOException {
58+
LOG.debug("AADToken: refreshing token from ARC MSI");
59+
AzureADToken token = AzureADAuthenticator
60+
.getTokenFromArcMsi(authEndpoint, apiVersion, tenantGuid, clientId, authority, false);
61+
tokenFetchTime = System.currentTimeMillis();
62+
return token;
63+
}
64+
65+
/**
66+
* Checks if the token is about to expire as per base expiry logic.
67+
* Otherwise try to expire every 1 hour
68+
*
69+
* @return true if the token is expiring in next 1 hour or if a token has
70+
* never been fetched
71+
*/
72+
@Override
73+
protected boolean isTokenAboutToExpire() {
74+
if (tokenFetchTime == -1 || super.isTokenAboutToExpire()) {
75+
return true;
76+
}
77+
78+
boolean expiring = false;
79+
long elapsedTimeSinceLastTokenRefreshInMillis =
80+
System.currentTimeMillis() - tokenFetchTime;
81+
expiring = elapsedTimeSinceLastTokenRefreshInMillis >= ONE_HOUR
82+
|| elapsedTimeSinceLastTokenRefreshInMillis < 0;
83+
// In case of, Token is not refreshed for 1 hr or any clock skew issues,
84+
// refresh token.
85+
if (expiring) {
86+
LOG.debug("ARCMSIToken: token renewing. Time elapsed since last token fetch:"
87+
+ " {} milli seconds", elapsedTimeSinceLastTokenRefreshInMillis);
88+
}
89+
90+
return expiring;
91+
}
92+
93+
}

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/AzureADAuthenticator.java

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.io.FileNotFoundException;
2222
import java.io.IOException;
2323
import java.io.InputStream;
24+
import java.io.BufferedReader;
25+
import java.io.FileReader;
2426
import java.net.HttpURLConnection;
2527
import java.net.MalformedURLException;
2628
import java.net.URL;
@@ -169,11 +171,11 @@ public static AzureADToken getTokenUsingJWTAssertion(String authEndpoint,
169171
* @return {@link AzureADToken} obtained using the creds
170172
* @throws IOException throws IOException if there is a failure in obtaining the token
171173
*/
172-
public static AzureADToken getTokenFromMsi(final String authEndpoint,
174+
public static AzureADToken getTokenFromMsi(final String authEndpoint, final String apiVersion,
173175
final String tenantGuid, final String clientId, String authority,
174176
boolean bypassCache) throws IOException {
175177
QueryParams qp = new QueryParams();
176-
qp.add("api-version", "2018-02-01");
178+
qp.add("api-version", apiVersion);
177179
qp.add("resource", RESOURCE_NAME);
178180

179181
if (tenantGuid != null && tenantGuid.length() > 0) {
@@ -194,7 +196,51 @@ public static AzureADToken getTokenFromMsi(final String authEndpoint,
194196
headers.put("Metadata", "true");
195197

196198
LOG.debug("AADToken: starting to fetch token using MSI");
197-
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true);
199+
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true, false);
200+
}
201+
202+
/**
203+
* Gets AAD token from the local virtual machine's ARC extension. This only works on
204+
* an Azure VM with MSI extension
205+
* enabled.
206+
*
207+
* @param authEndpoint the OAuth 2.0 token endpoint associated
208+
* with the user's directory (obtain from
209+
* Active Directory configuration)
210+
* @param tenantGuid (optional) The guid of the AAD tenant. Can be {@code null}.
211+
* @param clientId (optional) The clientId guid of the MSI service
212+
* principal to use. Can be {@code null}.
213+
* @param bypassCache {@code boolean} specifying whether a cached token is acceptable or a fresh token
214+
* request should me made to AAD
215+
* @return {@link AzureADToken} obtained using the creds
216+
* @throws IOException throws IOException if there is a failure in obtaining the token
217+
*/
218+
public static AzureADToken getTokenFromArcMsi(final String authEndpoint, final String apiVersion,
219+
final String tenantGuid, final String clientId, String authority,
220+
boolean bypassCache) throws IOException {
221+
QueryParams qp = new QueryParams();
222+
qp.add("api-version", apiVersion);
223+
qp.add("resource", RESOURCE_NAME);
224+
225+
if (tenantGuid != null && tenantGuid.length() > 0) {
226+
authority = authority + tenantGuid;
227+
LOG.debug("MSI authority : {}", authority);
228+
qp.add("authority", authority);
229+
}
230+
231+
if (clientId != null && clientId.length() > 0) {
232+
qp.add("client_id", clientId);
233+
}
234+
235+
if (bypassCache) {
236+
qp.add("bypass_cache", "true");
237+
}
238+
239+
Hashtable<String, String> headers = new Hashtable<>();
240+
headers.put("Metadata", "true");
241+
242+
LOG.debug("AADToken: starting to fetch token using MSI from ARC");
243+
return getTokenCall(authEndpoint, qp.serialize(), headers, "GET", true, true);
198244
}
199245

200246
/**
@@ -327,11 +373,11 @@ public UnexpectedResponseException(final int httpErrorCode,
327373

328374
private static AzureADToken getTokenCall(String authEndpoint, String body,
329375
Hashtable<String, String> headers, String httpMethod) throws IOException {
330-
return getTokenCall(authEndpoint, body, headers, httpMethod, false);
376+
return getTokenCall(authEndpoint, body, headers, httpMethod, false, false);
331377
}
332378

333379
private static AzureADToken getTokenCall(String authEndpoint, String body,
334-
Hashtable<String, String> headers, String httpMethod, boolean isMsi)
380+
Hashtable<String, String> headers, String httpMethod, boolean isMsi, boolean isArc)
335381
throws IOException {
336382
AzureADToken token = null;
337383

@@ -346,7 +392,7 @@ private static AzureADToken getTokenCall(String authEndpoint, String body,
346392
httperror = 0;
347393
ex = null;
348394
try {
349-
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi);
395+
token = getTokenSingleCall(authEndpoint, body, headers, httpMethod, isMsi, isArc);
350396
} catch (HttpException e) {
351397
httperror = e.httpErrorCode;
352398
ex = e;
@@ -385,18 +431,83 @@ private static boolean isRecoverableFailure(IOException e) {
385431

386432
private static AzureADToken getTokenSingleCall(String authEndpoint,
387433
String payload, Hashtable<String, String> headers, String httpMethod,
388-
boolean isMsi)
434+
boolean isMsi, boolean isArc)
389435
throws IOException {
390436

391437
AzureADToken token = null;
392438
HttpURLConnection conn = null;
393439
String urlString = authEndpoint;
440+
String challengerToken = null;
394441

395442
httpMethod = (httpMethod == null) ? "POST" : httpMethod;
396443
if (httpMethod.equals("GET")) {
397444
urlString = urlString + "?" + payload;
398445
}
399446

447+
if (isArc) {
448+
// Currently there is a known flow that ARC needs obtain a challenge token first
449+
// before and in order to get access_token from the same MSI endpoint
450+
try{
451+
LOG.debug("Requesting a challenge token by {} to {}",
452+
httpMethod, authEndpoint);
453+
URL url = new URL(urlString);
454+
conn = (HttpURLConnection) url.openConnection();
455+
conn.setRequestMethod(httpMethod);
456+
conn.setReadTimeout(READ_TIMEOUT);
457+
conn.setConnectTimeout(CONNECT_TIMEOUT);
458+
459+
if (headers != null && headers.size() > 0) {
460+
for (Map.Entry<String, String> entry : headers.entrySet()) {
461+
conn.setRequestProperty(entry.getKey(), entry.getValue());
462+
}
463+
}
464+
conn.setRequestProperty("Connection", "close");
465+
AbfsIoUtils.dumpHeadersToDebugLog("Request Headers",
466+
conn.getRequestProperties());
467+
if (httpMethod.equals("POST")) {
468+
conn.setDoOutput(true);
469+
conn.getOutputStream().write(payload.getBytes(StandardCharsets.UTF_8));
470+
}
471+
AbfsIoUtils.dumpHeadersToDebugLog("Response Headers",
472+
conn.getHeaderFields());
473+
474+
int httpResponseCode = conn.getResponseCode();
475+
String requestId = conn.getHeaderField("x-ms-request-id");
476+
String responseContentType = conn.getHeaderField("Content-Type");
477+
String operation = "Challenge Token: HTTP connection to " + authEndpoint
478+
+ " failed for getting challenge token from ARC MSI endpoint.";
479+
InputStream stream = conn.getErrorStream();
480+
if (stream == null) {
481+
// no error stream, try the original input stream
482+
stream = conn.getInputStream();
483+
}
484+
String responseBody = consumeInputStream(stream, 1024);
485+
486+
String authHeader = conn.getHeaderField("Www-Authenticate");
487+
if (authHeader != null) {
488+
// Extract the challenge token path
489+
int index = authHeader.indexOf('=');
490+
if (index != -1) {
491+
String authHeaderPath = authHeader.substring(index + 1).trim();
492+
try (BufferedReader reader = new BufferedReader(new FileReader(authHeaderPath))) {
493+
challengerToken = reader.readLine().trim();
494+
}
495+
}
496+
} else {
497+
throw new HttpException(httpResponseCode,
498+
requestId,
499+
operation,
500+
authEndpoint,
501+
responseContentType,
502+
responseBody);
503+
}
504+
} finally {
505+
if (conn != null) {
506+
conn.disconnect();
507+
}
508+
}
509+
}
510+
400511
try {
401512
LOG.debug("Requesting an OAuth token by {} to {}",
402513
httpMethod, authEndpoint);
@@ -406,6 +517,10 @@ private static AzureADToken getTokenSingleCall(String authEndpoint,
406517
conn.setReadTimeout(READ_TIMEOUT);
407518
conn.setConnectTimeout(CONNECT_TIMEOUT);
408519

520+
if (isArc) {
521+
conn.setRequestProperty("Authorization", "Basic " + challengerToken);
522+
}
523+
409524
if (headers != null && headers.size() > 0) {
410525
for (Map.Entry<String, String> entry : headers.entrySet()) {
411526
conn.setRequestProperty(entry.getKey(), entry.getValue());

hadoop-tools/hadoop-azure/src/main/java/org/apache/hadoop/fs/azurebfs/oauth2/MsiTokenProvider.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ public class MsiTokenProvider extends AccessTokenProvider {
3030

3131
private final String authEndpoint;
3232

33+
private final String apiVersion;
34+
3335
private final String authority;
3436

3537
private final String tenantGuid;
@@ -42,19 +44,20 @@ public class MsiTokenProvider extends AccessTokenProvider {
4244

4345
private static final Logger LOG = LoggerFactory.getLogger(AccessTokenProvider.class);
4446

45-
public MsiTokenProvider(final String authEndpoint, final String tenantGuid,
47+
public MsiTokenProvider(final String authEndpoint, final String apiVersion, final String tenantGuid,
4648
final String clientId, final String authority) {
4749
this.authEndpoint = authEndpoint;
4850
this.tenantGuid = tenantGuid;
4951
this.clientId = clientId;
5052
this.authority = authority;
53+
this.apiVersion = apiVersion;
5154
}
5255

5356
@Override
5457
protected AzureADToken refreshToken() throws IOException {
5558
LOG.debug("AADToken: refreshing token from MSI");
5659
AzureADToken token = AzureADAuthenticator
57-
.getTokenFromMsi(authEndpoint, tenantGuid, clientId, authority, false);
60+
.getTokenFromMsi(authEndpoint, apiVersion, tenantGuid, clientId, authority, false);
5861
tokenFetchTime = System.currentTimeMillis();
5962
return token;
6063
}

0 commit comments

Comments
 (0)