diff --git a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs index ed092be157..c074d0d1f3 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -307,7 +307,7 @@ private ValueGetter CreateByteGetterDelegate(ColInfo colInfo) private ValueGetter CreateDateTimeGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref DateTime value) => value = DataReader.GetDateTime(columnIndex); + return (ref DateTime value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetDateTime(columnIndex); } private ValueGetter CreateDoubleGetterDelegate(ColInfo colInfo) diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index 21b6d96d19..4f7ebef980 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -7,7 +7,9 @@ using System.Data.SqlClient; using System.Data.SQLite; using System.IO; +using System.Linq; using System.Runtime.InteropServices; +using FluentAssertions; using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; @@ -222,6 +224,38 @@ public void IrisSdcaMaximumEntropy() }).PredictedLabel); } + [X86X64FactAttribute("The SQLite un-managed code, SQLite.interop, only supports x86/x64 architectures.")] + public void TestLoadDatetimeColumnWithNullValue() + { + var connectionString = "DataSource=Dummy;Mode=Memory;Version=3;Timeout=120;Cache=Shared"; + using (var connection = new SQLiteConnection(connectionString)) + { + connection.Open(); + using (var command = new SQLiteCommand(connection)) + { + command.CommandText = """ + BEGIN; + CREATE TABLE IF NOT EXISTS Datetime (datetime Datetime NULL); + INSERT INTO Datetime VALUES (NULL); + INSERT INTO Datetime VALUES ('2018-01-01 00:00:00'); + COMMIT; + """; + command.ExecuteNonQuery(); + } + } + var mlContext = new MLContext(seed: 1); + var loader = mlContext.Data.CreateDatabaseLoader(new DatabaseLoader.Column("datetime", DbType.DateTime, 0)); + var source = new DatabaseSource(SQLiteFactory.Instance, connectionString, "SELECT datetime FROM Datetime"); + var data = loader.Load(source); + var datetimes = data.GetColumn("datetime").ToArray(); + datetimes.Count().Should().Be(2); + + // Convert null value to DateTime.MinValue, aka 0001-01-01 00:00:00 + // This is the default behavior of TextLoader as well. + datetimes[0].Should().Be(DateTime.MinValue); + datetimes[1].Should().Be(new DateTime(2018, 1, 1, 0, 0, 0)); + } + /// /// Non-Windows builds do not support SqlClientFactory/MSSQL databases. Hence, an equivalent /// SQLite database is used on Linux and MacOS builds.