diff --git a/Refit.Tests/AuthenticatedClientHandlerTests.cs b/Refit.Tests/AuthenticatedClientHandlerTests.cs index 7dce85749..98f02379e 100644 --- a/Refit.Tests/AuthenticatedClientHandlerTests.cs +++ b/Refit.Tests/AuthenticatedClientHandlerTests.cs @@ -4,8 +4,11 @@ using System.Net.Http; using System.Text; using System.Threading.Tasks; + using Refit; // for the code gen + using RichardSzalay.MockHttp; + using Xunit; namespace Refit.Tests; @@ -49,6 +52,12 @@ public interface IInheritedAuthenticatedServiceWithHeaders : IAuthenticatedServi Task GetInheritedThing(); } + public interface IInheritedAuthenticatedServiceWithHeadersCRLF : IAuthenticatedServiceWithHeaders + { + [Get("/get-inherited-thing\r\n\r\nGET /smuggled")] + Task GetInheritedThing(); + } + [Headers("Authorization: Bearer")] public interface IAuthenticatedServiceWithHeaders { @@ -347,4 +356,30 @@ public async void AuthentictedMethodFromInheritedClassWithHeadersAttributeUsesAu Assert.Equal("Ok", result); } + + [Fact] + public async void AuthentictedMethodFromInheritedClassWithHeadersAttributeUsesAuth_WithCRLFCheck() + { + var handler = new MockHttpMessageHandler(); + var settings = new RefitSettings() + { + AuthorizationHeaderValueGetter = (_, __) => Task.FromResult("tokenValue"), + HttpMessageHandlerFactory = () => handler, + }; + + handler + .Expect(HttpMethod.Get, "http://api/get-inherited-thing") + .WithHeaders("Authorization", "Bearer tokenValue") + .Respond("text/plain", "Ok"); + + await Assert.ThrowsAsync(async () => + { + var fixture = RestService.For( + "http://api", + settings + ); + + var result = await fixture.GetInheritedThing(); + }); + } } diff --git a/Refit/RequestBuilderImplementation.cs b/Refit/RequestBuilderImplementation.cs index 807cafb79..2882d3ed6 100644 --- a/Refit/RequestBuilderImplementation.cs +++ b/Refit/RequestBuilderImplementation.cs @@ -977,7 +977,7 @@ static void AddHeadersToRequest(Dictionary? headersToAdd, HttpR // sure we have an HttpContent object to add them to, // provided the HttpClient will allow it for the method if (ret.Content == null && !IsBodyless(ret.Method)) - ret.Content = new ByteArrayContent(Array.Empty()); + ret.Content = new ByteArrayContent([]); foreach (var header in headersToAdd) { @@ -1335,6 +1335,10 @@ static void SetHeader(HttpRequestMessage request, string name, string? value) if (value == null) return; + // CRLF injection protection + name = EnsureSafe(name); + value = EnsureSafe(value); + var added = request.Headers.TryAddWithoutValidation(name, value); // Don't even bother trying to add the header as a content header @@ -1345,6 +1349,14 @@ static void SetHeader(HttpRequestMessage request, string name, string? value) } } + static string EnsureSafe(string value) + { + // Remove CR and LF characters +#pragma warning disable CA1307 // Specify StringComparison for clarity + return value.Replace("\r", string.Empty).Replace("\n", string.Empty); +#pragma warning restore CA1307 // Specify StringComparison for clarity + } + static bool IsBodyless(HttpMethod method) => method == HttpMethod.Get || method == HttpMethod.Head; } } diff --git a/Refit/RestMethodInfo.cs b/Refit/RestMethodInfo.cs index a22dd5262..901dfc338 100644 --- a/Refit/RestMethodInfo.cs +++ b/Refit/RestMethodInfo.cs @@ -259,6 +259,12 @@ static void VerifyUrlPathIsSane(string relativePath) throw new ArgumentException( $"URL path {relativePath} must start with '/' and be of the form '/foo/bar/baz'" ); + + // CRLF injection protection + if (relativePath.Contains("\r") || relativePath.Contains("\n")) + throw new ArgumentException( + $"URL path {relativePath} must not contain CR or LF characters" + ); } static Dictionary BuildParameterMap(