Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.Shared.DiagnosticIds;
using Microsoft.Shared.Diagnostics;
using Polly;
using Polly.Retry;

namespace Microsoft.Extensions.Http.Resilience;

Expand Down Expand Up @@ -52,15 +53,16 @@ public static void DisableFor(this HttpRetryStrategyOptions options, params Http
{
var result = await shouldHandle(args).ConfigureAwait(args.Context.ContinueOnCapturedContext);

if (result &&
args.Outcome.Result is HttpResponseMessage response &&
response.RequestMessage is HttpRequestMessage request)
if (result && GetRequestMessage(args) is HttpRequestMessage request)
{
return !methods.Contains(request.Method);
}

return result;
};
}

private static HttpRequestMessage? GetRequestMessage(RetryPredicateArguments<HttpResponseMessage> args) =>
args.Outcome.Result?.RequestMessage ?? args.Context.GetRequestMessage();
}

Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,16 @@ public async Task DisableFor_RespectsOriginalShouldHandlePredicate()
}

[Fact]
public async Task DisableFor_ResponseMessageIsNull_DoesNotDisableRetries()
public async Task DisableFor_ResponseMessageIsNull_RetrievesRequestMessageFromContext()
{
var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.True() };
options.DisableFor(HttpMethod.Post);

Assert.True(await options.ShouldHandle(CreatePredicateArguments(null)));
using var request = new HttpRequestMessage { Method = HttpMethod.Post };
var context = ResilienceContextPool.Shared.Get();
context.SetRequestMessage(request);

Assert.False(await options.ShouldHandle(CreatePredicateArguments(null, context)));
}

[Fact]
Expand All @@ -80,8 +84,10 @@ public async Task DisableFor_RequestMessageIsNull_DoesNotDisableRetries()
options.DisableFor(HttpMethod.Post);

using var response = new HttpResponseMessage { RequestMessage = null };
var context = ResilienceContextPool.Shared.Get();
context.SetRequestMessage(null);

Assert.True(await options.ShouldHandle(CreatePredicateArguments(response)));
Assert.True(await options.ShouldHandle(CreatePredicateArguments(response, context)));
}

[Theory]
Expand All @@ -105,10 +111,10 @@ public async Task DisableForUnsafeHttpMethods_PositiveScenario(string httpMethod
Assert.Equal(shouldHandle, await options.ShouldHandle(CreatePredicateArguments(response)));
}

private static RetryPredicateArguments<HttpResponseMessage> CreatePredicateArguments(HttpResponseMessage? response)
private static RetryPredicateArguments<HttpResponseMessage> CreatePredicateArguments(HttpResponseMessage? response, ResilienceContext? context = null)
{
return new RetryPredicateArguments<HttpResponseMessage>(
ResilienceContextPool.Shared.Get(),
context ?? ResilienceContextPool.Shared.Get(),
Outcome.FromResult(response),
attemptNumber: 1);
}
Expand Down
Loading