diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d5518173..4d64375b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ### unreleased * [NEW] Add .NET 5 support +* [NEW] Add `.ThrowsAsync()` that will correctly mock exception on async methods. (#609) ### 4.2.2 (Jun 2020) diff --git a/src/NSubstitute/Extensions/ExceptionExtensions.cs b/src/NSubstitute/Extensions/ExceptionExtensions.cs index 60f3b5339..012337fb9 100644 --- a/src/NSubstitute/Extensions/ExceptionExtensions.cs +++ b/src/NSubstitute/Extensions/ExceptionExtensions.cs @@ -1,4 +1,9 @@ using System; +using System.Linq; +using System.Reflection; +#if !NET45 +using System.Threading.Tasks; +#endif using NSubstitute.Core; // Disable nullability for client API, so it does not affect clients. @@ -67,5 +72,181 @@ public static ConfiguredCall ThrowsForAnyArgs(this object value) /// public static ConfiguredCall ThrowsForAnyArgs(this object value, Func createException) => value.ReturnsForAnyArgs(ci => throw createException(ci)); + +#if !NET45 + /// + /// Throw an exception for this call. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsync(this Task value, Exception ex) => + value.Returns(_ => Task.FromException(ex)); + + /// + /// Throw an exception for this call. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsync(this Task value, Exception ex) => + value.Returns(_ => Task.FromException(ex)); + + /// + /// Throw an exception of the given type for this call. + /// + /// Type of exception to throw + /// + /// + public static ConfiguredCall ThrowsAsync(this Task value) + where TException : notnull, Exception, new() + { + return value.Returns(_ => FromException(value, new TException())); + } + + /// + /// Throw an exception for this call, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsync(this Task value, Func createException) => + value.Returns(ci => Task.FromException(createException(ci))); + + /// + /// Throw an exception for this call, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsync(this Task value, Func createException) => + value.Returns(ci => Task.FromException(createException(ci))); + + /// + /// Throw an exception for this call made with any arguments. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this Task value, Exception ex) => + value.ReturnsForAnyArgs(_ => Task.FromException(ex)); + + /// + /// Throw an exception for this call made with any arguments. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this Task value, Exception ex) => + value.ReturnsForAnyArgs(_ => Task.FromException(ex)); + + /// + /// Throws an exception of the given type for this call made with any arguments. + /// + /// Type of exception to throw + /// + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this Task value) + where TException : notnull, Exception, new() + { + return value.ReturnsForAnyArgs(_ => FromException(value, new TException())); + } + + /// + /// Throws an exception for this call made with any arguments, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this Task value, Func createException) => + value.ReturnsForAnyArgs(ci => Task.FromException(createException(ci))); + + /// + /// Throws an exception for this call made with any arguments, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this Task value, Func createException) => + value.ReturnsForAnyArgs(ci => Task.FromException(createException(ci))); + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP2_1_OR_GREATER + /// + /// Throw an exception for this call. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsync(this ValueTask value, Exception ex) => + value.Returns(_ => ValueTask.FromException(ex)); + + /// + /// Throw an exception of the given type for this call. + /// + /// Type of exception to throw + /// Type of exception to throw + /// + /// + public static ConfiguredCall ThrowsAsync(this ValueTask value) + where TException : notnull, Exception, new() + { + return value.Returns(_ => ValueTask.FromException(new TException())); + } + + /// + /// Throw an exception for this call, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsync(this ValueTask value, Func createException) => + value.Returns(ci => ValueTask.FromException(createException(ci))); + + /// + /// Throws an exception of the given type for this call made with any arguments. + /// + /// + /// Type of exception to throw + /// + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this ValueTask value) + where TException : notnull, Exception, new() + { + return value.ReturnsForAnyArgs(_ => ValueTask.FromException(new TException())); + } + + /// + /// Throw an exception for this call made with any arguments. + /// + /// + /// Exception to throw + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this ValueTask value, Exception ex) => + value.ReturnsForAnyArgs(_ => ValueTask.FromException(ex)); + + /// + /// Throws an exception for this call made with any arguments, as generated by the specified function. + /// + /// + /// Func creating exception object + /// + public static ConfiguredCall ThrowsAsyncForAnyArgs(this ValueTask value, Func createException) => + value.ReturnsForAnyArgs(ci => ValueTask.FromException(createException(ci))); +#endif + + private static object FromException(object value, Exception exception) + { + // Handle Task + var valueType = value.GetType(); + if (valueType.IsConstructedGenericType) + { + var fromExceptionMethodInfo = typeof(Task).GetMethods(BindingFlags.Static | BindingFlags.Public).Single(m => m.Name == "FromException" && m.ContainsGenericParameters); + var specificFromExceptionMethod = fromExceptionMethodInfo.MakeGenericMethod(valueType.GenericTypeArguments); + return specificFromExceptionMethod.Invoke(null, new object[] {exception}); + } + + return Task.FromException(exception); + } +#endif } } diff --git a/tests/NSubstitute.Acceptance.Specs/Infrastructure/ISomething.cs b/tests/NSubstitute.Acceptance.Specs/Infrastructure/ISomething.cs index 91a3ad8ef..94501058b 100644 --- a/tests/NSubstitute.Acceptance.Specs/Infrastructure/ISomething.cs +++ b/tests/NSubstitute.Acceptance.Specs/Infrastructure/ISomething.cs @@ -19,7 +19,9 @@ public interface ISomething object this[string key] { get; set; } System.Threading.Tasks.Task Async(); + System.Threading.Tasks.Task DoAsync(object stuff); System.Threading.Tasks.Task CountAsync(); + System.Threading.Tasks.Task AnythingAsync(object stuff); System.Threading.Tasks.Task EchoAsync(int i); System.Threading.Tasks.Task SayAsync(string s); System.Threading.Tasks.Task SomeActionAsync(); @@ -27,6 +29,7 @@ public interface ISomething System.Threading.Tasks.ValueTask CountValueTaskAsync(); System.Threading.Tasks.ValueTask EchoValueTaskAsync(int i); + System.Threading.Tasks.ValueTask AnythingValueTaskAsync(object stuff); System.Threading.Tasks.ValueTask SayValueTaskAsync(string s); System.Threading.Tasks.ValueTask SomeActionValueTaskAsync(); System.Threading.Tasks.ValueTask SomeActionWithParamsValueTaskAsync(int i, string s); diff --git a/tests/NSubstitute.Acceptance.Specs/ThrowingAsyncExceptions.cs b/tests/NSubstitute.Acceptance.Specs/ThrowingAsyncExceptions.cs new file mode 100644 index 000000000..e097a6915 --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/ThrowingAsyncExceptions.cs @@ -0,0 +1,342 @@ +#if !NET45 +using System; +using System.Linq; +using System.Threading.Tasks; +using NSubstitute.Acceptance.Specs.Infrastructure; +using NSubstitute.ExceptionExtensions; +using NUnit.Framework; + +namespace NSubstitute.Acceptance.Specs +{ + [TestFixture] + public class ThrowingAsyncExceptions + { + public class WithVoidReturn + { + private ISomething _something; + + [Test] + public void ThrowAsyncException() + { + var exception = new Exception(); + _something.Async().ThrowsAsync(exception); + + AssertFaultedTaskException(() => _something.Async()); + } + + [Test] + public void ThrowAsyncExceptionWithDefaultConstructor() + { + _something.Async().ThrowsAsync(); + + AssertFaultedTaskException(() => _something.Async()); + } + + [Test] + public void ThrowExceptionWithMessage() + { + const string exceptionMessage = "This is exception's message"; + + _something.Async().ThrowsAsync(new Exception(exceptionMessage)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.Async()); + Assert.AreEqual(exceptionMessage, exceptionThrown.Message); + } + + [Test] + public void ThrowExceptionWithInnerException() + { + ArgumentException innerException = new ArgumentException(); + _something.Async().ThrowsAsync(new Exception("Exception message", innerException)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.Async()); + + Assert.IsNotNull(exceptionThrown.InnerException); + Assert.IsInstanceOf(exceptionThrown.InnerException); + } + + [Test] + public void ThrowExceptionUsingFactoryFunc() + { + _something.DoAsync("abc").ThrowsAsync(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.DoAsync("abc")); + } + + [Test] + public void DoesNotThrowForNonMatchingArgs() + { + _something.DoAsync(12).ThrowsAsync(new Exception()); + + AssertFaultedTaskException(() => _something.DoAsync(12)); + Assert.DoesNotThrowAsync(() => _something.DoAsync(11)); + } + + [Test] + public void ThrowExceptionForAnyArgs() + { + _something.DoAsync(12).ThrowsAsyncForAnyArgs(new Exception()); + + AssertFaultedTaskException(() => _something.DoAsync(null)); + AssertFaultedTaskException(() => _something.DoAsync(12)); + } + + [Test] + public void ThrowExceptionWithDefaultConstructorForAnyArgs() + { + _something.DoAsync(12).ThrowsAsyncForAnyArgs(); + + AssertFaultedTaskException(() => _something.DoAsync(null)); + } + + [Test] + public void ThrowExceptionCreatedByFactoryFuncForAnyArgs() + { + _something.DoAsync(null).ThrowsAsyncForAnyArgs(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.DoAsync(new object())); + } + + [SetUp] + public void SetUp() + { + _something = Substitute.For(); + } + + [TearDown] + public void TearDown() + { + _something = null; + } + } + + public class WithReturnValue + { + private ISomething _something; + + [Test] + public void ThrowAsyncException() + { + var exception = new Exception(); + _something.CountAsync().ThrowsAsync(exception); + + AssertFaultedTaskException(() => _something.CountAsync()); + } + + + [Test] + public void ThrowAsyncExceptionWithDefaultConstructor() + { + _something.CountAsync().ThrowsAsync(); + + AssertFaultedTaskException(() => _something.CountAsync()); + } + + [Test] + public void ThrowExceptionWithMessage() + { + const string exceptionMessage = "This is exception's message"; + + _something.CountAsync().ThrowsAsync(new Exception(exceptionMessage)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.CountAsync()); + Assert.AreEqual(exceptionMessage, exceptionThrown.Message); + } + + [Test] + public void ThrowExceptionWithInnerException() + { + ArgumentException innerException = new ArgumentException(); + _something.CountAsync().ThrowsAsync(new Exception("Exception message", innerException)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.CountAsync()); + + Assert.IsNotNull(exceptionThrown.InnerException); + Assert.IsInstanceOf(exceptionThrown.InnerException); + } + + [Test] + public void ThrowExceptionUsingFactoryFunc() + { + _something.AnythingAsync("abc").ThrowsAsync(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.AnythingAsync("abc")); + } + + [Test] + public void DoesNotThrowForNonMatchingArgs() + { + _something.AnythingAsync(12).ThrowsAsync(new Exception()); + + AssertFaultedTaskException(() => _something.AnythingAsync(12)); + Assert.DoesNotThrowAsync(() => _something.AnythingAsync(11)); + } + + [Test] + public void ThrowExceptionForAnyArgs() + { + _something.AnythingAsync(12).ThrowsAsyncForAnyArgs(new Exception()); + + AssertFaultedTaskException(() => _something.AnythingAsync(null)); + AssertFaultedTaskException(() => _something.AnythingAsync(12)); + } + + [Test] + public void ThrowExceptionWithDefaultConstructorForAnyArgs() + { + _something.AnythingAsync(12).ThrowsAsyncForAnyArgs(); + + AssertFaultedTaskException(() => _something.AnythingAsync(null)); + } + + [Test] + public void ThrowExceptionCreatedByFactoryFuncForAnyArgs() + { + _something.AnythingAsync(null).ThrowsAsyncForAnyArgs(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.AnythingAsync(new object())); + } + + [SetUp] + public void SetUp() + { + _something = Substitute.For(); + } + + [TearDown] + public void TearDown() + { + _something = null; + } + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP2_1_OR_GREATER + public class ForValueTask + { + + private ISomething _something; + + [Test] + public void ThrowAsyncException() + { + var exception = new Exception(); + _something.CountValueTaskAsync().ThrowsAsync(exception); + + AssertFaultedTaskException(() => _something.CountValueTaskAsync()); + } + + [Test] + public void ThrowAsyncExceptionWithDefaultConstructor() + { + _something.CountValueTaskAsync().ThrowsAsync(); + + AssertFaultedTaskException(() => _something.CountValueTaskAsync()); + } + + [Test] + public void ThrowExceptionWithMessage() + { + const string exceptionMessage = "This is exception's message"; + + _something.CountValueTaskAsync().ThrowsAsync(new Exception(exceptionMessage)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.CountValueTaskAsync()); + Assert.AreEqual(exceptionMessage, exceptionThrown.Message); + } + + [Test] + public void ThrowExceptionWithInnerException() + { + ArgumentException innerException = new ArgumentException(); + _something.CountValueTaskAsync().ThrowsAsync(new Exception("Exception message", innerException)); + + Exception exceptionThrown = AssertFaultedTaskException(() => _something.CountValueTaskAsync()); + + Assert.IsNotNull(exceptionThrown.InnerException); + Assert.IsInstanceOf(exceptionThrown.InnerException); + } + + [Test] + public void ThrowExceptionUsingFactoryFunc() + { + _something.AnythingValueTaskAsync("abc").ThrowsAsync(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync("abc")); + } + + [Test] + public void DoesNotThrowForNonMatchingArgs() + { + _something.AnythingValueTaskAsync(12).ThrowsAsync(new Exception()); + + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync(12)); + AssertDoesNotThrow(() => _something.AnythingValueTaskAsync(11)); + } + + [Test] + public void ThrowExceptionForAnyArgs() + { + _something.AnythingValueTaskAsync(12).ThrowsAsyncForAnyArgs(new Exception()); + + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync(null)); + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync(12)); + } + + [Test] + public void ThrowExceptionWithDefaultConstructorForAnyArgs() + { + _something.AnythingValueTaskAsync(12).ThrowsAsyncForAnyArgs(); + + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync(null)); + } + + [Test] + public void ThrowExceptionCreatedByFactoryFuncForAnyArgs() + { + _something.AnythingValueTaskAsync(null).ThrowsAsyncForAnyArgs(ci => new ArgumentException("Args:" + ci.Args()[0])); + + AssertFaultedTaskException(() => _something.AnythingValueTaskAsync(new object())); + } + + [SetUp] + public void SetUp() + { + _something = Substitute.For(); + } + + [TearDown] + public void TearDown() + { + _something = null; + } + + public static TException AssertFaultedTaskException(Func> act) + where TException : Exception + { + var actual = act(); + + Assert.That(actual.IsFaulted, Is.True); + return Assert.CatchAsync(async () => await actual); + } + + public static void AssertDoesNotThrow(Func> act) + { + var actual = act(); + + Assert.That(actual.IsFaulted, Is.False); + } + } +#endif + + public static TException AssertFaultedTaskException(Func act) + where TException : Exception + { + var actual = act(); + + Assert.That(actual.Status, Is.EqualTo(TaskStatus.Faulted)); + Assert.That(actual.Exception, Is.TypeOf()); + return actual.Exception!.InnerExceptions.First() as TException; + } + } +} +#endif \ No newline at end of file