diff --git a/src/EFCore.Relational/Storage/RelationalDataReader.cs b/src/EFCore.Relational/Storage/RelationalDataReader.cs index f7282ef167f..f4ae8f34535 100644 --- a/src/EFCore.Relational/Storage/RelationalDataReader.cs +++ b/src/EFCore.Relational/Storage/RelationalDataReader.cs @@ -61,6 +61,7 @@ public virtual void Initialize( _reader = reader; _commandId = commandId; _logger = logger; + _readCount = 0; _disposed = false; _startTime = DateTimeOffset.UtcNow; _stopwatch.Restart(); diff --git a/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs b/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs index 6310664e71a..e9310702876 100644 --- a/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs +++ b/test/EFCore.Relational.Tests/Storage/RelationalCommandTest.cs @@ -356,6 +356,86 @@ public async Task Can_ExecuteReaderAsync() Assert.Equal(expectedCount, fakeDbConnection.CloseCount); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task Can_ExecuteReader_multiple_times(bool async) + { + var diagnosticEvents = new List>(); + + var logger = new RelationalCommandDiagnosticsLogger( + new ListLoggerFactory(), + new FakeLoggingOptions(false), + new ListDiagnosticSource(diagnosticEvents), + new TestRelationalLoggingDefinitions(), + new NullDbContextLogger(), + CreateOptions()); + + DbDataReader CreateDbDataReader() + => new FakeDbDataReader( + new[] { "Id", "Name" }, new List { new object[] { 1, "Foo" }, new object[] { 2, "Bar" } }); + + var fakeDbConnection = new FakeDbConnection( + ConnectionString, + new FakeCommandExecutor( + executeReader: (c, b) => CreateDbDataReader(), + executeReaderAsync: (c, b, ct) => Task.FromResult(CreateDbDataReader()))); + var optionsExtension = new FakeRelationalOptionsExtension().WithConnection(fakeDbConnection); + + var options = CreateOptions(optionsExtension); + + var relationalCommand = CreateRelationalCommand(); + + await using (var relationalReader = await ExecuteReader( + relationalCommand, + new RelationalCommandParameterObject(new FakeRelationalConnection(options), null, null, null, logger), async)) + { + var dbDataReader = relationalReader.DbDataReader; + + Assert.True(await Read(relationalReader, async)); + Assert.Equal(1, dbDataReader.GetInt32(0)); + Assert.Equal("Foo", dbDataReader.GetString(1)); + + Assert.True(await Read(relationalReader, async)); + Assert.Equal(2, dbDataReader.GetInt32(0)); + Assert.Equal("Bar", dbDataReader.GetString(1)); + + Assert.False(await Read(relationalReader, async)); + + diagnosticEvents.Clear(); + } + + var diagnostic = diagnosticEvents.Single(); + Assert.Equal(RelationalEventId.DataReaderDisposing.Name, diagnostic.Item1); + var dataReaderDisposingEventData = (DataReaderDisposingEventData)diagnostic.Item2; + Assert.Equal(3, dataReaderDisposingEventData.ReadCount); + + diagnosticEvents.Clear(); + + await using (var relationalReader = await ExecuteReader( + relationalCommand, + new RelationalCommandParameterObject(new FakeRelationalConnection(options), null, null, null, logger), async)) + { + var dbDataReader = relationalReader.DbDataReader; + + Assert.True(await Read(relationalReader, async)); + Assert.Equal(1, dbDataReader.GetInt32(0)); + Assert.Equal("Foo", dbDataReader.GetString(1)); + + Assert.True(await Read(relationalReader, async)); + Assert.Equal(2, dbDataReader.GetInt32(0)); + Assert.Equal("Bar", dbDataReader.GetString(1)); + + Assert.False(await Read(relationalReader, async)); + + diagnosticEvents.Clear(); + } + + diagnostic = diagnosticEvents.Single(); + Assert.Equal(RelationalEventId.DataReaderDisposing.Name, diagnostic.Item1); + dataReaderDisposingEventData = (DataReaderDisposingEventData)diagnostic.Item2; + Assert.Equal(3, dataReaderDisposingEventData.ReadCount); + } + public static TheoryData CommandActions => new TheoryData { @@ -1299,5 +1379,19 @@ private IRelationalCommand CreateRelationalCommand( TestServiceFactory.Instance.Create())), commandText, parameters ?? Array.Empty()); + + + private Task ExecuteReader( + IRelationalCommand relationalCommand, + RelationalCommandParameterObject parameterObject, + bool async) + => async + ? relationalCommand.ExecuteReaderAsync(parameterObject) + : Task.FromResult(relationalCommand.ExecuteReader(parameterObject)); + + private Task Read(RelationalDataReader relationalReader, bool async) + => async ? relationalReader.ReadAsync() : Task.FromResult(relationalReader.Read()); + + public static IEnumerable IsAsyncData = new[] { new object[] { false }, new object[] { true } }; } }