diff --git a/Dapper/SqlMapper.Async.cs b/Dapper/SqlMapper.Async.cs index a0f84afc2..9408d5735 100644 --- a/Dapper/SqlMapper.Async.cs +++ b/Dapper/SqlMapper.Async.cs @@ -1333,6 +1333,11 @@ static async IAsyncEnumerable Impl(IDbConnection cnn, Type effectiveType, Com { if (reader is not null) { + if (!reader.IsClosed) + { + try { cmd?.Cancel(); } + catch { /* don't spoil any existing exception */ } + } await reader.DisposeAsync(); } if (wasClosed) cnn.Close(); diff --git a/Dapper/SqlMapper.Settings.cs b/Dapper/SqlMapper.Settings.cs index 343906b74..cbbc3c687 100644 --- a/Dapper/SqlMapper.Settings.cs +++ b/Dapper/SqlMapper.Settings.cs @@ -11,9 +11,10 @@ public static partial class SqlMapper /// public static class Settings { - // disable single result by default; prevents errors AFTER the select being detected properly - private const CommandBehavior DefaultAllowedCommandBehaviors = ~CommandBehavior.SingleResult; + // disable single row/result by default; prevents errors AFTER the select being detected properly + private const CommandBehavior DefaultAllowedCommandBehaviors = ~(CommandBehavior.SingleResult | CommandBehavior.SingleRow); internal static CommandBehavior AllowedCommandBehaviors { get; private set; } = DefaultAllowedCommandBehaviors; + private static void SetAllowedCommandBehaviors(CommandBehavior behavior, bool enabled) { if (enabled) AllowedCommandBehaviors |= behavior; diff --git a/Dapper/SqlMapper.cs b/Dapper/SqlMapper.cs index e95d35297..fa5ce52df 100644 --- a/Dapper/SqlMapper.cs +++ b/Dapper/SqlMapper.cs @@ -1148,7 +1148,7 @@ private static GridReader QueryMultipleImpl(this IDbConnection cnn, ref CommandD if (!reader.IsClosed) { try { cmd?.Cancel(); } - catch { /* don't spoil the existing exception */ } + catch { /* don't spoil any existing exception */ } } reader.Dispose(); } @@ -1229,7 +1229,7 @@ private static IEnumerable QueryImpl(this IDbConnection cnn, CommandDefini if (!reader.IsClosed) { try { cmd?.Cancel(); } - catch { /* don't spoil the existing exception */ } + catch { /* don't spoil any existing exception */ } } reader.Dispose(); } @@ -1321,7 +1321,7 @@ private static T QueryRowImpl(IDbConnection cnn, Row row, ref CommandDefiniti if (!reader.IsClosed) { try { cmd?.Cancel(); } - catch { /* don't spoil the existing exception */ } + catch { /* don't spoil any existing exception */ } } reader.Dispose(); } diff --git a/Directory.Packages.props b/Directory.Packages.props index 23d2ece4c..4cc865132 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -1,51 +1,51 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/Dapper.Tests/Dapper.Tests.csproj b/tests/Dapper.Tests/Dapper.Tests.csproj index 5863bc8ed..f370ebbcc 100644 --- a/tests/Dapper.Tests/Dapper.Tests.csproj +++ b/tests/Dapper.Tests/Dapper.Tests.csproj @@ -6,6 +6,7 @@ $(DefineConstants);MSSQLCLIENT $(NoWarn);IDE0017;IDE0034;IDE0037;IDE0039;IDE0042;IDE0044;IDE0051;IDE0052;IDE0059;IDE0060;IDE0063;IDE1006;xUnit1004;CA1806;CA1816;CA1822;CA1825;CA2208;CA1861 enable + true @@ -16,6 +17,7 @@ + diff --git a/tests/Dapper.Tests/SingleRowTests.cs b/tests/Dapper.Tests/SingleRowTests.cs new file mode 100644 index 000000000..a26d757f8 --- /dev/null +++ b/tests/Dapper.Tests/SingleRowTests.cs @@ -0,0 +1,146 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FastMember; +using Xunit; +using Xunit.Abstractions; +using static Dapper.SqlMapper; + +namespace Dapper.Tests; + +[Collection("SingleRowTests")] +public sealed class SystemSqlClientSingleRowTests(ITestOutputHelper log) : SingleRowTests(log) +{ + protected override async Task InjectDataAsync(DbConnection conn, DbDataReader source) + { + using var bcp = new System.Data.SqlClient.SqlBulkCopy((System.Data.SqlClient.SqlConnection)conn); + bcp.DestinationTableName = "#mydata"; + bcp.EnableStreaming = true; + await bcp.WriteToServerAsync(source); + } +} +#if MSSQLCLIENT +[Collection("SingleRowTests")] +public sealed class MicrosoftSqlClientSingleRowTests(ITestOutputHelper log) : SingleRowTests(log) +{ + protected override async Task InjectDataAsync(DbConnection conn, DbDataReader source) + { + using var bcp = new Microsoft.Data.SqlClient.SqlBulkCopy((Microsoft.Data.SqlClient.SqlConnection)conn); + bcp.DestinationTableName = "#mydata"; + bcp.EnableStreaming = true; + await bcp.WriteToServerAsync(source); + } +} +#endif +public abstract class SingleRowTests(ITestOutputHelper log) : TestBase where TProvider : DatabaseProvider +{ + protected abstract Task InjectDataAsync(DbConnection connection, DbDataReader source); + + [Fact] + public async Task QueryFirst_PerformanceAndCorrectness() + { + using var conn = GetOpenConnection(); + conn.Execute("create table #mydata(id int not null, name nvarchar(250) not null)"); + + var rand = new Random(); + var data = from id in Enumerable.Range(1, 500_000) + select new MyRow { Id = rand.Next(), Name = CreateName(rand) }; + + Stopwatch watch; + using (var reader = ObjectReader.Create(data)) + { + await InjectDataAsync(conn, reader); + watch = Stopwatch.StartNew(); + var count = await conn.QuerySingleAsync("""select count(1) from #mydata"""); + watch.Stop(); + log.WriteLine($"bulk-insert complete; {count} rows in {watch.ElapsedMilliseconds}ms"); + } + + // just errors + var ex = Assert.ThrowsAny(() => conn.Execute("raiserror('bad things', 16, 1)")); + log.WriteLine(ex.Message); + ex = await Assert.ThrowsAnyAsync(async () => await conn.ExecuteAsync("raiserror('bad things', 16, 1)")); + log.WriteLine(ex.Message); + + // just data + watch = Stopwatch.StartNew(); + var row = conn.QueryFirst("select top 1 * from #mydata"); + watch.Stop(); + log.WriteLine($"sync top 1 read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); + + watch = Stopwatch.StartNew(); + row = await conn.QueryFirstAsync("select top 1 * from #mydata"); + watch.Stop(); + log.WriteLine($"async top 1 read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); + + watch = Stopwatch.StartNew(); + row = conn.QueryFirst("select * from #mydata"); + watch.Stop(); + log.WriteLine($"sync read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); + + watch = Stopwatch.StartNew(); + row = await conn.QueryFirstAsync("select * from #mydata"); + watch.Stop(); + log.WriteLine($"async read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); + + // data with trailing errors + + watch = Stopwatch.StartNew(); + ex = Assert.ThrowsAny(() => conn.QueryFirst("select * from #mydata; raiserror('bad things', 16, 1)")); + watch.Stop(); + log.WriteLine($"sync read with error complete in {watch.ElapsedMilliseconds}ms; {ex.Message}"); + + watch = Stopwatch.StartNew(); + ex = await Assert.ThrowsAnyAsync(async () => await conn.QueryFirstAsync("select * from #mydata; raiserror('bad things', 16, 1)")); + watch.Stop(); + log.WriteLine($"async read with error complete in {watch.ElapsedMilliseconds}ms; {ex.Message}"); + + // unbuffered read with trailing errors - do not expect to see this unless we consume all! + + watch = Stopwatch.StartNew(); + row = conn.Query("select * from #mydata", buffered: false).First(); + watch.Stop(); + log.WriteLine($"sync unbuffered LINQ read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); + +#if NET5_0_OR_GREATER + watch = Stopwatch.StartNew(); + row = await conn.QueryUnbufferedAsync("select * from #mydata").FirstAsync(); + watch.Stop(); + log.WriteLine($"async unbuffered LINQ read first complete; row {row.Id} in {watch.ElapsedMilliseconds}ms"); +#endif + + static unsafe string CreateName(Random rand) + { + const string Alphabet = "abcdefghijklmnopqrstuvwxyz 0123456789,;-"; + var len = rand.Next(5, 251); + char* ptr = stackalloc char[len]; + for (int i = 0; i < len; i++) + { + ptr[i] = Alphabet[rand.Next(Alphabet.Length)]; + } + return new string(ptr, 0, len); + } + + } + + public class MyRow + { + public int Id { get; set; } + public string Name { get; set; } = ""; + } +} + +internal static class AsyncLinqHelper +{ + public static async ValueTask FirstAsync(this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + await using var iter = source.GetAsyncEnumerator(cancellationToken); + if (!await iter.MoveNextAsync()) Array.Empty().First(); // for consistent error + return iter.Current; + } +}