From 81d0ba58394207d68f71689ea7e7f1ca48a55996 Mon Sep 17 00:00:00 2001 From: Daniel Costea Date: Thu, 1 Oct 2020 21:43:58 +0200 Subject: [PATCH] Add WriteCsv plus unit tests. (#2947) * Add WriteCsv plus unit tests. * Add CultureInfo to WriteCsv. Remove index column param. Update unit tests. * Add CR changes. CultureInfo. Separator. * Format decimal types individually. Fix culture info. Fix unit tests. * Format decimal types individually. Fix culture info. Fix unit tests. --- src/Microsoft.Data.Analysis/DataFrame.IO.cs | 117 +++++++++++++ .../DataFrame.IOTests.cs | 157 ++++++++++++++++++ 2 files changed, 274 insertions(+) diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index 084b66379d..54268172a0 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -4,8 +4,11 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.IO; +using System.Linq; using System.Text; +using Microsoft.ML; namespace Microsoft.Data.Analysis { @@ -306,5 +309,119 @@ public static DataFrame LoadCsv(Stream csvStream, return ret; } } + + /// + /// Writes a DataFrame into a CSV. + /// + /// + /// CSV file path + /// column separator + /// has a header or not + /// The character encoding. Defaults to UTF8 if not specified + /// culture info for formatting values + public static void WriteCsv(DataFrame dataFrame, string path, + char separator = ',', bool header = true, + Encoding encoding = null, CultureInfo cultureInfo = null) + { + using (FileStream csvStream = new FileStream(path, FileMode.Create)) + { + WriteCsv(dataFrame: dataFrame, csvStream: csvStream, + separator: separator, header: header, + encoding: encoding, cultureInfo: cultureInfo); + } + } + + /// + /// Writes a DataFrame into a CSV. + /// + /// + /// stream of CSV data to be write out + /// column separator + /// has a header or not + /// the character encoding. Defaults to UTF8 if not specified + /// culture info for formatting values + public static void WriteCsv(DataFrame dataFrame, Stream csvStream, + char separator = ',', bool header = true, + Encoding encoding = null, CultureInfo cultureInfo = null) + { + if (cultureInfo is null) + { + cultureInfo = CultureInfo.CurrentCulture; + } + + if (cultureInfo.NumberFormat.NumberDecimalSeparator.Equals(separator.ToString())) + { + throw new ArgumentException("Decimal separator cannot match the column separator"); + } + + if (encoding is null) + { + encoding = Encoding.ASCII; + } + + using (StreamWriter csvFile = new StreamWriter(csvStream, encoding, bufferSize: DefaultStreamReaderBufferSize, leaveOpen: true)) + { + if (dataFrame != null) + { + var columnNames = dataFrame.Columns.GetColumnNames(); + + if (header) + { + var headerColumns = string.Join(separator.ToString(), columnNames); + csvFile.WriteLine(headerColumns); + } + + var record = new StringBuilder(); + + foreach (var row in dataFrame.Rows) + { + bool firstRow = true; + foreach (var cell in row) + { + if (!firstRow) + { + record.Append(separator); + } + else + { + firstRow = false; + } + + Type t = cell?.GetType(); + + if (t == typeof(bool)) + { + record.AppendFormat(cultureInfo, "{0}", cell); + continue; + } + + if (t == typeof(float)) + { + record.AppendFormat(cultureInfo, "{0:G9}", cell); + continue; + } + + if (t == typeof(double)) + { + record.AppendFormat(cultureInfo, "{0:G17}", cell); + continue; + } + + if (t == typeof(decimal)) + { + record.AppendFormat(cultureInfo, "{0:G31}", cell); + continue; + } + + record.Append(cell); + } + + csvFile.WriteLine(record); + + record.Clear(); + } + } + } + } } } diff --git a/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 9bed04638c..b463068044 100644 --- a/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Globalization; using System.IO; +using System.Linq; using System.Text; using Apache.Arrow; using Xunit; @@ -603,5 +605,160 @@ Stream GetStream(string streamData) Assert.Null(df[2, 2]); Assert.Null(df[5, 3]); } + + [Fact] + public void TestWriteCsvWithHeader() + { + using MemoryStream csvStream = new MemoryStream(); + DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true); + + DataFrame.WriteCsv(dataFrame, csvStream); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + Assert.Equal(1F, readIn[1, 0]); + Assert.Equal(1F, readIn[1, 1]); + Assert.Equal(1F, readIn[1, 2]); + Assert.Equal(1F, readIn[1, 3]); + Assert.Equal(1F, readIn[1, 4]); + Assert.Equal(1F, readIn[1, 5]); + Assert.Equal(1F, readIn[1, 6]); + Assert.Equal(1F, readIn[1, 7]); + Assert.Equal(1F, readIn[1, 8]); + Assert.Equal(1F, readIn[1, 9]); + Assert.Equal(1F, readIn[1, 10]); + } + + [Fact] + public void TestWriteCsvWithCultureInfoRomanianAndSemiColon() + { + DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true); + dataFrame[1, 1] = 1.1M; + dataFrame[1, 2] = 1.2D; + dataFrame[1, 3] = 1.3F; + + using MemoryStream csvStream = new MemoryStream(); + var cultureInfo = new CultureInfo("ro-RO"); + var separator = ';'; + DataFrame.WriteCsv(dataFrame, csvStream, separator: separator, cultureInfo: cultureInfo); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream, separator: separator); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + Assert.Equal(1F, readIn[1, 0]); + + // LoadCsv does not support culture info, therefore decimal point comma (,) is seen as thousand separator and is ignored when read + Assert.Equal(11F, readIn[1, 1]); + Assert.Equal(12F, readIn[1, 2]); + Assert.Equal(129999992F, readIn[1, 3]); + + Assert.Equal(1F, readIn[1, 4]); + Assert.Equal(1F, readIn[1, 5]); + Assert.Equal(1F, readIn[1, 6]); + Assert.Equal(1F, readIn[1, 7]); + Assert.Equal(1F, readIn[1, 8]); + Assert.Equal(1F, readIn[1, 9]); + Assert.Equal(1F, readIn[1, 10]); + } + + [Fact] + public void TestWriteCsvWithCultureInfo() + { + using MemoryStream csvStream = new MemoryStream(); + DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true); + dataFrame[1, 1] = 1.1M; + dataFrame[1, 2] = 1.2D; + dataFrame[1, 3] = 1.3F; + + var cultureInfo = new CultureInfo("en-US"); + DataFrame.WriteCsv(dataFrame, csvStream, cultureInfo: cultureInfo); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + Assert.Equal(1F, readIn[1, 0]); + Assert.Equal(1.1F, readIn[1, 1]); + Assert.Equal(1.2F, readIn[1, 2]); + Assert.Equal(1.3F, readIn[1, 3]); + Assert.Equal(1F, readIn[1, 4]); + Assert.Equal(1F, readIn[1, 5]); + Assert.Equal(1F, readIn[1, 6]); + Assert.Equal(1F, readIn[1, 7]); + Assert.Equal(1F, readIn[1, 8]); + Assert.Equal(1F, readIn[1, 9]); + Assert.Equal(1F, readIn[1, 10]); + } + + [Fact] + public void TestWriteCsvWithCultureInfoRomanianAndComma() + { + using MemoryStream csvStream = new MemoryStream(); + DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true); + + var cultureInfo = new CultureInfo("ro-RO"); + var separator = cultureInfo.NumberFormat.NumberDecimalSeparator.First(); + + Assert.Throws(() => DataFrame.WriteCsv(dataFrame, csvStream, separator: separator, cultureInfo: cultureInfo)); + } + + [Fact] + public void TestWriteCsvWithNoHeader() + { + using MemoryStream csvStream = new MemoryStream(); + DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true); + + DataFrame.WriteCsv(dataFrame, csvStream, header: false); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream, header: false); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + Assert.Equal(1F, readIn[1, 0]); + Assert.Equal(1F, readIn[1, 1]); + Assert.Equal(1F, readIn[1, 2]); + Assert.Equal(1F, readIn[1, 3]); + Assert.Equal(1F, readIn[1, 4]); + Assert.Equal(1F, readIn[1, 5]); + Assert.Equal(1F, readIn[1, 6]); + Assert.Equal(1F, readIn[1, 7]); + Assert.Equal(1F, readIn[1, 8]); + Assert.Equal(1F, readIn[1, 9]); + Assert.Equal(1F, readIn[1, 10]); + } + + [Fact] + public void TestWriteCsvWithSemicolonSeparator() + { + using MemoryStream csvStream = new MemoryStream(); + DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true); + + var separator = ';'; + DataFrame.WriteCsv(dataFrame, csvStream, separator: separator); + + csvStream.Seek(0, SeekOrigin.Begin); + DataFrame readIn = DataFrame.LoadCsv(csvStream, separator: separator); + + Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count); + Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count); + Assert.Equal(1F, readIn[1, 0]); + Assert.Equal(1F, readIn[1, 1]); + Assert.Equal(1F, readIn[1, 2]); + Assert.Equal(1F, readIn[1, 3]); + Assert.Equal(1F, readIn[1, 4]); + Assert.Equal(1F, readIn[1, 5]); + Assert.Equal(1F, readIn[1, 6]); + Assert.Equal(1F, readIn[1, 7]); + Assert.Equal(1F, readIn[1, 8]); + Assert.Equal(1F, readIn[1, 9]); + Assert.Equal(1F, readIn[1, 10]); + } } }