From d4f1555da7c72f15e97c56d8bac4a3f2eeb64e6b Mon Sep 17 00:00:00 2001 From: Matthew Kelly Date: Fri, 12 Jun 2020 21:11:09 +0100 Subject: [PATCH] #573: support for cert thumbprint, and move the opening of socket stream --- src/Listener/PodeContext.cs | 33 +++++++++---- src/Listener/PodeContextState.cs | 3 +- src/Listener/PodeRequest.cs | 21 ++------- src/Listener/PodeSocket.cs | 79 +++++++++++--------------------- src/Public/Core.ps1 | 57 ++++++++++++++++++----- 5 files changed, 102 insertions(+), 91 deletions(-) diff --git a/src/Listener/PodeContext.cs b/src/Listener/PodeContext.cs index 0e4161cf6..fa5c7841f 100644 --- a/src/Listener/PodeContext.cs +++ b/src/Listener/PodeContext.cs @@ -1,5 +1,6 @@ using System; using System.Collections; +using System.IO; using System.Net.Http; using System.Net.Sockets; using System.Security.Cryptography; @@ -22,7 +23,9 @@ public class PodeContext : IDisposable public bool CloseImmediately { - get => (State == PodeContextState.Error || string.IsNullOrWhiteSpace(Request.HttpMethod)); + get => (State == PodeContextState.Error + || string.IsNullOrWhiteSpace(Request.HttpMethod) + || (IsWebSocket && !Request.HttpMethod.Equals("GET", StringComparison.InvariantCultureIgnoreCase))); } public bool IsWebSocket @@ -30,6 +33,11 @@ public bool IsWebSocket get => (Type == PodeContextType.WebSocket); } + public bool IsErrored + { + get => (State == PodeContextState.Error || State == PodeContextState.SslError); + } + public PodeContext(Socket socket, PodeSocket podeSocket, PodeListener listener) { @@ -43,8 +51,8 @@ public PodeContext(Socket socket, PodeSocket podeSocket, PodeListener listener) Type = PodeContextType.Unknown; State = PodeContextState.New; - NewRequest(); NewResponse(); + NewRequest(); } private void NewResponse() @@ -67,12 +75,10 @@ private void NewRequest() } catch { - State = PodeContextState.Error; + State = (Request.InputStream == default(Stream) + ? PodeContextState.Error + : PodeContextState.SslError); } - - // attempt to receive data from the request stream - Receive(); - SetContextType(); } private void SetContextType() @@ -109,6 +115,8 @@ public void Receive() { State = PodeContextState.Error; } + + SetContextType(); } public void StartReceive() @@ -171,7 +179,16 @@ public void Dispose(bool force) // send the response and close, only close request if not keep alive try { - Response.Send(); + if (IsErrored) + { + Response.StatusCode = 500; + } + + if (State != PodeContextState.SslError) + { + Response.Send(); + } + Response.Dispose(); if (!IsKeepAlive || force) diff --git a/src/Listener/PodeContextState.cs b/src/Listener/PodeContextState.cs index e20b3e10a..5d0cdf91a 100644 --- a/src/Listener/PodeContextState.cs +++ b/src/Listener/PodeContextState.cs @@ -7,6 +7,7 @@ public enum PodeContextState Receiving, Received, Closed, - Error + Error, + SslError } } \ No newline at end of file diff --git a/src/Listener/PodeRequest.cs b/src/Listener/PodeRequest.cs index d02ab3d9d..4812f41e4 100644 --- a/src/Listener/PodeRequest.cs +++ b/src/Listener/PodeRequest.cs @@ -45,6 +45,7 @@ public PodeRequest(Socket socket) { Socket = socket; RemoteEndPoint = socket.RemoteEndPoint; + Protocol = "HTTP/1.1"; } public void Open(X509Certificate certificate, SslProtocols protocols) @@ -53,16 +54,15 @@ public void Open(X509Certificate certificate, SslProtocols protocols) IsSsl = (certificate != default(X509Certificate)); // open the socket's stream - var stream = new NetworkStream(Socket, true); + InputStream = new NetworkStream(Socket, true); if (!IsSsl) { // if not ssl, use the main network stream - InputStream = stream; return; } // otherwise, convert the stream to an ssl stream - var ssl = new SslStream(stream, false, new RemoteCertificateValidationCallback(ValidateCertificateCallback)); + var ssl = new SslStream(InputStream, false, new RemoteCertificateValidationCallback(ValidateCertificateCallback)); ssl.AuthenticateAsServer(certificate, false, protocols, false); InputStream = ssl; } @@ -82,22 +82,7 @@ public void Receive() try { Error = default(HttpRequestException); - var allBytes = new List(); - if (IsSsl) - { - try - { - // the stream gets reset on ssl upgrade - Socket.Receive(new byte[0]); - } - catch - { - var err = new HttpRequestException(); - err.Data.Add("PodeStatusCode", 408); - throw err; - } - } while (Socket.Available > 0) { diff --git a/src/Listener/PodeSocket.cs b/src/Listener/PodeSocket.cs index a022e0095..b19935de0 100644 --- a/src/Listener/PodeSocket.cs +++ b/src/Listener/PodeSocket.cs @@ -89,19 +89,23 @@ public void Start() } } - public void StartReceive(PodeContext context) + private void StartReceive(Socket acceptedSocket) { - var args = GetReceiveConnection(); - args.AcceptSocket = context.Socket; - args.UserToken = context; - StartReceive(args); + var context = new PodeContext(acceptedSocket, this, Listener); + if (context.IsErrored) + { + context.Dispose(true); + return; + } + + StartReceive(context); } - private void StartReceive(Socket acceptedSocket) + public void StartReceive(PodeContext context) { var args = GetReceiveConnection(); - args.AcceptSocket = acceptedSocket; - args.UserToken = this; + args.AcceptSocket = context.Socket; + args.UserToken = context; StartReceive(args); } @@ -170,13 +174,9 @@ private void ProcessReceive(SocketAsyncEventArgs args) { // get details var received = args.AcceptSocket; - var token = args.UserToken; + var context = (PodeContext)args.UserToken; var error = args.SocketError; - var isContext = (token is PodeContext); - var context = (isContext ? (PodeContext)token : default(PodeContext)); - var socket = (isContext ? default(PodeSocket) : (PodeSocket)token); - // remove the socket from pending RemovePendingSocket(received); @@ -190,10 +190,7 @@ private void ProcessReceive(SocketAsyncEventArgs args) } // close the context - if (isContext) - { - context.Dispose(true); - } + context.Dispose(true); // add args back to connections ClearSocketAsyncEvent(args); @@ -206,46 +203,22 @@ private void ProcessReceive(SocketAsyncEventArgs args) // add context to be processed? var process = true; - // deal with existing context - if (isContext) + // deal with context + context.Receive(); + + // if we need to exit now, dispose and exit + if (context.CloseImmediately) { - context.Receive(); - - // if we need to exit now, dispose and exit - if (context.CloseImmediately) - { - PodeHelpers.WriteException(context.Request.Error, Listener); - context.Dispose(true); - process = false; - } + PodeHelpers.WriteException(context.Request.Error, Listener); + context.Dispose(true); + process = false; } - // else, create a new context - else + // if it's a websocket, upgrade it + else if (context.IsWebSocket) { - context = new PodeContext(received, socket, Listener); - - // if we need to exit now, dispose and exit - if (context.CloseImmediately) - { - PodeHelpers.WriteException(context.Request.Error, Listener); - context.Dispose(true); - process = false; - } - - // if websocket, and httpmethod != GET, close! - else if (context.IsWebSocket && !context.Request.HttpMethod.Equals("GET", StringComparison.InvariantCultureIgnoreCase)) - { - context.Dispose(true); - process = false; - } - - // if it's a websocket, upgrade it - else if (context.IsWebSocket) - { - context.UpgradeWebSocket(); - process = false; - } + context.UpgradeWebSocket(); + process = false; } // add the context for processing diff --git a/src/Public/Core.ps1 b/src/Public/Core.ps1 index 8352cde7b..69d9aa042 100644 --- a/src/Public/Core.ps1 +++ b/src/Public/Core.ps1 @@ -622,19 +622,16 @@ The Port number of the endpoint. The protocol of the supplied endpoint. .PARAMETER Certificate -A certificate name to find and bind onto HTTPS endpoints (Windows only). - -.PARAMETER CertificateThumbprint -A certificate thumbprint to bind onto HTTPS endpoints (Windows only). - -.PARAMETER CertificateFile -The path to a certificate that can be use to enable HTTPS (Cross-platform) +The path to a certificate that can be use to enable HTTPS .PARAMETER CertificatePassword -The password for the certificate referenced in CertificateFile (Cross-platform) +The password for the certificate file referenced in Certificate -.PARAMETER RawCertificate -The raw X509 certificate that can be use to enable HTTPS (Cross-platform) +.PARAMETER CertificateThumbprint +A certificate thumbprint to bind onto HTTPS endpoints (Windows). + +.PARAMETER X509Certificate +The raw X509 certificate that can be use to enable HTTPS .PARAMETER Name An optional name for the endpoint, that can be used with other functions. @@ -649,7 +646,7 @@ A quick description of the Endpoint - normally used in OpenAPI. Ignore Adminstrator checks for non-localhost endpoints. .PARAMETER SelfSigned -Create and bind a self-signed certifcate onto HTTPS endpoints (Windows only). +Create and bind a self-signed certifcate for HTTPS endpoints. .EXAMPLE Add-PodeEndpoint -Address localhost -Port 8090 -Protocol Http @@ -688,6 +685,10 @@ function Add-PodeEndpoint [string] $CertificatePassword = $null, + [Parameter(Mandatory=$true, ParameterSetName='CertThumb')] + [string] + $CertificateThumbprint, + [Parameter(Mandatory=$true, ParameterSetName='CertRaw')] [Parameter()] [X509Certificate] @@ -810,6 +811,40 @@ function Add-PodeEndpoint } } + # if we're dealing with a certificate thumbprint, attempt to retrieve it (if windows) + if (!$isIIS -and !$isHeroku -and ($PSCmdlet.ParameterSetName -ieq 'certthumb')) { + # fail if protocol is not https + if (@('https', 'wss') -inotcontains $Protocol) { + throw "Certificate thumbprint supplied for non-HTTPS/WSS endpoint" + } + + # fail if not windows + if (!(Test-IsWindows)) { + throw "Certificate thumbprints are only supported on Windows" + } + + $x509store = [System.Security.Cryptography.X509Certificates.X509Store]::new( + [System.Security.Cryptography.X509Certificates.StoreName]::My, + [System.Security.Cryptography.X509Certificates.StoreLocation]::CurrentUser + ) + + $x509store.Open([System.Security.Cryptography.X509Certificates.OpenFlags]::ReadOnly) + + $x509certs = $x509store.Certificates.Find( + [System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, + $CertificateThumbprint, + $false + ) + + Close-PodeDisposable -Disposable $x509store -Close + + if (($null -eq $x509certs) -or ($x509certs.Count -eq 0)) { + throw "No certificate could be found in CurrentUser\My for Thumbprint: $($CertificateThumbprint)" + } + + $obj.Certificate.Raw = [X509Certificate2]($x509certs[0]) + } + # if we're dealing with a self-signed certificate, create it if (!$isIIS -and !$isHeroku -and $SelfSigned) { # fail if protocol is not https