Skip to content

Commit

Permalink
SmtpClient SendMailAsync with CancellationToken API implementation (#287
Browse files Browse the repository at this point in the history
)

* Add SmtpClient.SendMailAsync overloads with cancellation

* Rework mock SmtpClient tests

* Verify that SmtpClient uses Auth if available

* Add tests for SmtpClient.SendMailAsync using CancellationTokens

* Revert to case-insensitive comparison of hostnames in SmtpClient tests

* Disable SmtpClient NTLM test on Unix

* Address PR feedback

* Address PR feedback

* Address PR feedback

Use Interlocked.Exchange instead of locks
  • Loading branch information
MihaZupan authored Nov 27, 2019
1 parent d0974b8 commit 40b11db
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 252 deletions.
2 changes: 2 additions & 0 deletions src/libraries/System.Net.Mail/ref/System.Net.Mail.cs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ public void SendAsync(string from, string recipients, string subject, string bod
public void SendAsyncCancel() { }
public System.Threading.Tasks.Task SendMailAsync(System.Net.Mail.MailMessage message) { throw null; }
public System.Threading.Tasks.Task SendMailAsync(string from, string recipients, string subject, string body) { throw null; }
public System.Threading.Tasks.Task SendMailAsync(System.Net.Mail.MailMessage message, System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.Task SendMailAsync(string from, string recipients, string subject, string body, System.Threading.CancellationToken cancellationToken) { throw null; }
}
public enum SmtpDeliveryFormat
{
Expand Down
69 changes: 54 additions & 15 deletions src/libraries/System.Net.Mail/src/System/Net/Mail/SmtpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -782,17 +782,59 @@ public void SendAsyncCancel()
public Task SendMailAsync(string from, string recipients, string subject, string body)
{
var message = new MailMessage(from, recipients, subject, body);
return SendMailAsync(message);
return SendMailAsync(message, cancellationToken: default);
}

public Task SendMailAsync(MailMessage message)
{
return SendMailAsync(message, cancellationToken: default);
}

public Task SendMailAsync(string from, string recipients, string subject, string body, CancellationToken cancellationToken)
{
var message = new MailMessage(from, recipients, subject, body);
return SendMailAsync(message, cancellationToken);
}

public Task SendMailAsync(MailMessage message, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}

// Create a TaskCompletionSource to represent the operation
var tcs = new TaskCompletionSource<object>();

CancellationTokenRegistration ctr = default;

// Indicates whether the CTR has been set - captured in handler
int state = 0;

// Register a handler that will transfer completion results to the TCS Task
SendCompletedEventHandler handler = null;
handler = (sender, e) => HandleCompletion(tcs, e, handler);
handler = (sender, e) =>
{
if (e.UserState == tcs)
{
try
{
((SmtpClient)sender).SendCompleted -= handler;
if (Interlocked.Exchange(ref state, 1) != 0)
{
// A CTR has been set, we have to wait until it completes before completing the task
ctr.Dispose();
}
}
catch (ObjectDisposedException) { } // SendAsyncCancel will throw if SmtpClient was disposed
finally
{
if (e.Error != null) tcs.TrySetException(e.Error);
else if (e.Cancelled) tcs.TrySetCanceled();
else tcs.TrySetResult(null);
}
}
};
SendCompleted += handler;

// Start the async operation.
Expand All @@ -806,22 +848,19 @@ public Task SendMailAsync(MailMessage message)
throw;
}

// Return the task to represent the asynchronous operation
return tcs.Task;
}
ctr = cancellationToken.Register(s =>
{
((SmtpClient)s).SendAsyncCancel();
}, this);

private void HandleCompletion(TaskCompletionSource<object> tcs, AsyncCompletedEventArgs e, SendCompletedEventHandler handler)
{
if (e.UserState == tcs)
if (Interlocked.Exchange(ref state, 1) != 0)
{
try { SendCompleted -= handler; }
finally
{
if (e.Error != null) tcs.TrySetException(e.Error);
else if (e.Cancelled) tcs.TrySetCanceled();
else tcs.TrySetResult(null);
}
// SendCompleted was already invoked, ensure the CTR completes before returning the task
ctr.Dispose();
}

// Return the task to represent the asynchronous operation
return tcs.Task;
}


