Skip to content

Commit

Permalink
Fix Flaky Core Tests (#23600)
Browse files Browse the repository at this point in the history
Fix Flaky Core Tests
  • Loading branch information
alzimmermsft authored Aug 17, 2021
1 parent 14c3a07 commit acdfb61
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.time.temporal.ChronoUnit;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.azure.core.util.CoreUtils.isNullOrEmpty;

Expand Down Expand Up @@ -155,7 +156,7 @@ static Duration determineDelayDuration(HttpResponse response, int tryCount, Retr
String retryAfterHeader, ChronoUnit retryAfterTimeUnit) {
// If the retry after header hasn't been configured, attempt to look up the well-known headers.
if (isNullOrEmpty(retryAfterHeader)) {
return getWellKnownRetryDelay(response.getHeaders(), tryCount, retryStrategy);
return getWellKnownRetryDelay(response.getHeaders(), tryCount, retryStrategy, OffsetDateTime::now);
}

String retryHeaderValue = response.getHeaderValue(retryAfterHeader);
Expand All @@ -172,7 +173,8 @@ static Duration determineDelayDuration(HttpResponse response, int tryCount, Retr
/*
* Determines the delay duration that should be waited before retrying using the well-known retry headers.
*/
static Duration getWellKnownRetryDelay(HttpHeaders responseHeaders, int tryCount, RetryStrategy retryStrategy) {
static Duration getWellKnownRetryDelay(HttpHeaders responseHeaders, int tryCount, RetryStrategy retryStrategy,
Supplier<OffsetDateTime> nowSupplier) {
// Found 'x-ms-retry-after-ms' header, use a Duration of milliseconds based on the value.
Duration retryDelay = tryGetRetryDelay(responseHeaders, X_MS_RETRY_AFTER_MS_HEADER,
RetryPolicy::tryGetDelayMillis);
Expand All @@ -188,7 +190,8 @@ static Duration getWellKnownRetryDelay(HttpHeaders responseHeaders, int tryCount

// Found 'Retry-After' header. First, attempt to resolve it as a Duration of seconds. If that fails, then
// attempt to resolve it as an HTTP date (RFC1123).
retryDelay = tryGetRetryDelay(responseHeaders, RETRY_AFTER_HEADER, RetryPolicy::tryParseLongOrDateTime);
retryDelay = tryGetRetryDelay(responseHeaders, RETRY_AFTER_HEADER,
headerValue -> tryParseLongOrDateTime(headerValue, nowSupplier));
if (retryDelay != null) {
return retryDelay;
}
Expand All @@ -209,12 +212,12 @@ private static Duration tryGetDelayMillis(String value) {
return (delayMillis >= 0) ? Duration.ofMillis(delayMillis) : null;
}

private static Duration tryParseLongOrDateTime(String value) {
private static Duration tryParseLongOrDateTime(String value, Supplier<OffsetDateTime> nowSupplier) {
long delaySeconds;
try {
OffsetDateTime retryAfter = new DateTimeRfc1123(value).getDateTime();

delaySeconds = OffsetDateTime.now().until(retryAfter, ChronoUnit.SECONDS);
delaySeconds = nowSupplier.get().until(retryAfter, ChronoUnit.SECONDS);
} catch (DateTimeException ex) {
delaySeconds = tryParseLong(value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,35 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class TokenCacheTests {
private static final Random RANDOM = new Random();

@Test
public void testOnlyOneThreadRefreshesToken() throws Exception {
AtomicLong refreshes = new AtomicLong(0);

// Token acquisition time grows in 1 sec, 2 sec... To make sure only one token acquisition is run
SimpleTokenCache cache = new SimpleTokenCache(() -> incrementalRemoteGetTokenAsync(new AtomicInteger(1)));
SimpleTokenCache cache = new SimpleTokenCache(() -> {
refreshes.incrementAndGet();
return incrementalRemoteGetTokenAsync(new AtomicInteger(1));
});

CountDownLatch latch = new CountDownLatch(1);
AtomicLong maxMillis = new AtomicLong(0);

Flux.range(1, 10)
.flatMap(i -> Mono.just(OffsetDateTime.now())
// Runs cache.getToken() on 10 different threads
.publishOn(Schedulers.parallel())
.flatMap(start -> cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
if (millis > maxMillis.get()) {
maxMillis.set(millis);
}
// System.out.format("Thread: %s\tDuration: %smillis%n",
// Thread.currentThread().getName(), Duration.between(start, OffsetDateTime.now()).toMillis());
})))

Flux.range(1, 10).flatMap(ignored -> Mono.just(OffsetDateTime.now()))
.parallel(10)
// Runs cache.getToken() on 10 different threads
.runOn(Schedulers.boundedElastic())
.flatMap(start -> cache.getToken())
.doOnComplete(latch::countDown)
.subscribe();

latch.await();
long maxMs = maxMillis.get();
Assertions.assertTrue(maxMs > 1000, () -> String.format("maxMillis was less than 1000ms. Was %d.", maxMs));

// Big enough for any latency, small enough to make sure no get token is called twice
Assertions.assertTrue(maxMs < 2000, () -> String.format("maxMillis was greater than 2000ms. Was %d.", maxMs));
// Ensure that only one refresh attempt is made.
assertEquals(1, refreshes.get());
}

@Test
Expand All @@ -59,23 +55,15 @@ public void testLongRunningWontOverflow() throws Exception {
// token expires on creation. Run this 100 times to simulate running the application a long time
SimpleTokenCache cache = new SimpleTokenCache(() -> {
refreshes.incrementAndGet();
return remoteGetTokenThatExpiresSoonAsync(1000, 0);
return remoteGetTokenThatExpiresSoonAsync();
});

VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.create();
CountDownLatch latch = new CountDownLatch(1);

Flux.interval(Duration.ofMillis(100), virtualTimeScheduler)
.take(100)
.flatMap(i -> Mono.just(OffsetDateTime.now())
// Runs cache.getToken() on 10 different threads
.subscribeOn(Schedulers.parallel())
.flatMap(start -> cache.getToken()
.map(t -> Duration.between(start, OffsetDateTime.now()).toMillis())
.doOnNext(millis -> {
// System.out.format("Thread: %s\tDuration: %smillis%n",
// Thread.currentThread().getName(), Duration.between(start, OffsetDateTime.now()).toMillis());
})))
.flatMap(i -> cache.getToken())
.doOnComplete(latch::countDown)
.subscribe();

Expand All @@ -86,14 +74,8 @@ public void testLongRunningWontOverflow() throws Exception {
Assertions.assertTrue(refreshes.get() <= 11);
}

private Mono<AccessToken> remoteGetTokenAsync(long delayInMillis) {
return Mono.delay(Duration.ofMillis(delayInMillis))
.map(l -> new Token(Integer.toString(RANDOM.nextInt(100))));
}

private Mono<AccessToken> remoteGetTokenThatExpiresSoonAsync(long delayInMillis, long validityInMillis) {
return Mono.delay(Duration.ofMillis(delayInMillis))
.map(l -> new Token(Integer.toString(RANDOM.nextInt(100)), validityInMillis));
private Mono<AccessToken> remoteGetTokenThatExpiresSoonAsync() {
return Mono.delay(Duration.ofMillis(1000)).map(l -> new Token(Integer.toString(RANDOM.nextInt(100)), 0));
}

// First token takes latency seconds, and adds 1 sec every subsequent call
Expand All @@ -103,32 +85,12 @@ private Mono<AccessToken> incrementalRemoteGetTokenAsync(AtomicInteger latency)
}

private static class Token extends AccessToken {
private String token;
private OffsetDateTime expiry;

@Override
public String getToken() {
return token;
}

Token(String token) {
this(token, 5000);
}

Token(String token, long validityInMillis) {
super(token, OffsetDateTime.now().plus(Duration.ofMillis(validityInMillis)));
this.token = token;
this.expiry = OffsetDateTime.now().plus(Duration.ofMillis(validityInMillis));
}

@Override
public OffsetDateTime getExpiresAt() {
return expiry;
}

@Override
public boolean isExpired() {
return OffsetDateTime.now().isAfter(expiry);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import com.azure.core.util.logging.LogLevel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.junit.jupiter.api.parallel.Isolated;
import org.junit.jupiter.api.parallel.ResourceLock;
import org.junit.jupiter.api.parallel.Resources;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
Expand All @@ -28,8 +32,8 @@
import reactor.test.StepVerifier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.UncheckedIOException;
import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -50,6 +54,9 @@
/**
* This class contains tests for {@link HttpLoggingPolicy}.
*/
@Execution(ExecutionMode.SAME_THREAD)
@Isolated
@ResourceLock(Resources.SYSTEM_OUT)
public class HttpLoggingPolicyTests {
private static final String REDACTED = "REDACTED";
private static final Context CONTEXT = new Context("caller-method", HttpLoggingPolicyTests.class.getName());
Expand All @@ -61,9 +68,7 @@ public class HttpLoggingPolicyTests {
@BeforeEach
public void prepareForTest() {
// Set the log level to information for the test.
originalLogLevel = Configuration.getGlobalConfiguration().get(PROPERTY_AZURE_LOG_LEVEL);
Configuration.getGlobalConfiguration().put(PROPERTY_AZURE_LOG_LEVEL,
String.valueOf(LogLevel.INFORMATIONAL.getLogLevel()));
setupLogLevel(LogLevel.INFORMATIONAL.getLogLevel());

/*
* DefaultLogger uses System.out to log. Inject a custom PrintStream to log into for the duration of the test to
Expand All @@ -75,25 +80,19 @@ public void prepareForTest() {
}

@AfterEach
public void cleanupAfterTest() throws IOException {
public void cleanupAfterTest() {
// Reset or clear the log level after the test completes.
if (CoreUtils.isNullOrEmpty(originalLogLevel)) {
Configuration.getGlobalConfiguration().remove(PROPERTY_AZURE_LOG_LEVEL);
} else {
Configuration.getGlobalConfiguration().put(PROPERTY_AZURE_LOG_LEVEL, originalLogLevel);
}
setPropertyToOriginalOrClear(originalLogLevel);

// Reset System.err to the original PrintStream.
System.setOut(originalSystemOut);
logCaptureStream.close();
}

/**
* Tests that a query string will be properly redacted before it is logged.
*/
@ParameterizedTest
@MethodSource("redactQueryParametersSupplier")
@ResourceLock("SYSTEM_OUT")
public void redactQueryParameters(String requestUrl, String expectedQueryString,
Set<String> allowedQueryParameters) {
HttpPipeline pipeline = new HttpPipelineBuilder()
Expand Down Expand Up @@ -138,7 +137,6 @@ private static Stream<Arguments> redactQueryParametersSupplier() {
*/
@ParameterizedTest(name = "[{index}] {displayName}")
@MethodSource("validateLoggingDoesNotConsumeSupplier")
@ResourceLock("SYSTEM_OUT")
public void validateLoggingDoesNotConsumeRequest(Flux<ByteBuffer> stream, byte[] data, int contentLength)
throws MalformedURLException {
URL requestUrl = new URL("https://test.com");
Expand All @@ -154,7 +152,7 @@ public void validateLoggingDoesNotConsumeRequest(Flux<ByteBuffer> stream, byte[]
.build();

StepVerifier.create(pipeline.send(new HttpRequest(HttpMethod.POST, requestUrl, requestHeaders, stream),
CONTEXT))
CONTEXT))
.verifyComplete();

String logString = convertOutputStreamToString(logCaptureStream);
Expand All @@ -166,7 +164,6 @@ public void validateLoggingDoesNotConsumeRequest(Flux<ByteBuffer> stream, byte[]
*/
@ParameterizedTest(name = "[{index}] {displayName}")
@MethodSource("validateLoggingDoesNotConsumeSupplier")
@ResourceLock("SYSTEM_OUT")
public void validateLoggingDoesNotConsumeResponse(Flux<ByteBuffer> stream, byte[] data, int contentLength) {
HttpRequest request = new HttpRequest(HttpMethod.GET, "https://test.com");
HttpHeaders responseHeaders = new HttpHeaders()
Expand Down Expand Up @@ -276,8 +273,7 @@ public Mono<String> getBodyAsString(Charset charset) {

@ParameterizedTest(name = "[{index}] {displayName}")
@EnumSource(value = HttpLogDetailLevel.class, mode = EnumSource.Mode.INCLUDE,
names = { "BASIC", "HEADERS", "BODY", "BODY_AND_HEADERS" })
@ResourceLock("SYSTEM_OUT")
names = {"BASIC", "HEADERS", "BODY", "BODY_AND_HEADERS"})
public void loggingIncludesRetryCount(HttpLogDetailLevel logLevel) {
AtomicInteger requestCount = new AtomicInteger();
HttpRequest request = new HttpRequest(HttpMethod.GET, "https://test.com");
Expand All @@ -298,11 +294,24 @@ public void loggingIncludesRetryCount(HttpLogDetailLevel logLevel) {
assertTrue(logString.contains("Try count: 2"));
}

private void setupLogLevel(int logLevelToSet) {
originalLogLevel = Configuration.getGlobalConfiguration().get(PROPERTY_AZURE_LOG_LEVEL);
Configuration.getGlobalConfiguration().put(PROPERTY_AZURE_LOG_LEVEL, String.valueOf(logLevelToSet));
}

private void setPropertyToOriginalOrClear(String originalValue) {
if (CoreUtils.isNullOrEmpty(originalValue)) {
Configuration.getGlobalConfiguration().remove(PROPERTY_AZURE_LOG_LEVEL);
} else {
Configuration.getGlobalConfiguration().put(PROPERTY_AZURE_LOG_LEVEL, originalValue);
}
}

private static String convertOutputStreamToString(ByteArrayOutputStream stream) {
try {
return stream.toString("UTF-8");
return stream.toString(StandardCharsets.UTF_8.name());
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
throw new UncheckedIOException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -282,12 +281,13 @@ public Mono<HttpResponse> send(HttpRequest request) {
.verifyComplete();
}

@SuppressWarnings("ReactiveStreamsUnusedPublisher")
@Test
public void retryConsumesBody() {
final AtomicInteger bodyConsumptionCount = new AtomicInteger();
Flux<ByteBuffer> errorBody = Flux.generate(sink -> {
bodyConsumptionCount.incrementAndGet();
sink.next(ByteBuffer.wrap("Should be consumed" .getBytes(StandardCharsets.UTF_8)));
sink.next(ByteBuffer.wrap("Should be consumed".getBytes(StandardCharsets.UTF_8)));
sink.complete();
});

Expand Down Expand Up @@ -341,7 +341,8 @@ public Mono<String> getBodyAsString(Charset charset) {
@ParameterizedTest
@MethodSource("getWellKnownRetryDelaySupplier")
public void getWellKnownRetryDelay(HttpHeaders responseHeaders, RetryStrategy retryStrategy, Duration expected) {
assertEquals(expected, RetryPolicy.getWellKnownRetryDelay(responseHeaders, 1, retryStrategy));
assertEquals(expected, RetryPolicy.getWellKnownRetryDelay(responseHeaders, 1, retryStrategy,
OffsetDateTime::now));
}

private static Stream<Arguments> getWellKnownRetryDelaySupplier() {
Expand Down Expand Up @@ -381,14 +382,11 @@ private static Stream<Arguments> getWellKnownRetryDelaySupplier() {

@Test
public void retryAfterDateTime() {
HttpHeaders headers = new HttpHeaders().set("Retry-After",
new DateTimeRfc1123(OffsetDateTime.now().plusSeconds(30)).toString());
Duration actual = RetryPolicy.getWellKnownRetryDelay(headers, 1, null);

// Since DateTime based Retry-After uses OffsetDateTime.now internally make sure this result skew isn't larger
// than an allowable bound.
Duration skew = Duration.ofSeconds(30).minus(actual);
assertTrue(skew.getSeconds() < 2);
OffsetDateTime now = OffsetDateTime.now().withNano(0);
HttpHeaders headers = new HttpHeaders().set("Retry-After", new DateTimeRfc1123(now.plusSeconds(30)).toString());
Duration actual = RetryPolicy.getWellKnownRetryDelay(headers, 1, null, () -> now);

assertEquals(Duration.ofSeconds(30), actual);
}

private static RetryStrategy createStatusCodeRetryStrategy(int... retriableErrorCodes) {
Expand Down
Loading

0 comments on commit acdfb61

Please sign in to comment.