diff --git a/README.md b/README.md index e1d2935e..715edbc1 100644 --- a/README.md +++ b/README.md @@ -569,6 +569,31 @@ app.UseEndpoints(endpoints => await app.RunAsync(); ``` +In order to ensure that all requests trigger CORS preflight requests, by default the server +will reject requests that do not meet one of the following criteria: + +- The request is a POST request that includes a Content-Type header that is not + `application/x-www-form-urlencoded`, `multipart/form-data`, or `text/plain`. +- The request includes a non-empty `GraphQL-Require-Preflight` header. + +To disable this behavior, set the `CsrfProtectionEnabled` option to `false` in the `GraphQLServerOptions`. + +```csharp +app.UseGraphQL("/graphql", config => +{ + config.CsrfProtectionEnabled = false; +}); +``` + +You may also change the allowed headers by modifying the `CsrfProtectionHeaders` option. + +```csharp +app.UseGraphQL("/graphql", config => +{ + config.CsrfProtectionHeaders = ["MyCustomHeader"]; +}); +``` + ### Response compression ASP.NET Core supports response compression independently of GraphQL, with brotli and gzip @@ -660,6 +685,8 @@ methods allowing for different options for each configured endpoint. | `AuthorizationRequired` | Requires `HttpContext.User` to represent an authenticated user. | False | | `AuthorizedPolicy` | If set, requires `HttpContext.User` to pass authorization of the specified policy. | | | `AuthorizedRoles` | If set, requires `HttpContext.User` to be a member of any one of a list of roles. | | +| `CsrfProtectionEnabled` | Enables cross-site request forgery (CSRF) protection for both GET and POST requests. | True | +| `CsrfProtectionHeaders` | Sets the headers used for CSRF protection when necessary. | `GraphQL-Require-Preflight` | | `DefaultResponseContentType` | Sets the default response content type used within responses. | `application/graphql-response+json; charset=utf-8` | | `EnableBatchedRequests` | Enables handling of batched GraphQL requests for POST requests when formatted as JSON. | True | | `ExecuteBatchedRequestsInParallel` | Enables parallel execution of batched GraphQL requests. | True | diff --git a/docs/migration/migration8.md b/docs/migration/migration8.md index 8e2b3667..21100f94 100644 --- a/docs/migration/migration8.md +++ b/docs/migration/migration8.md @@ -4,6 +4,8 @@ - When using `FormFileGraphType` with type-first schemas, you may specify the allowed media types for the file by using the new `[MediaType]` attribute on the argument or input object field. +- Cross-site request forgery (CSRF) protection has been added for both GET and POST requests, + enabled by default. ## Breaking changes @@ -11,6 +13,11 @@ GraphQL.NET library. Please see the GraphQL.NET v8 migration document for more information. - The obsolete (v6 and prior) authorization validation rule has been removed. See the v7 migration document for more information on how to migrate to the v7/v8 authorization validation rule. +- Cross-site request forgery (CSRF) protection has been enabled for all requests by default. + This will require that the `GraphQL-Require-Preflight` header be sent with all GET requests and + all form-POST requests. To disable this feature, set the `CsrfProtectionEnabled` property on the + `GraphQLMiddlewareOptions` class to `false`. You may also configure the headers list by modifying + the `CsrfProtectionHeaders` property on the same class. See the readme for more details. ## Other changes diff --git a/src/Transports.AspNetCore/Errors/CsrfProtectionError.cs b/src/Transports.AspNetCore/Errors/CsrfProtectionError.cs new file mode 100644 index 00000000..980c1249 --- /dev/null +++ b/src/Transports.AspNetCore/Errors/CsrfProtectionError.cs @@ -0,0 +1,16 @@ +namespace GraphQL.Server.Transports.AspNetCore.Errors; + +/// +/// Represents an error indicating that the request may not have triggered a CORS preflight request. +/// +public class CsrfProtectionError : RequestError +{ + /// + public CsrfProtectionError(IEnumerable headersRequired) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}.") { } + + /// + public CsrfProtectionError(IEnumerable headersRequired, Exception innerException) : base($"This request requires a non-empty header from the following list: {FormatHeaders(headersRequired)}. {innerException.Message}") { } + + private static string FormatHeaders(IEnumerable headersRequired) + => string.Join(", ", headersRequired.Select(x => $"'{x}'")); +} diff --git a/src/Transports.AspNetCore/GraphQLHttpMiddleware.cs b/src/Transports.AspNetCore/GraphQLHttpMiddleware.cs index adf2a4cf..a8ed6d37 100644 --- a/src/Transports.AspNetCore/GraphQLHttpMiddleware.cs +++ b/src/Transports.AspNetCore/GraphQLHttpMiddleware.cs @@ -142,6 +142,10 @@ protected virtual async Task InvokeAsync(HttpContext context, RequestDelegate ne return; } + // Perform CSRF protection if necessary + if (await HandleCsrfProtectionAsync(context, next)) + return; + // Authenticate request if necessary if (await HandleAuthorizeAsync(context, next)) return; @@ -484,7 +488,36 @@ static void ApplyFileToRequest(IFormFile file, string target, GraphQLRequest? re } /// - /// Perform authentication, if required, and return if the + /// Performs CSRF protection, if required, and returns if the + /// request was handled (typically by returning an error message). If + /// is returned, the request is processed normally. + /// + protected virtual async ValueTask HandleCsrfProtectionAsync(HttpContext context, RequestDelegate next) + { + if (!_options.CsrfProtectionEnabled) + return false; + if (context.Request.Headers.TryGetValue("Content-Type", out var contentTypes) && contentTypes.Count > 0 && contentTypes[0] != null) + { + var contentType = contentTypes[0]!; + if (contentType.IndexOf(';') > 0) + { + contentType = contentType.Substring(0, contentType.IndexOf(';')); + } + contentType = contentType.Trim().ToLowerInvariant(); + if (!(contentType == "text/plain" || contentType == "application/x-www-form-urlencoded" || contentType == "multipart/form-data")) + return false; + } + foreach (var header in _options.CsrfProtectionHeaders) + { + if (context.Request.Headers.TryGetValue(header, out var values) && values.Count > 0 && values[0]?.Length > 0) + return false; + } + await HandleCsrfProtectionErrorAsync(context, next); + return true; + } + + /// + /// Perform authentication, if required, and returns if the /// request was handled (typically by returning an error message). If /// is returned, the request is processed normally. /// @@ -1034,21 +1067,29 @@ protected virtual Task HandleNotAuthorizedPolicyAsync(HttpContext context, Reque /// protected virtual async ValueTask HandleDeserializationErrorAsync(HttpContext context, RequestDelegate next, Exception exception) { - await WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new JsonInvalidError(exception)); + await WriteErrorResponseAsync(context, new JsonInvalidError(exception)); return true; } + /// + /// Writes a '.' message to the output. + /// + protected virtual async Task HandleCsrfProtectionErrorAsync(HttpContext context, RequestDelegate next) + { + await WriteErrorResponseAsync(context, new CsrfProtectionError(_options.CsrfProtectionHeaders)); + } + /// /// Writes a '400 Batched requests are not supported.' message to the output. /// protected virtual Task HandleBatchedRequestsNotSupportedAsync(HttpContext context, RequestDelegate next) - => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new BatchedRequestsNotSupportedError()); + => WriteErrorResponseAsync(context, new BatchedRequestsNotSupportedError()); /// /// Writes a '400 Invalid requested WebSocket sub-protocol(s).' message to the output. /// protected virtual Task HandleWebSocketSubProtocolNotSupportedAsync(HttpContext context, RequestDelegate next) - => WriteErrorResponseAsync(context, HttpStatusCode.BadRequest, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols)); + => WriteErrorResponseAsync(context, new WebSocketSubProtocolNotSupportedError(context.WebSockets.WebSocketRequestedProtocols)); /// /// Writes a '415 Invalid Content-Type header: could not be parsed.' message to the output. @@ -1079,6 +1120,12 @@ protected virtual Task HandleInvalidHttpMethodErrorAsync(HttpContext context, Re return next(context); } + /// + /// Writes the specified error as a JSON-formatted GraphQL response. + /// + protected virtual Task WriteErrorResponseAsync(HttpContext context, ExecutionError executionError) + => WriteErrorResponseAsync(context, executionError is IHasPreferredStatusCode withCode ? withCode.PreferredStatusCode : HttpStatusCode.BadRequest, executionError); + /// /// Writes the specified error message as a JSON-formatted GraphQL response, with the specified HTTP status code. /// diff --git a/src/Transports.AspNetCore/GraphQLHttpMiddlewareOptions.cs b/src/Transports.AspNetCore/GraphQLHttpMiddlewareOptions.cs index 1bb01339..6936568d 100644 --- a/src/Transports.AspNetCore/GraphQLHttpMiddlewareOptions.cs +++ b/src/Transports.AspNetCore/GraphQLHttpMiddlewareOptions.cs @@ -70,6 +70,22 @@ public class GraphQLHttpMiddlewareOptions : IAuthorizationOptions /// public bool ReadFormOnPost { get; set; } = true; // TODO: change to false for v9 + /// + /// Enables cross-site request forgery (CSRF) protection for both GET and POST requests. + /// Requires a non-empty header from the list to be + /// present, or a POST request with a Content-Type header that is not text/plain, + /// application/x-www-form-urlencoded, or multipart/form-data. + /// + public bool CsrfProtectionEnabled { get; set; } = true; + + /// + /// When is enabled, requests require a non-empty + /// header from this list or a POST request with a Content-Type header that is not + /// text/plain, application/x-www-form-urlencoded, or multipart/form-data. + /// Defaults to GraphQL-Require-Preflight. + /// + public List CsrfProtectionHeaders { get; set; } = ["GraphQL-Require-Preflight"]; // see https://github.com/graphql/graphql-over-http/pull/303 + /// /// Enables reading variables from the query string. /// Variables are interpreted as JSON and deserialized before being diff --git a/tests/ApiApprovalTests/net50+net60+net80/GraphQL.Server.Transports.AspNetCore.approved.txt b/tests/ApiApprovalTests/net50+net60+net80/GraphQL.Server.Transports.AspNetCore.approved.txt index d4eda1ba..d72fac9e 100644 --- a/tests/ApiApprovalTests/net50+net60+net80/GraphQL.Server.Transports.AspNetCore.approved.txt +++ b/tests/ApiApprovalTests/net50+net60+net80/GraphQL.Server.Transports.AspNetCore.approved.txt @@ -99,6 +99,8 @@ namespace GraphQL.Server.Transports.AspNetCore protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList gqlRequests) { } protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.ValueTask HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.ValueTask HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { } protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } @@ -115,6 +117,7 @@ namespace GraphQL.Server.Transports.AspNetCore "BatchRequest"})] protected virtual System.Threading.Tasks.Task?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { } protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { } + protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { } protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { } @@ -126,6 +129,8 @@ namespace GraphQL.Server.Transports.AspNetCore public bool AuthorizationRequired { get; set; } public string? AuthorizedPolicy { get; set; } public System.Collections.Generic.List AuthorizedRoles { get; set; } + public bool CsrfProtectionEnabled { get; set; } + public System.Collections.Generic.List CsrfProtectionHeaders { get; set; } public Microsoft.Net.Http.Headers.MediaTypeHeaderValue DefaultResponseContentType { get; set; } public bool EnableBatchedRequests { get; set; } public bool ExecuteBatchedRequestsInParallel { get; set; } @@ -199,6 +204,11 @@ namespace GraphQL.Server.Transports.AspNetCore.Errors { public BatchedRequestsNotSupportedError() { } } + public class CsrfProtectionError : GraphQL.Execution.RequestError + { + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired) { } + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired, System.Exception innerException) { } + } public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.Server.Transports.AspNetCore.Errors.IHasPreferredStatusCode { public FileCountExceededError() { } diff --git a/tests/ApiApprovalTests/netcoreapp21+netstandard20/GraphQL.Server.Transports.AspNetCore.approved.txt b/tests/ApiApprovalTests/netcoreapp21+netstandard20/GraphQL.Server.Transports.AspNetCore.approved.txt index ffa1da76..3de02297 100644 --- a/tests/ApiApprovalTests/netcoreapp21+netstandard20/GraphQL.Server.Transports.AspNetCore.approved.txt +++ b/tests/ApiApprovalTests/netcoreapp21+netstandard20/GraphQL.Server.Transports.AspNetCore.approved.txt @@ -106,6 +106,8 @@ namespace GraphQL.Server.Transports.AspNetCore protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList gqlRequests) { } protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.ValueTask HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.ValueTask HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { } protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } @@ -122,6 +124,7 @@ namespace GraphQL.Server.Transports.AspNetCore "BatchRequest"})] protected virtual System.Threading.Tasks.Task?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { } protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { } + protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { } protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { } @@ -133,6 +136,8 @@ namespace GraphQL.Server.Transports.AspNetCore public bool AuthorizationRequired { get; set; } public string? AuthorizedPolicy { get; set; } public System.Collections.Generic.List AuthorizedRoles { get; set; } + public bool CsrfProtectionEnabled { get; set; } + public System.Collections.Generic.List CsrfProtectionHeaders { get; set; } public Microsoft.Net.Http.Headers.MediaTypeHeaderValue DefaultResponseContentType { get; set; } public bool EnableBatchedRequests { get; set; } public bool ExecuteBatchedRequestsInParallel { get; set; } @@ -217,6 +222,11 @@ namespace GraphQL.Server.Transports.AspNetCore.Errors { public BatchedRequestsNotSupportedError() { } } + public class CsrfProtectionError : GraphQL.Execution.RequestError + { + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired) { } + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired, System.Exception innerException) { } + } public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.Server.Transports.AspNetCore.Errors.IHasPreferredStatusCode { public FileCountExceededError() { } diff --git a/tests/ApiApprovalTests/netcoreapp31/GraphQL.Server.Transports.AspNetCore.approved.txt b/tests/ApiApprovalTests/netcoreapp31/GraphQL.Server.Transports.AspNetCore.approved.txt index 70369b81..2122e599 100644 --- a/tests/ApiApprovalTests/netcoreapp31/GraphQL.Server.Transports.AspNetCore.approved.txt +++ b/tests/ApiApprovalTests/netcoreapp31/GraphQL.Server.Transports.AspNetCore.approved.txt @@ -99,6 +99,8 @@ namespace GraphQL.Server.Transports.AspNetCore protected virtual System.Threading.Tasks.Task HandleBatchRequestAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Collections.Generic.IList gqlRequests) { } protected virtual System.Threading.Tasks.Task HandleBatchedRequestsNotSupportedAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleContentTypeCouldNotBeParsedErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.ValueTask HandleCsrfProtectionAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } + protected virtual System.Threading.Tasks.Task HandleCsrfProtectionErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.ValueTask HandleDeserializationErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, System.Exception exception) { } protected virtual System.Threading.Tasks.Task HandleInvalidContentTypeErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } protected virtual System.Threading.Tasks.Task HandleInvalidHttpMethodErrorAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next) { } @@ -115,6 +117,7 @@ namespace GraphQL.Server.Transports.AspNetCore "BatchRequest"})] protected virtual System.Threading.Tasks.Task?>?> ReadPostContentAsync(Microsoft.AspNetCore.Http.HttpContext context, Microsoft.AspNetCore.Http.RequestDelegate next, string? mediaType, System.Text.Encoding? sourceEncoding) { } protected virtual string SelectResponseContentType(Microsoft.AspNetCore.Http.HttpContext context) { } + protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, GraphQL.ExecutionError executionError) { } protected virtual System.Threading.Tasks.Task WriteErrorResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, string errorMessage) { } protected virtual System.Threading.Tasks.Task WriteJsonResponseAsync(Microsoft.AspNetCore.Http.HttpContext context, System.Net.HttpStatusCode httpStatusCode, TResult result) { } @@ -126,6 +129,8 @@ namespace GraphQL.Server.Transports.AspNetCore public bool AuthorizationRequired { get; set; } public string? AuthorizedPolicy { get; set; } public System.Collections.Generic.List AuthorizedRoles { get; set; } + public bool CsrfProtectionEnabled { get; set; } + public System.Collections.Generic.List CsrfProtectionHeaders { get; set; } public Microsoft.Net.Http.Headers.MediaTypeHeaderValue DefaultResponseContentType { get; set; } public bool EnableBatchedRequests { get; set; } public bool ExecuteBatchedRequestsInParallel { get; set; } @@ -199,6 +204,11 @@ namespace GraphQL.Server.Transports.AspNetCore.Errors { public BatchedRequestsNotSupportedError() { } } + public class CsrfProtectionError : GraphQL.Execution.RequestError + { + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired) { } + public CsrfProtectionError(System.Collections.Generic.IEnumerable headersRequired, System.Exception innerException) { } + } public class FileCountExceededError : GraphQL.Execution.RequestError, GraphQL.Server.Transports.AspNetCore.Errors.IHasPreferredStatusCode { public FileCountExceededError() { } diff --git a/tests/Samples.AzureFunctions.Tests/EndToEndTests.cs b/tests/Samples.AzureFunctions.Tests/EndToEndTests.cs index 714baa5b..dbeab21c 100644 --- a/tests/Samples.AzureFunctions.Tests/EndToEndTests.cs +++ b/tests/Samples.AzureFunctions.Tests/EndToEndTests.cs @@ -17,6 +17,7 @@ public async Task GraphQLGet() { request.Method = "GET"; request.QueryString = new QueryString("?query={count}"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); }, GraphQL.RunGraphQL); statusCode.ShouldBe(200); diff --git a/tests/Samples.Complex.Tests/BaseTest.cs b/tests/Samples.Complex.Tests/BaseTest.cs index 19471189..3c7c2df8 100644 --- a/tests/Samples.Complex.Tests/BaseTest.cs +++ b/tests/Samples.Complex.Tests/BaseTest.cs @@ -36,6 +36,10 @@ protected Task SendRequestAsync(HttpMethod httpMethod, Http { Content = httpContent }; + if (httpMethod == HttpMethod.Get) + { + request.Headers.Add("GraphQL-Require-Preflight", "true"); + } return Client.SendAsync(request); } @@ -87,7 +91,11 @@ protected async Task SendRequestAsync(GraphQLRequest request, RequestTyp case RequestType.Get: // Details passed in query string urlWithParams = url + "?" + await Serializer.ToQueryStringParamsAsync(request); - response = await Client.GetAsync(urlWithParams); + using (var getRequest = new HttpRequestMessage(HttpMethod.Get, urlWithParams)) + { + getRequest.Headers.Add("GraphQL-Require-Preflight", "true"); + response = await Client.SendAsync(getRequest); + } break; case RequestType.PostWithJson: @@ -113,7 +121,12 @@ protected async Task SendRequestAsync(GraphQLRequest request, RequestTyp case RequestType.PostWithForm: // Details passed in form body as form url encoded, with url query string params also allowed var formContent = Serializer.ToFormUrlEncodedContent(request); - response = await Client.PostAsync(url, formContent); + using (var formPostRequest = new HttpRequestMessage(HttpMethod.Post, url)) + { + formPostRequest.Content = formContent; + formPostRequest.Headers.Add("GraphQL-Require-Preflight", "true"); + response = await Client.SendAsync(formPostRequest); + } break; default: diff --git a/tests/Samples.Tests/TestServerExtensions.cs b/tests/Samples.Tests/TestServerExtensions.cs index 9ea4a9a1..0369ed20 100644 --- a/tests/Samples.Tests/TestServerExtensions.cs +++ b/tests/Samples.Tests/TestServerExtensions.cs @@ -36,6 +36,7 @@ public static async Task VerifyGraphQLGetAsync( { using var client = server.CreateClient(); using var request = new HttpRequestMessage(HttpMethod.Get, url + "?query=" + Uri.EscapeDataString(query)); + request.Headers.Add("GraphQL-Require-Preflight", "true"); if (jwtToken != null) request.Headers.Authorization = new("Bearer", jwtToken); using var response = await client.SendAsync(request); diff --git a/tests/Samples.Upload.Tests/EndToEndTests.cs b/tests/Samples.Upload.Tests/EndToEndTests.cs index b37a41b6..dd8436c3 100644 --- a/tests/Samples.Upload.Tests/EndToEndTests.cs +++ b/tests/Samples.Upload.Tests/EndToEndTests.cs @@ -33,6 +33,7 @@ public async Task RotateImage() form.Add(triangleContent, "file0", "triangle.jpg"); using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql"); request.Content = form; + request.Headers.Add("GraphQL-Require-Preflight", "true"); using var response = await client.SendAsync(request); response.StatusCode.ShouldBe(HttpStatusCode.OK); var ret = await response.Content.ReadAsStringAsync(); @@ -66,6 +67,7 @@ public async Task RotateImage_WrongType() form.Add(triangleContent, "file0", "hello-world.txt"); using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql"); request.Content = form; + request.Headers.Add("GraphQL-Require-Preflight", "true"); using var response = await client.SendAsync(request); response.StatusCode.ShouldBe(HttpStatusCode.BadRequest); var ret = await response.Content.ReadAsStringAsync(); diff --git a/tests/Transports.AspNetCore.Tests/AuthorizationTests.cs b/tests/Transports.AspNetCore.Tests/AuthorizationTests.cs index bc83b03f..f364b81e 100644 --- a/tests/Transports.AspNetCore.Tests/AuthorizationTests.cs +++ b/tests/Transports.AspNetCore.Tests/AuthorizationTests.cs @@ -759,7 +759,7 @@ public async Task EndToEnd(bool authenticated) context.User = _principal; return next(context); }); - app.UseGraphQL(); + app.UseGraphQL(configureMiddleware: c => c.CsrfProtectionEnabled = false); }); using var server = new TestServer(hostBuilder); diff --git a/tests/Transports.AspNetCore.Tests/BuilderMethodTests.cs b/tests/Transports.AspNetCore.Tests/BuilderMethodTests.cs index e26f703a..bced3a2d 100644 --- a/tests/Transports.AspNetCore.Tests/BuilderMethodTests.cs +++ b/tests/Transports.AspNetCore.Tests/BuilderMethodTests.cs @@ -373,7 +373,7 @@ private async Task VerifyUserContextAsync(string value) _hostBuilder.Configure(app => { app.UseWebSockets(); - app.UseGraphQL(); + app.UseGraphQL(configureMiddleware: c => c.CsrfProtectionEnabled = false); }); using var server = new TestServer(_hostBuilder); var str = await server.ExecuteGet("/graphql?query={userInfo}"); diff --git a/tests/Transports.AspNetCore.Tests/Middleware/AuthorizationTests.cs b/tests/Transports.AspNetCore.Tests/Middleware/AuthorizationTests.cs index a042a02f..07a018f0 100644 --- a/tests/Transports.AspNetCore.Tests/Middleware/AuthorizationTests.cs +++ b/tests/Transports.AspNetCore.Tests/Middleware/AuthorizationTests.cs @@ -107,7 +107,9 @@ public async Task NotAuthorized_Get() { _options.AuthorizationRequired = true; var client = _server.CreateClient(); - using var response = await client.GetAsync("/graphql?query={ __typename }"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={ __typename }"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); response.StatusCode.ShouldBe(HttpStatusCode.Unauthorized); var actual = await response.Content.ReadAsStringAsync(); actual.ShouldBe("""{"errors":[{"message":"Access denied for schema.","extensions":{"code":"ACCESS_DENIED","codes":["ACCESS_DENIED"]}}]}"""); diff --git a/tests/Transports.AspNetCore.Tests/Middleware/Cors/EndpointTests.cs b/tests/Transports.AspNetCore.Tests/Middleware/Cors/EndpointTests.cs index 951bb452..c2fb00e5 100644 --- a/tests/Transports.AspNetCore.Tests/Middleware/Cors/EndpointTests.cs +++ b/tests/Transports.AspNetCore.Tests/Middleware/Cors/EndpointTests.cs @@ -90,7 +90,7 @@ public async Task NoCorsConfig(string httpMethod, string url) httpMethod == "POST" ? HttpMethod.Post : httpMethod == "OPTIONS" ? HttpMethod.Options : httpMethod == "GET" ? HttpMethod.Get : throw new ArgumentOutOfRangeException(nameof(httpMethod)), configureCors: _ => { }, configureCorsPolicy: _ => { }, - configureGraphQl: _ => { }, + configureGraphQl: o => o.CsrfProtectionEnabled = false, configureGraphQlEndpoint: _ => { }, configureHeaders: headers => { diff --git a/tests/Transports.AspNetCore.Tests/Middleware/FileUploadTests.cs b/tests/Transports.AspNetCore.Tests/Middleware/FileUploadTests.cs index 7da724d9..8bffb0d9 100644 --- a/tests/Transports.AspNetCore.Tests/Middleware/FileUploadTests.cs +++ b/tests/Transports.AspNetCore.Tests/Middleware/FileUploadTests.cs @@ -50,7 +50,10 @@ public async Task Basic(bool withOtherVariables) var fileContent = new ByteArrayContent(fileData); fileContent.Headers.ContentType = new("application/octet-stream"); content.Add(fileContent, "file", "filename.bin"); - using var response = await client.PostAsync("/graphql", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql"); + request.Content = content; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); if (withOtherVariables) { await response.ShouldBeAsync("""{"data":{"convertToBase64":"pre-filename.bin-YWJjZA=="}}"""); diff --git a/tests/Transports.AspNetCore.Tests/Middleware/GetTests.cs b/tests/Transports.AspNetCore.Tests/Middleware/GetTests.cs index 0a5265b0..751ad347 100644 --- a/tests/Transports.AspNetCore.Tests/Middleware/GetTests.cs +++ b/tests/Transports.AspNetCore.Tests/Middleware/GetTests.cs @@ -31,10 +31,12 @@ public GetTests() app.UseWebSockets(); app.UseGraphQL("/graphql", opts => { + opts.CsrfProtectionEnabled = false; _options = opts; }); app.UseGraphQL("/graphql2", opts => { + opts.CsrfProtectionEnabled = false; _options2 = opts; }); }); @@ -65,6 +67,48 @@ public async Task BasicTest() await response.ShouldBeAsync("""{"data":{"count":0}}"""); } + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task CsrfBasicTests(bool requireCsrf, bool sendCsrf) + { + _options.CsrfProtectionEnabled = requireCsrf; + var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + if (sendCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (requireCsrf && !sendCsrf) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"count":0}}"""); + } + + [Theory] + [InlineData(null, null, false)] + [InlineData("Header1", "true", true)] + [InlineData("Header1", "", false)] + [InlineData("Header1", null, false)] + [InlineData("Header2", "true", true)] + [InlineData("Header3", "true", false)] + [InlineData("GraphQL-Require-Preflight", "true", false)] + public async Task CsrfCustomTests(string? header, string? value, bool success) + { + _options.CsrfProtectionEnabled = true; + _options.CsrfProtectionHeaders = ["Header1", "Header2"]; + var client = _server.CreateClient(); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + if (header != null) + request.Headers.Add(header, value); + using var response = await client.SendAsync(request); + if (!success) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027Header1\u0027, \u0027Header2\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"count":0}}"""); + } + [Theory] [InlineData(null, "application/graphql+json", "application/graphql+json; charset=utf-8")] [InlineData(null, "application/json", "application/json; charset=utf-8")] @@ -180,7 +224,9 @@ public async Task NoUseWebSockets() using var server = new TestServer(hostBuilder); var client = server.CreateClient(); - using var response = await client.GetAsync("/graphql?query={count}"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={count}"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync("""{"data":{"count":0}}"""); } diff --git a/tests/Transports.AspNetCore.Tests/Middleware/PostTests.cs b/tests/Transports.AspNetCore.Tests/Middleware/PostTests.cs index fe777f75..fd16248b 100644 --- a/tests/Transports.AspNetCore.Tests/Middleware/PostTests.cs +++ b/tests/Transports.AspNetCore.Tests/Middleware/PostTests.cs @@ -147,9 +147,15 @@ public async Task AltCharset_Invalid() } #endif - [Fact] - public async Task FormMultipart_Legacy() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormMultipart_Legacy(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); var content = new MultipartFormDataContent(); var queryContent = new StringContent("query op1{ext} query op2($test:String!){ext var(test:$test)}"); @@ -164,13 +170,25 @@ public async Task FormMultipart_Legacy() content.Add(variablesContent, "variables"); content.Add(extensionsContent, "extensions"); content.Add(operationNameContent, "operationName"); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (!requireCsrf || supplyCsrf) + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + else + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); } - [Fact] - public async Task FormMultipart_Upload() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormMultipart_Upload(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); using var content = new MultipartFormDataContent(); var jsonContent = new StringContent(""" @@ -182,8 +200,14 @@ public async Task FormMultipart_Upload() } """, Encoding.UTF8, "application/json"); content.Add(jsonContent, "operations"); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (!requireCsrf || supplyCsrf) + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + else + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); } // successful queries @@ -345,7 +369,9 @@ public async Task FormMultipart_Upload_Matrix(int testIndex, string? operations, content.Add(new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt"); if (file1) content.Add(new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html"); - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync((HttpStatusCode)expectedStatusCode, expectedResponse); } @@ -366,13 +392,21 @@ public async Task FormMultipart_Upload_Validation(int? maxFileCount, int? maxFil { new StringContent("test1", Encoding.UTF8, "text/text"), "file0", "example1.txt" }, { new StringContent("test2", Encoding.UTF8, "text/html"), "file1", "example2.html" } }; - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); await response.ShouldBeAsync(expectedStatusCode, expectedResponse); } - [Fact] - public async Task FormUrlEncoded() + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task FormUrlEncoded(bool requireCsrf, bool supplyCsrf) { + if (!requireCsrf) + _options2.CsrfProtectionEnabled = false; var client = _server.CreateClient(); var content = new FormUrlEncodedContent(new[] { new KeyValuePair("query", "query op1{ext} query op2($test:String!){ext var(test:$test)}"), @@ -380,8 +414,14 @@ public async Task FormUrlEncoded() new KeyValuePair("extensions", """{"test":"2"}"""), new KeyValuePair("operationName", "op2"), }); - using var response = await client.PostAsync("/graphql2", content); - await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + if (supplyCsrf) + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); + if (requireCsrf && !supplyCsrf) + await response.ShouldBeAsync(true, """{"errors":[{"message":"This request requires a non-empty header from the following list: \u0027GraphQL-Require-Preflight\u0027.","extensions":{"code":"CSRF_PROTECTION","codes":["CSRF_PROTECTION"]}}]}"""); + else + await response.ShouldBeAsync("""{"data":{"ext":"2","var":"1"}}"""); } [Theory] @@ -395,7 +435,9 @@ public async Task FormUrlEncoded_DeserializationError(bool badRequest) new KeyValuePair("query", "{ext}"), new KeyValuePair("variables", "{"), }); - using var response = await client.PostAsync("/graphql2", content); + using var request = new HttpRequestMessage(HttpMethod.Post, "/graphql2") { Content = content }; + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); // always returns BadRequest here await response.ShouldBeAsync(true, """{"errors":[{"message":"JSON body text could not be parsed. Expected depth to be zero at the end of the JSON payload. There is an open JSON object or array that should be closed. Path: $ | LineNumber: 0 | BytePositionInLine: 1.","extensions":{"code":"JSON_INVALID","codes":["JSON_INVALID"]}}]}"""); } @@ -435,6 +477,7 @@ public async Task ContentType_GraphQLJson(string contentType) [InlineData(true, false, "application/x-www-form-urlencoded")] public async Task UnknownContentType(bool badRequest, bool allowFormBody, string contentType) { + _options.CsrfProtectionEnabled = false; _options.ValidationErrorsReturnBadRequest = badRequest; _options.ReadFormOnPost = allowFormBody; var client = _server.CreateClient(); @@ -467,6 +510,7 @@ public async Task CannotParseContentType(bool badRequest) var client = _server.CreateClient(); var content = new StringContent(""); content.Headers.ContentType = null; + content.Headers.Add("GraphQL-Require-Preflight", "true"); var response = await client.PostAsync("/graphql2", content); // always returns unsupported media type response.StatusCode.ShouldBe(HttpStatusCode.UnsupportedMediaType); diff --git a/tests/Transports.AspNetCore.Tests/TestServerExtensions.cs b/tests/Transports.AspNetCore.Tests/TestServerExtensions.cs index 07b78f00..fb936b43 100644 --- a/tests/Transports.AspNetCore.Tests/TestServerExtensions.cs +++ b/tests/Transports.AspNetCore.Tests/TestServerExtensions.cs @@ -5,7 +5,9 @@ internal static class TestServerExtensions public static async Task ExecuteGet(this TestServer server, string url) { var client = server.CreateClient(); - using var response = await client.GetAsync(url); + using var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await client.SendAsync(request); response.EnsureSuccessStatusCode(); var contentType = response.Content.Headers.ContentType; contentType.ShouldNotBeNull(); diff --git a/tests/Transports.AspNetCore.Tests/UserContextBuilderTests.cs b/tests/Transports.AspNetCore.Tests/UserContextBuilderTests.cs index bd910081..4059b42a 100644 --- a/tests/Transports.AspNetCore.Tests/UserContextBuilderTests.cs +++ b/tests/Transports.AspNetCore.Tests/UserContextBuilderTests.cs @@ -107,7 +107,9 @@ public async Task Async_Payload_Works() private async Task Test(string name) { - using var response = await _client.GetAsync("/graphql?query={test}"); + using var request = new HttpRequestMessage(HttpMethod.Get, "/graphql?query={test}"); + request.Headers.Add("GraphQL-Require-Preflight", "true"); + using var response = await _client.SendAsync(request); response.EnsureSuccessStatusCode(); var actual = await response.Content.ReadAsStringAsync(); actual.ShouldBe(@"{""data"":{""test"":""" + name + @"""}}");