Expand Down
265 changes: 265 additions & 0 deletions src/libraries/System.Net.Mail/tests/Functional/LoopbackSmtpServer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Net;
using System.Net.Mail;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Systen.Net.Mail.Tests
{
public class LoopbackSmtpServer : IDisposable
{
private static readonly ReadOnlyMemory<byte> s_messageTerminator = new byte[] { (byte)'\r', (byte)'\n' };
private static readonly ReadOnlyMemory<byte> s_bodyTerminator = new byte[] { (byte)'\r', (byte)'\n', (byte)'.', (byte)'\r', (byte)'\n' };

public bool ReceiveMultipleConnections = false;
public bool SupportSmtpUTF8 = false;
public bool AdvertiseNtlmAuthSupport = false;

private bool _disposed = false;
private readonly Socket _listenSocket;
private readonly ConcurrentBag<Socket> _socketsToDispose;
private long _messageCounter = new Random().Next(1000, 2000);

public readonly int Port;
public SmtpClient CreateClient() => new SmtpClient("localhost", Port);

public Action<Socket> OnConnected;
public Action<string> OnHelloReceived;
public Action<string, string> OnCommandReceived;
public Action<string> OnUnknownCommand;
public Action<Socket> OnQuitReceived;

public string ClientDomain { get; private set; }
public string MailFrom { get; private set; }
public string MailTo { get; private set; }
public string UsernamePassword { get; private set; }
public string Username { get; private set; }
public string Password { get; private set; }
public string AuthMethodUsed { get; private set; }
public ParsedMailMessage Message { get; private set; }

public int ConnectionCount { get; private set; }
public int MessagesReceived { get; private set; }

public LoopbackSmtpServer()
{
_socketsToDispose = new ConcurrentBag<Socket>();
_listenSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
_socketsToDispose.Add(_listenSocket);

_listenSocket.Bind(new IPEndPoint(IPAddress.Any, 0));
Port = ((IPEndPoint)_listenSocket.LocalEndPoint).Port;
_listenSocket.Listen(1);

_ = Task.Run(async () =>
{
do
{
var socket = await _listenSocket.AcceptAsync();
_socketsToDispose.Add(socket);
ConnectionCount++;
_ = Task.Run(async () => await HandleConnectionAsync(socket));
}
while (ReceiveMultipleConnections);
});
}

private async Task HandleConnectionAsync(Socket socket)
{
var buffer = new byte[1024].AsMemory();

async ValueTask<string> ReceiveMessageAsync(bool isBody = false)
{
var terminator = isBody ? s_bodyTerminator : s_messageTerminator;
int suffix = terminator.Length;

int received = 0;
do
{
int read = await socket.ReceiveAsync(buffer.Slice(received), SocketFlags.None);
if (read == 0) return null;
received += read;
}
while (received < suffix || !buffer.Slice(received - suffix, suffix).Span.SequenceEqual(terminator.Span));

MessagesReceived++;
return Encoding.UTF8.GetString(buffer.Span.Slice(0, received - suffix));
}
async ValueTask SendMessageAsync(string text)
{
var bytes = buffer.Slice(0, Encoding.UTF8.GetBytes(text, buffer.Span) + 2);
bytes.Span[^2] = (byte)'\r';
bytes.Span[^1] = (byte)'\n';
await socket.SendAsync(bytes, SocketFlags.None);
}

try
{
OnConnected?.Invoke(socket);
await SendMessageAsync("220 localhost");

string message = await ReceiveMessageAsync();
Debug.Assert(message.ToLower().StartsWith("helo ") || message.ToLower().StartsWith("ehlo "));
ClientDomain = message.Substring(5).ToLower();
OnCommandReceived?.Invoke(message.Substring(0, 4), ClientDomain);
OnHelloReceived?.Invoke(ClientDomain);

await SendMessageAsync("250-localhost, mock server here");
if (SupportSmtpUTF8) await SendMessageAsync("250-SMTPUTF8");
await SendMessageAsync("250 AUTH PLAIN LOGIN" + (AdvertiseNtlmAuthSupport ? " NTLM" : ""));

while ((message = await ReceiveMessageAsync()) != null)
{
int colonIndex = message.IndexOf(':');
string command = colonIndex == -1 ? message : message.Substring(0, colonIndex);
string argument = command.Length == message.Length ? string.Empty : message.Substring(colonIndex + 1).Trim();

OnCommandReceived?.Invoke(command, argument);

if (command.StartsWith("AUTH", StringComparison.OrdinalIgnoreCase))
{
var parts = command.Split(' ');
Debug.Assert(parts.Length > 1, "Expected an actual auth request");

AuthMethodUsed = parts[1];

if (parts[1].Equals("LOGIN", StringComparison.OrdinalIgnoreCase))
{
if (parts.Length == 2)
{
await SendMessageAsync("334 VXNlcm5hbWU6");
Username = Encoding.UTF8.GetString(Convert.FromBase64String(await ReceiveMessageAsync()));
}
else
{
Username = Encoding.UTF8.GetString(Convert.FromBase64String(parts[2]));
}
await SendMessageAsync("334 UGFzc3dvcmQ6");
Password = Encoding.UTF8.GetString(Convert.FromBase64String(await ReceiveMessageAsync()));
UsernamePassword = Username + Password;
await SendMessageAsync("235 Authentication successful");
}
else if (parts[1].Equals("NTLM", StringComparison.OrdinalIgnoreCase))
{
await SendMessageAsync("12345 I lied, I can't speak NTLM - here's an invalid response");
}
else await SendMessageAsync("504 scheme not supported");
continue;
}

switch (command.ToUpper())
{
case "MAIL FROM":
MailFrom = argument;
await SendMessageAsync("250 Ok");
break;

case "RCPT TO":
MailTo = argument;
await SendMessageAsync("250 Ok");
break;

case "DATA":
await SendMessageAsync("354 Start mail input; end with <CRLF>.<CRLF>");
string data = await ReceiveMessageAsync(true);
Message = ParsedMailMessage.Parse(data);
await SendMessageAsync("250 Ok: queued as " + Interlocked.Increment(ref _messageCounter));
break;

case "QUIT":
OnQuitReceived?.Invoke(socket);
await SendMessageAsync("221 Bye");
return;

default:
OnUnknownCommand?.Invoke(message);
await SendMessageAsync("500 Idk that command");
break;
}
}
}
catch { }
finally
{
try
{
socket.Shutdown(SocketShutdown.Both);
}
finally
{
socket?.Close();
}
}
}

public void Dispose()
{
if (!_disposed)
{
_disposed = true;
foreach (var socket in _socketsToDispose)
{
try
{
socket.Close();
}
catch { }
}
_socketsToDispose.Clear();
}
}


public class ParsedMailMessage
{
public readonly IReadOnlyDictionary<string, string> Headers;
public readonly string Body;

private string GetHeader(string name) => Headers.TryGetValue(name, out string value) ? value : "NOT-PRESENT";
public string From => GetHeader("From");
public string To => GetHeader("To");
public string Subject => GetHeader("Subject");

private ParsedMailMessage(Dictionary<string, string> headers, string body)
{
Headers = headers;
Body = body;
}

public static ParsedMailMessage Parse(string data)
{
Dictionary<string, string> headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);

ReadOnlySpan<char> dataSpan = data;
string body = null;

while (!dataSpan.IsEmpty)
{
int endOfLine = dataSpan.IndexOf('\n');
Debug.Assert(endOfLine != -1, "Expected valid \r\n terminated lines");
var line = dataSpan.Slice(0, endOfLine).TrimEnd('\r');

if (line.IsEmpty)
{
body = dataSpan.Slice(endOfLine + 1).TrimEnd(stackalloc char[] { '\r', '\n' }).ToString();
break;
}
else
{
int colon = line.IndexOf(':');
Debug.Assert(colon != -1, "Expected a valid header");
headers.Add(line.Slice(0, colon).Trim().ToString(), line.Slice(colon + 1).Trim().ToString());
dataSpan = dataSpan.Slice(endOfLine + 1);
}
}

return new ParsedMailMessage(headers, body);
}
}
}
}
Loading

0 comments on commit 40b11db

Please sign in to comment.