2121import java .io .FileNotFoundException ;
2222import java .io .IOException ;
2323import java .io .InputStream ;
24+ import java .io .BufferedReader ;
25+ import java .io .FileReader ;
2426import java .net .HttpURLConnection ;
2527import java .net .MalformedURLException ;
2628import java .net .URL ;
@@ -169,11 +171,11 @@ public static AzureADToken getTokenUsingJWTAssertion(String authEndpoint,
169171 * @return {@link AzureADToken} obtained using the creds
170172 * @throws IOException throws IOException if there is a failure in obtaining the token
171173 */
172- public static AzureADToken getTokenFromMsi (final String authEndpoint ,
174+ public static AzureADToken getTokenFromMsi (final String authEndpoint , final String apiVersion ,
173175 final String tenantGuid , final String clientId , String authority ,
174176 boolean bypassCache ) throws IOException {
175177 QueryParams qp = new QueryParams ();
176- qp .add ("api-version" , "2018-02-01" );
178+ qp .add ("api-version" , apiVersion );
177179 qp .add ("resource" , RESOURCE_NAME );
178180
179181 if (tenantGuid != null && tenantGuid .length () > 0 ) {
@@ -194,7 +196,51 @@ public static AzureADToken getTokenFromMsi(final String authEndpoint,
194196 headers .put ("Metadata" , "true" );
195197
196198 LOG .debug ("AADToken: starting to fetch token using MSI" );
197- return getTokenCall (authEndpoint , qp .serialize (), headers , "GET" , true );
199+ return getTokenCall (authEndpoint , qp .serialize (), headers , "GET" , true , false );
200+ }
201+
202+ /**
203+ * Gets AAD token from the local virtual machine's ARC extension. This only works on
204+ * an Azure VM with MSI extension
205+ * enabled.
206+ *
207+ * @param authEndpoint the OAuth 2.0 token endpoint associated
208+ * with the user's directory (obtain from
209+ * Active Directory configuration)
210+ * @param tenantGuid (optional) The guid of the AAD tenant. Can be {@code null}.
211+ * @param clientId (optional) The clientId guid of the MSI service
212+ * principal to use. Can be {@code null}.
213+ * @param bypassCache {@code boolean} specifying whether a cached token is acceptable or a fresh token
214+ * request should me made to AAD
215+ * @return {@link AzureADToken} obtained using the creds
216+ * @throws IOException throws IOException if there is a failure in obtaining the token
217+ */
218+ public static AzureADToken getTokenFromArcMsi (final String authEndpoint , final String apiVersion ,
219+ final String tenantGuid , final String clientId , String authority ,
220+ boolean bypassCache ) throws IOException {
221+ QueryParams qp = new QueryParams ();
222+ qp .add ("api-version" , apiVersion );
223+ qp .add ("resource" , RESOURCE_NAME );
224+
225+ if (tenantGuid != null && tenantGuid .length () > 0 ) {
226+ authority = authority + tenantGuid ;
227+ LOG .debug ("MSI authority : {}" , authority );
228+ qp .add ("authority" , authority );
229+ }
230+
231+ if (clientId != null && clientId .length () > 0 ) {
232+ qp .add ("client_id" , clientId );
233+ }
234+
235+ if (bypassCache ) {
236+ qp .add ("bypass_cache" , "true" );
237+ }
238+
239+ Hashtable <String , String > headers = new Hashtable <>();
240+ headers .put ("Metadata" , "true" );
241+
242+ LOG .debug ("AADToken: starting to fetch token using MSI from ARC" );
243+ return getTokenCall (authEndpoint , qp .serialize (), headers , "GET" , true , true );
198244 }
199245
200246 /**
@@ -327,11 +373,11 @@ public UnexpectedResponseException(final int httpErrorCode,
327373
328374 private static AzureADToken getTokenCall (String authEndpoint , String body ,
329375 Hashtable <String , String > headers , String httpMethod ) throws IOException {
330- return getTokenCall (authEndpoint , body , headers , httpMethod , false );
376+ return getTokenCall (authEndpoint , body , headers , httpMethod , false , false );
331377 }
332378
333379 private static AzureADToken getTokenCall (String authEndpoint , String body ,
334- Hashtable <String , String > headers , String httpMethod , boolean isMsi )
380+ Hashtable <String , String > headers , String httpMethod , boolean isMsi , boolean isArc )
335381 throws IOException {
336382 AzureADToken token = null ;
337383
@@ -346,7 +392,7 @@ private static AzureADToken getTokenCall(String authEndpoint, String body,
346392 httperror = 0 ;
347393 ex = null ;
348394 try {
349- token = getTokenSingleCall (authEndpoint , body , headers , httpMethod , isMsi );
395+ token = getTokenSingleCall (authEndpoint , body , headers , httpMethod , isMsi , isArc );
350396 } catch (HttpException e ) {
351397 httperror = e .httpErrorCode ;
352398 ex = e ;
@@ -385,18 +431,83 @@ private static boolean isRecoverableFailure(IOException e) {
385431
386432 private static AzureADToken getTokenSingleCall (String authEndpoint ,
387433 String payload , Hashtable <String , String > headers , String httpMethod ,
388- boolean isMsi )
434+ boolean isMsi , boolean isArc )
389435 throws IOException {
390436
391437 AzureADToken token = null ;
392438 HttpURLConnection conn = null ;
393439 String urlString = authEndpoint ;
440+ String challengerToken = null ;
394441
395442 httpMethod = (httpMethod == null ) ? "POST" : httpMethod ;
396443 if (httpMethod .equals ("GET" )) {
397444 urlString = urlString + "?" + payload ;
398445 }
399446
447+ if (isArc ) {
448+ // Currently there is a known flow that ARC needs obtain a challenge token first
449+ // before and in order to get access_token from the same MSI endpoint
450+ try {
451+ LOG .debug ("Requesting a challenge token by {} to {}" ,
452+ httpMethod , authEndpoint );
453+ URL url = new URL (urlString );
454+ conn = (HttpURLConnection ) url .openConnection ();
455+ conn .setRequestMethod (httpMethod );
456+ conn .setReadTimeout (READ_TIMEOUT );
457+ conn .setConnectTimeout (CONNECT_TIMEOUT );
458+
459+ if (headers != null && headers .size () > 0 ) {
460+ for (Map .Entry <String , String > entry : headers .entrySet ()) {
461+ conn .setRequestProperty (entry .getKey (), entry .getValue ());
462+ }
463+ }
464+ conn .setRequestProperty ("Connection" , "close" );
465+ AbfsIoUtils .dumpHeadersToDebugLog ("Request Headers" ,
466+ conn .getRequestProperties ());
467+ if (httpMethod .equals ("POST" )) {
468+ conn .setDoOutput (true );
469+ conn .getOutputStream ().write (payload .getBytes (StandardCharsets .UTF_8 ));
470+ }
471+ AbfsIoUtils .dumpHeadersToDebugLog ("Response Headers" ,
472+ conn .getHeaderFields ());
473+
474+ int httpResponseCode = conn .getResponseCode ();
475+ String requestId = conn .getHeaderField ("x-ms-request-id" );
476+ String responseContentType = conn .getHeaderField ("Content-Type" );
477+ String operation = "Challenge Token: HTTP connection to " + authEndpoint
478+ + " failed for getting challenge token from ARC MSI endpoint." ;
479+ InputStream stream = conn .getErrorStream ();
480+ if (stream == null ) {
481+ // no error stream, try the original input stream
482+ stream = conn .getInputStream ();
483+ }
484+ String responseBody = consumeInputStream (stream , 1024 );
485+
486+ String authHeader = conn .getHeaderField ("Www-Authenticate" );
487+ if (authHeader != null ) {
488+ // Extract the challenge token path
489+ int index = authHeader .indexOf ('=' );
490+ if (index != -1 ) {
491+ String authHeaderPath = authHeader .substring (index + 1 ).trim ();
492+ try (BufferedReader reader = new BufferedReader (new FileReader (authHeaderPath ))) {
493+ challengerToken = reader .readLine ().trim ();
494+ }
495+ }
496+ } else {
497+ throw new HttpException (httpResponseCode ,
498+ requestId ,
499+ operation ,
500+ authEndpoint ,
501+ responseContentType ,
502+ responseBody );
503+ }
504+ } finally {
505+ if (conn != null ) {
506+ conn .disconnect ();
507+ }
508+ }
509+ }
510+
400511 try {
401512 LOG .debug ("Requesting an OAuth token by {} to {}" ,
402513 httpMethod , authEndpoint );
@@ -406,6 +517,10 @@ private static AzureADToken getTokenSingleCall(String authEndpoint,
406517 conn .setReadTimeout (READ_TIMEOUT );
407518 conn .setConnectTimeout (CONNECT_TIMEOUT );
408519
520+ if (isArc ) {
521+ conn .setRequestProperty ("Authorization" , "Basic " + challengerToken );
522+ }
523+
409524 if (headers != null && headers .size () > 0 ) {
410525 for (Map .Entry <String , String > entry : headers .entrySet ()) {
411526 conn .setRequestProperty (entry .getKey (), entry .getValue ());
0 commit comments