Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve region support and add region telemetry #388

Merged
merged 5 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
<version>9.4</version>
<version>9.7</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,42 @@ private static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL aut
MsalRequest msalRequest,
ServiceBundle serviceBundle) {

String region = StringHelper.EMPTY_STRING;
IHttpResponse httpResponse = null;
String providedRegion = msalRequest.application().azureRegion();
String detectedRegion = null;
int regionOutcomeTelemetryValue = 0;
String regionToUse = null;

//If a region was provided by a developer or they set the autoDetectRegion parameter,
// attempt to discover the region and set telemetry info based on the outcome
if (providedRegion != null) {
detectedRegion = discoverRegion(msalRequest, serviceBundle);
Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved
regionToUse = providedRegion;
regionOutcomeTelemetryValue = determineRegionOutcome(detectedRegion, providedRegion, msalRequest.application().autoDetectRegion());
} else if (msalRequest.application().autoDetectRegion()) {
detectedRegion = discoverRegion(msalRequest, serviceBundle);

if (detectedRegion != null) {
regionToUse = detectedRegion;
}
Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved

//If the autoDetectRegion parameter in the request is set, attempt to discover the region
if (msalRequest.application().autoDetectRegion()) {
region = discoverRegion(msalRequest, serviceBundle);
regionOutcomeTelemetryValue = determineRegionOutcome(detectedRegion, providedRegion, msalRequest.application().autoDetectRegion());
}

//If the region is known, attempt to make instance discovery request with region endpoint
if (!region.isEmpty()) {
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpointWithRegion(authorityUrl.getAuthority(), region) +
if (regionToUse != null) {
String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpointWithRegion(authorityUrl.getAuthority(), regionToUse) +
formInstanceDiscoveryParameters(authorityUrl);

httpResponse = executeRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
try {
httpResponse = executeRequest(instanceDiscoveryRequestUrl, msalRequest.headers().getReadonlyHeaderMap(), msalRequest, serviceBundle);
} catch (MsalClientException ex) {
log.warn("Could not retrieve regional instance discovery metadata, falling back to global endpoint");
}
}

//If the region is unknown or the instance discovery failed at the region endpoint, try the global endpoint
if (region.isEmpty() || httpResponse == null || httpResponse.statusCode() != HTTPResponse.SC_OK) {
if (!region.isEmpty()) {
log.warn("Could not retrieve regional instance discovery metadata, falling back to global endpoint");
}
if ((detectedRegion == null && providedRegion == null) || httpResponse == null || httpResponse.statusCode() != HTTPResponse.SC_OK) {

String instanceDiscoveryRequestUrl = getInstanceDiscoveryEndpoint(authorityUrl.getAuthority()) +
formInstanceDiscoveryParameters(authorityUrl);
Expand All @@ -161,9 +176,33 @@ private static AadInstanceDiscoveryResponse sendInstanceDiscoveryRequest(URL aut
throw MsalServiceExceptionFactory.fromHttpResponse(httpResponse);
}

serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(regionOutcomeTelemetryValue);

return JsonHelper.convertJsonToObject(httpResponse.body(), AadInstanceDiscoveryResponse.class);
}

private static int determineRegionOutcome(String detectedRegion, String providedRegion, boolean autoDetect) {
int regionOutcomeTelemetryValue = 0;
if (providedRegion != null) {
if (detectedRegion == null) {
regionOutcomeTelemetryValue = RegionTelemetry.REGION_OUTCOME_DEVELOPER_AUTODETECT_FAILED.telemetryValue;
} else if (providedRegion.equals(detectedRegion)) {
regionOutcomeTelemetryValue = RegionTelemetry.REGION_OUTCOME_DEVELOPER_AUTODETECT_MATCH.telemetryValue;
} else {
regionOutcomeTelemetryValue = RegionTelemetry.REGION_OUTCOME_DEVELOPER_AUTODETECT_MISMATCH.telemetryValue;
}
}
else if (autoDetect) {
if (detectedRegion != null) {
regionOutcomeTelemetryValue = RegionTelemetry.REGION_OUTCOME_AUTODETECT_SUCCESS.telemetryValue;
} else {
regionOutcomeTelemetryValue = RegionTelemetry.REGION_OUTCOME_AUTODETECT_FAILED.telemetryValue;
}
}

return regionOutcomeTelemetryValue;
}

private static String formInstanceDiscoveryParameters(URL authorityUrl) {
return INSTANCE_DISCOVERY_REQUEST_PARAMETERS_TEMPLATE.replace("{authorizeEndpoint}",
getAuthorizeEndpoint(authorityUrl.getAuthority(),
Expand All @@ -184,9 +223,12 @@ private static IHttpResponse executeRequest(String requestUrl, Map<String, Strin

private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serviceBundle) {
Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved

CurrentRequest currentRequest = serviceBundle.getServerSideTelemetry().getCurrentRequest();

//Check if the REGION_NAME environment variable has a value for the region
if (System.getenv(REGION_NAME) != null) {
log.info("Region found in environment variable: " + System.getenv(REGION_NAME));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_ENV_VARIABLE.telemetryValue);

return System.getenv(REGION_NAME);
}
Expand All @@ -200,17 +242,21 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
//If call to IMDS endpoint was successful, return region from response body
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);

return httpResponse.body();
}

log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);

return StringHelper.EMPTY_STRING;
return null;
} catch (Exception e) {
//IMDS call failed, cannot find region
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
return StringHelper.EMPTY_STRING;
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);

return null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ abstract class AbstractClientApplicationBase implements IClientApplicationBase {
@Getter
private boolean autoDetectRegion;

@Accessors(fluent = true)
@Getter
private String azureRegion;

@Override
public CompletableFuture<IAuthenticationResult> acquireToken(AuthorizationCodeParameters parameters) {

Expand Down Expand Up @@ -327,6 +331,7 @@ abstract static class Builder<T extends Builder<T>> {
private AadInstanceDiscoveryResponse aadInstanceDiscoveryResponse;
private String clientCapabilities;
private boolean autoDetectRegion;
private String azureRegion;
private Integer connectTimeoutForDefaultHttpClient;
private Integer readTimeoutForDefaultHttpClient;

Expand Down Expand Up @@ -625,6 +630,21 @@ public T autoDetectRegion(boolean val) {
return self();
}

/**
* Indicates that the library should attempt to fetch the instance discovery metadata from the specified Azure region.
*
* If the region is valid, token requests will be sent to the regional ESTS endpoint rather than the global endpoint.
* If region information could not be verified, the library will fall back to using the global endpoint, which is also
* the default behavior if this value is not set.
*
* @param val boolean (default is false)
* @return instance of the Builder on which method was called
*/
public T azureRegion(String val) {
azureRegion = val;
return self();
}

Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved
abstract AbstractClientApplicationBase build();
}

Expand Down Expand Up @@ -652,6 +672,7 @@ public T autoDetectRegion(boolean val) {
aadAadInstanceDiscoveryResponse = builder.aadInstanceDiscoveryResponse;
clientCapabilities = builder.clientCapabilities;
autoDetectRegion = builder.autoDetectRegion;
azureRegion = builder.azureRegion;

if (aadAadInstanceDiscoveryResponse != null) {
AadInstanceDiscoveryProvider.cacheInstanceDiscoveryMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ AuthenticationResult execute() throws Exception {
res.refreshOn() < currTimeStampSec && res.expiresOn() >= currTimeStampSec;

if (silentRequest.parameters().forceRefresh() || afterRefreshOn || StringHelper.isBlank(res.accessToken())) {

//As of version 3 of the telemetry schema, there is a field for collecting data about why a token was refreshed,
// so here we set the telemetry value based on the cause of the refresh
if (silentRequest.parameters().forceRefresh()) {
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
} else if (afterRefreshOn) {
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_REFRESH_IN.telemetryValue);
} else if (res.expiresOn() < currTimeStampSec) {
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_ACCESS_TOKEN_EXPIRED.telemetryValue);
} else if (StringHelper.isBlank(res.accessToken())) {
clientApplication.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_NO_ACCESS_TOKEN.telemetryValue);
}

if (!StringHelper.isBlank(res.refreshToken())) {
RefreshTokenRequest refreshTokenRequest = new RefreshTokenRequest(
RefreshTokenParameters.builder(silentRequest.parameters().scopes(), res.refreshToken()).build(),
Expand Down
26 changes: 26 additions & 0 deletions src/main/java/com/microsoft/aad/msal4j/CacheTelemetry.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

/**
* Telemetry values covering the use of the cache in the library
*/
enum CacheTelemetry {
/**
* These values represent reasons why a token needed to be refreshed: either the flow does not use cached tokens (0),
* the force refresh parameter was set (1), there was no cached access token (2), the cached access token expired (3),
* or the cached token's refresh in time has passed (4)
*/
REFRESH_CACHE_NOT_USED(0),
REFRESH_FORCE_REFRESH(1),
REFRESH_NO_ACCESS_TOKEN(2),
REFRESH_ACCESS_TOKEN_EXPIRED(3),
REFRESH_REFRESH_IN(4);

final int telemetryValue;

CacheTelemetry(int telemetryValue){
this.telemetryValue = telemetryValue;
}
}
11 changes: 10 additions & 1 deletion src/main/java/com/microsoft/aad/msal4j/CurrentRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@ class CurrentRequest {
private final PublicApi publicApi;

@Setter
private boolean forceRefresh = false;
private int cacheInfo = -1;

@Setter
private String regionUsed = StringHelper.EMPTY_STRING;

@Setter
private int regionSource = 0;

@Setter
private int regionOutcome = 0;

CurrentRequest(PublicApi publicApi){
this.publicApi = publicApi;
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/com/microsoft/aad/msal4j/RegionTelemetry.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.aad.msal4j;

/**
* Telemetry values covering the use of regions in the library
*/
enum RegionTelemetry {
/**
* These values represent the source of the region info: either from the cache (2), environment variables (3), or the IMDS endpoint (4)
* One value is used for the failure cause when region autodetection was requested but a region could not be found (1)
*/
REGION_SOURCE_FAILED_AUTODETECT(1),
REGION_SOURCE_CACHE(2),
REGION_SOURCE_ENV_VARIABLE(3),
REGION_SOURCE_IMDS(4),
/**
* These values represent the result of the attempt to find region info
* Three values cover cases where developer provided a region and either it matches the autodetected region (1),
* autodetection failed (2), or the autodetected region does not match the developer provided region (3)
* Two values cover cases where developer just requested autodetection, and we either detected the region (4) or failed (5)
*/
REGION_OUTCOME_DEVELOPER_AUTODETECT_MATCH(1),
REGION_OUTCOME_DEVELOPER_AUTODETECT_FAILED(2),
REGION_OUTCOME_DEVELOPER_AUTODETECT_MISMATCH(3),
REGION_OUTCOME_AUTODETECT_SUCCESS(4),
REGION_OUTCOME_AUTODETECT_FAILED(5);

final int telemetryValue;

RegionTelemetry(int telemetryValue){
this.telemetryValue = telemetryValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.slf4j.LoggerFactory;

import java.lang.reflect.Array;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Iterator;
Expand All @@ -17,7 +16,7 @@ class ServerSideTelemetry {

private final static Logger log = LoggerFactory.getLogger(ServerSideTelemetry.class);

private final static String SCHEMA_VERSION = "2";
private final static String SCHEMA_VERSION = "5";
Avery-Dunn marked this conversation as resolved.
Show resolved Hide resolved
private final static String SCHEMA_PIPE_DELIMITER = "|";
private final static String SCHEMA_COMMA_DELIMITER = ",";
private final static String CURRENT_REQUEST_HEADER_NAME = "x-client-current-telemetry";
Expand Down Expand Up @@ -68,7 +67,13 @@ private synchronized String buildCurrentRequestHeader() {
String currentRequestHeader = SCHEMA_VERSION + SCHEMA_PIPE_DELIMITER +
currentRequest.publicApi().getApiId() +
SCHEMA_COMMA_DELIMITER +
currentRequest.forceRefresh() +
(currentRequest.cacheInfo() == -1 ? "" : currentRequest.cacheInfo()) +
SCHEMA_COMMA_DELIMITER +
currentRequest.regionUsed() +
SCHEMA_COMMA_DELIMITER +
currentRequest.regionSource() +
SCHEMA_COMMA_DELIMITER +
currentRequest.regionOutcome() +
SCHEMA_PIPE_DELIMITER;

if (currentRequestHeader.getBytes(StandardCharsets.UTF_8).length > CURRENT_REQUEST_MAX_SIZE) {
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/com/microsoft/aad/msal4j/SilentRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class SilentRequest extends MsalRequest {
application.authenticationAuthority :
Authority.createAuthority(new URL(parameters.authorityUrl()));

application.getServiceBundle().getServerSideTelemetry().getCurrentRequest().forceRefresh(
parameters.forceRefresh());
if (parameters.forceRefresh()) {
application.getServiceBundle().getServerSideTelemetry().getCurrentRequest().cacheInfo(
CacheTelemetry.REFRESH_FORCE_REFRESH.telemetryValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public void aadInstanceDiscoveryTest_AutoDetectRegion_NoRegionDetected() throws
app.getServiceBundle());

//Region detection will have been performed in the expected discoverRegion method, but these tests (likely) aren't
// being run in an Azure VM and nstance discovery will fall back to the global endpoint (login.microsoftonline.com)
// being run in an Azure VM and instance discovery will fall back to the global endpoint (login.microsoftonline.com)
Assert.assertEquals(entry.preferredNetwork(), "login.microsoftonline.com");
Assert.assertEquals(entry.preferredCache(), "login.windows.net");
Assert.assertEquals(entry.aliases().size(), 4);
Expand Down
Loading