diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 2ecc2a3932abff..30280e347a68f4 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -188,94 +188,46 @@ private SecurityStatusPal EncryptData(ReadOnlyMemory buffer, ref byte[] ou // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private Task? ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) + private Task ProcessAuthenticationAsync(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) { ThrowIfExceptional(); if (NetSecurityTelemetry.Log.IsEnabled()) { - return ProcessAuthenticationWithTelemetry(isAsync, isApm, cancellationToken); + return ProcessAuthenticationWithTelemetryAsync(isAsync, isApm, cancellationToken); } else { - if (isAsync) - { - return ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm); - } - else - { - ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null).GetAwaiter().GetResult(); - return null; - } + return isAsync ? + ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm) : + ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null); } } - private Task? ProcessAuthenticationWithTelemetry(bool isAsync, bool isApm, CancellationToken cancellationToken) + private async Task ProcessAuthenticationWithTelemetryAsync(bool isAsync, bool isApm, CancellationToken cancellationToken) { NetSecurityTelemetry.Log.HandshakeStart(_context!.IsServer, _sslAuthenticationOptions!.TargetHost); ValueStopwatch stopwatch = ValueStopwatch.StartNew(); try { - if (isAsync) - { - Task task = ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context.IsServer, null, isApm); + Task task = isAsync? + ForceAuthenticationAsync(new AsyncReadWriteAdapter(InnerStream, cancellationToken), _context!.IsServer, null, isApm) : + ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context!.IsServer, null); - return task.ContinueWith((t, s) => - { - var tuple = ((SslStream, ValueStopwatch))s!; - SslStream thisRef = tuple.Item1; - ValueStopwatch stopwatch = tuple.Item2; + await task.ConfigureAwait(false); - if (t.IsCompletedSuccessfully) - { - LogSuccess(thisRef, stopwatch); - } - else - { - LogFailure(thisRef._context!.IsServer, stopwatch, t.Exception?.Message ?? "Operation canceled."); + // SslStream could already have been disposed at this point, in which case _connectionOpenedStatus == 2 + // Make sure that we increment the open connection counter only if it is guaranteed to be decremented in dispose/finalize + bool connectionOpen = Interlocked.CompareExchange(ref _connectionOpenedStatus, 1, 0) == 0; - // Throw the same exception we would if not using Telemetry - t.GetAwaiter().GetResult(); - } - }, - state: (this, stopwatch), - cancellationToken: default, - TaskContinuationOptions.ExecuteSynchronously, - TaskScheduler.Current); - } - else - { - ForceAuthenticationAsync(new SyncReadWriteAdapter(InnerStream), _context.IsServer, null).GetAwaiter().GetResult(); - LogSuccess(this, stopwatch); - return null; - } + NetSecurityTelemetry.Log.HandshakeCompleted(GetSslProtocolInternal(), stopwatch, connectionOpen); } - catch (Exception ex) when (LogFailure(_context.IsServer, stopwatch, ex.Message)) + catch (Exception ex) { - Debug.Fail("LogFailure should return false"); + NetSecurityTelemetry.Log.HandshakeFailed(_context.IsServer, stopwatch, ex.Message); throw; } - - static bool LogFailure(bool isServer, ValueStopwatch stopwatch, string exceptionMessage) - { - NetSecurityTelemetry.Log.HandshakeFailed(isServer, stopwatch, exceptionMessage); - return false; - } - - static void LogSuccess(SslStream thisRef, ValueStopwatch stopwatch) - { - // SslStream could already have been disposed at this point, in which case _connectionOpenedStatus == 2 - // Make sure that we increment the open connection counter only if it is guaranteed to be decremented in dispose/finalize - - // Using a field of a marshal-by-reference class as a ref or out value or taking its address may cause a runtime exception - // Justification: thisRef is a reference to 'this', not a proxy object -#pragma warning disable CS0197 - bool connectionOpen = Interlocked.CompareExchange(ref thisRef._connectionOpenedStatus, 1, 0) == 0; -#pragma warning restore CS0197 - - NetSecurityTelemetry.Log.HandshakeCompleted(thisRef.GetSslProtocolInternal(), stopwatch, connectionOpen); - } } // diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 75c2938d444af3..252f6d2ea094fb 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -293,7 +293,7 @@ public void AuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthent SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - ProcessAuthentication(); + ProcessAuthenticationAsync().GetAwaiter().GetResult(); } public virtual void AuthenticateAsServer(X509Certificate serverCertificate) @@ -330,7 +330,7 @@ public void AuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthent SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - ProcessAuthentication(); + ProcessAuthenticationAsync().GetAwaiter().GetResult(); } #endregion @@ -365,7 +365,7 @@ public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAu ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - return ProcessAuthentication(true, false, cancellationToken)!; + return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); } private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) @@ -375,7 +375,7 @@ private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAut ValidateCreateContext(sslClientAuthenticationOptions, _userCertificateValidationCallback, _certSelectionDelegate); - return ProcessAuthentication(true, true, cancellationToken)!; + return ProcessAuthenticationAsync(isAsync: true, isApm: true, cancellationToken); } public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) => @@ -418,7 +418,7 @@ public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAu SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - return ProcessAuthentication(true, false, cancellationToken)!; + return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); } private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) @@ -426,13 +426,13 @@ private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAut SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - return ProcessAuthentication(true, true, cancellationToken)!; + return ProcessAuthenticationAsync(isAsync: true, isApm: true, cancellationToken); } public Task AuthenticateAsServerAsync(ServerOptionsSelectionCallback optionsCallback, object? state, CancellationToken cancellationToken = default) { ValidateCreateContext(new SslAuthenticationOptions(optionsCallback, state, _userCertificateValidationCallback)); - return ProcessAuthentication(isAsync: true, isApm: false, cancellationToken)!; + return ProcessAuthenticationAsync(isAsync: true, isApm: false, cancellationToken); } public virtual Task ShutdownAsync() diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TelemetryTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TelemetryTest.cs index 72b71d04637815..eb8fa3d482a914 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/TelemetryTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TelemetryTest.cs @@ -33,28 +33,44 @@ public static void EventSource_SuccessfulHandshake_LogsStartStop() RemoteExecutor.Invoke(async () => { using var listener = new TestEventListener("System.Net.Security", EventLevel.Verbose, eventCounterInterval: 0.1d); + listener.AddActivityTracking(); - var events = new ConcurrentQueue(); - await listener.RunWithCallbackAsync(events.Enqueue, async () => + var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>(); + await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () => { // Invoke tests that'll cause some events to be generated var test = new SslStreamStreamToStreamTest_Async(); await test.SslStream_StreamToStream_Authentication_Success(); - await Task.Delay(300); + await WaitForEventCountersAsync(events); }); - Assert.DoesNotContain(events, ev => ev.EventId == 0); // errors from the EventSource itself + Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself - EventWrittenEventArgs[] starts = events.Where(e => e.EventName == "HandshakeStart").ToArray(); + (EventWrittenEventArgs Event, Guid ActivityId)[] starts = events.Where(e => e.Event.EventName == "HandshakeStart").ToArray(); Assert.Equal(2, starts.Length); - Assert.All(starts, s => Assert.Equal(2, s.Payload.Count)); - Assert.Single(starts, s => s.Payload[0] is bool isServer && isServer); - Assert.Single(starts, s => s.Payload[1] is string targetHost && targetHost.Length == 0); + Assert.All(starts, s => Assert.Equal(2, s.Event.Payload.Count)); + Assert.All(starts, s => Assert.NotEqual(Guid.Empty, s.ActivityId)); - EventWrittenEventArgs[] stops = events.Where(e => e.EventName == "HandshakeStop").ToArray(); + // isServer + (EventWrittenEventArgs Event, Guid ActivityId) serverStart = Assert.Single(starts, s => (bool)s.Event.Payload[0]); + (EventWrittenEventArgs Event, Guid ActivityId) clientStart = Assert.Single(starts, s => !(bool)s.Event.Payload[0]); + + // targetHost + Assert.Empty(Assert.IsType(serverStart.Event.Payload[1])); + Assert.NotEmpty(Assert.IsType(clientStart.Event.Payload[1])); + + Assert.NotEqual(serverStart.ActivityId, clientStart.ActivityId); + + (EventWrittenEventArgs Event, Guid ActivityId)[] stops = events.Where(e => e.Event.EventName == "HandshakeStop").ToArray(); Assert.Equal(2, stops.Length); - Assert.All(stops, s => ValidateHandshakeStopEventPayload(s, failure: false)); - Assert.DoesNotContain(events, e => e.EventName == "HandshakeFailed"); + EventWrittenEventArgs serverStop = Assert.Single(stops, s => s.ActivityId == serverStart.ActivityId).Event; + EventWrittenEventArgs clientStop = Assert.Single(stops, s => s.ActivityId == clientStart.ActivityId).Event; + + SslProtocols serverProtocol = ValidateHandshakeStopEventPayload(serverStop); + SslProtocols clientProtocol = ValidateHandshakeStopEventPayload(clientStop); + Assert.Equal(serverProtocol, clientProtocol); + + Assert.DoesNotContain(events, e => e.Event.EventName == "HandshakeFailed"); VerifyEventCounters(events, shouldHaveFailures: false); }).Dispose(); @@ -67,38 +83,57 @@ public static void EventSource_UnsuccessfulHandshake_LogsStartFailureStop() RemoteExecutor.Invoke(async () => { using var listener = new TestEventListener("System.Net.Security", EventLevel.Verbose, eventCounterInterval: 0.1d); + listener.AddActivityTracking(); - var events = new ConcurrentQueue(); - await listener.RunWithCallbackAsync(events.Enqueue, async () => + var events = new ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)>(); + await listener.RunWithCallbackAsync(e => events.Enqueue((e, e.ActivityId)), async () => { // Invoke tests that'll cause some events to be generated var test = new SslStreamStreamToStreamTest_Async(); await test.SslStream_ServerLocalCertificateSelectionCallbackReturnsNull_Throw(); - await Task.Delay(300); + await WaitForEventCountersAsync(events); }); - Assert.DoesNotContain(events, ev => ev.EventId == 0); // errors from the EventSource itself + Assert.DoesNotContain(events, ev => ev.Event.EventId == 0); // errors from the EventSource itself - EventWrittenEventArgs[] starts = events.Where(e => e.EventName == "HandshakeStart").ToArray(); + (EventWrittenEventArgs Event, Guid ActivityId)[] starts = events.Where(e => e.Event.EventName == "HandshakeStart").ToArray(); Assert.Equal(2, starts.Length); - Assert.All(starts, s => Assert.Equal(2, s.Payload.Count)); - Assert.Single(starts, s => s.Payload[0] is bool isServer && isServer); - Assert.Single(starts, s => s.Payload[1] is string targetHost && targetHost.Length == 0); + Assert.All(starts, s => Assert.Equal(2, s.Event.Payload.Count)); + Assert.All(starts, s => Assert.NotEqual(Guid.Empty, s.ActivityId)); - EventWrittenEventArgs[] failures = events.Where(e => e.EventName == "HandshakeFailed").ToArray(); - Assert.Equal(2, failures.Length); - Assert.All(failures, f => Assert.Equal(3, f.Payload.Count)); - Assert.Single(failures, f => f.Payload[0] is bool isServer && isServer); - Assert.All(failures, f => Assert.NotEmpty(f.Payload[2] as string)); // exceptionMessage + // isServer + (EventWrittenEventArgs Event, Guid ActivityId) serverStart = Assert.Single(starts, s => (bool)s.Event.Payload[0]); + (EventWrittenEventArgs Event, Guid ActivityId) clientStart = Assert.Single(starts, s => !(bool)s.Event.Payload[0]); + + // targetHost + Assert.Empty(Assert.IsType(serverStart.Event.Payload[1])); + Assert.NotEmpty(Assert.IsType(clientStart.Event.Payload[1])); - EventWrittenEventArgs[] stops = events.Where(e => e.EventName == "HandshakeStop").ToArray(); + Assert.NotEqual(serverStart.ActivityId, clientStart.ActivityId); + + (EventWrittenEventArgs Event, Guid ActivityId)[] stops = events.Where(e => e.Event.EventName == "HandshakeStop").ToArray(); Assert.Equal(2, stops.Length); - Assert.All(stops, s => ValidateHandshakeStopEventPayload(s, failure: true)); + Assert.All(stops, s => ValidateHandshakeStopEventPayload(s.Event, failure: true)); + + EventWrittenEventArgs serverStop = Assert.Single(stops, s => s.ActivityId == serverStart.ActivityId).Event; + EventWrittenEventArgs clientStop = Assert.Single(stops, s => s.ActivityId == clientStart.ActivityId).Event; + + (EventWrittenEventArgs Event, Guid ActivityId)[] failures = events.Where(e => e.Event.EventName == "HandshakeFailed").ToArray(); + Assert.Equal(2, failures.Length); + Assert.All(failures, f => Assert.Equal(3, f.Event.Payload.Count)); + Assert.All(failures, f => Assert.NotEmpty(f.Event.Payload[2] as string)); // exceptionMessage + + EventWrittenEventArgs serverFailure = Assert.Single(failures, f => f.ActivityId == serverStart.ActivityId).Event; + EventWrittenEventArgs clientFailure = Assert.Single(failures, f => f.ActivityId == clientStart.ActivityId).Event; + + // isServer + Assert.Equal(true, serverFailure.Payload[0]); + Assert.Equal(false, clientFailure.Payload[0]); VerifyEventCounters(events, shouldHaveFailures: true); }).Dispose(); } - private static void ValidateHandshakeStopEventPayload(EventWrittenEventArgs stopEvent, bool failure) + private static SslProtocols ValidateHandshakeStopEventPayload(EventWrittenEventArgs stopEvent, bool failure = false) { Assert.Equal("HandshakeStop", stopEvent.EventName); Assert.Equal(1, stopEvent.Payload.Count); @@ -114,11 +149,14 @@ private static void ValidateHandshakeStopEventPayload(EventWrittenEventArgs stop { Assert.NotEqual(SslProtocols.None, protocol); } + + return protocol; } - private static void VerifyEventCounters(ConcurrentQueue events, bool shouldHaveFailures) + private static void VerifyEventCounters(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events, bool shouldHaveFailures) { Dictionary eventCounters = events + .Select(e => e.Event) .Where(e => e.EventName == "EventCounters") .Select(e => (IDictionary)e.Payload.Single()) .GroupBy(d => (string)d["Name"], d => (double)(d.ContainsKey("Mean") ? d["Mean"] : d["Increment"])) @@ -174,5 +212,29 @@ private static void VerifyEventCounters(ConcurrentQueue e Assert.Contains(allHandshakeDurations, d => d > 0); } } + + private static async Task WaitForEventCountersAsync(ConcurrentQueue<(EventWrittenEventArgs Event, Guid ActivityId)> events) + { + DateTime startTime = DateTime.UtcNow; + int startCount = events.Count; + + while (events.Skip(startCount).Count(e => IsTlsHandshakeRateEventCounter(e.Event)) < 3) + { + if (DateTime.UtcNow.Subtract(startTime) > TimeSpan.FromSeconds(30)) + throw new TimeoutException($"Timed out waiting for EventCounters"); + + await Task.Delay(100); + } + + static bool IsTlsHandshakeRateEventCounter(EventWrittenEventArgs e) + { + if (e.EventName != "EventCounters") + return false; + + var dictionary = (IDictionary)e.Payload.Single(); + + return (string)dictionary["Name"] == "tls-handshake-rate"; + } + } } } diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs index 4e146153626b31..d97d1e8bcb67bf 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs @@ -54,7 +54,7 @@ private void CloseInternal() // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) + private Task ProcessAuthenticationAsync(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) { return Task.Run(() => {}); } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TelemetryTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TelemetryTest.cs index 7b507a18f40373..acd9029136eb7c 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TelemetryTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TelemetryTest.cs @@ -357,7 +357,7 @@ private static async Task WaitForEventCountersAsync(ConcurrentQueue<(EventWritte DateTime startTime = DateTime.UtcNow; int startCount = events.Count; - while (events.Skip(startCount).Count(e => IsBytesSentEventCounter(e.Event)) < 2) + while (events.Skip(startCount).Count(e => IsBytesSentEventCounter(e.Event)) < 3) { if (DateTime.UtcNow.Subtract(startTime) > TimeSpan.FromSeconds(30)) throw new TimeoutException($"Timed out waiting for EventCounters");