Skip to content

Commit

Permalink
Add WriteCsv plus unit tests. (dotnet#2947)
Browse files Browse the repository at this point in the history
* Add WriteCsv plus unit tests.

* Add CultureInfo to WriteCsv. Remove index column param. Update unit tests.

* Add CR changes. CultureInfo. Separator.

* Format decimal types individually. Fix culture info. Fix unit tests.

* Format decimal types individually. Fix culture info. Fix unit tests.
  • Loading branch information
dcostea authored Oct 1, 2020
1 parent 4e6d801 commit 81d0ba5
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 0 deletions.
117 changes: 117 additions & 0 deletions src/Microsoft.Data.Analysis/DataFrame.IO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.ML;

namespace Microsoft.Data.Analysis
{
Expand Down Expand Up @@ -306,5 +309,119 @@ public static DataFrame LoadCsv(Stream csvStream,
return ret;
}
}

/// <summary>
/// Writes a DataFrame into a CSV.
/// </summary>
/// <param name="dataFrame"><see cref="DataFrame"/></param>
/// <param name="path">CSV file path</param>
/// <param name="separator">column separator</param>
/// <param name="header">has a header or not</param>
/// <param name="encoding">The character encoding. Defaults to UTF8 if not specified</param>
/// <param name="cultureInfo">culture info for formatting values</param>
public static void WriteCsv(DataFrame dataFrame, string path,
char separator = ',', bool header = true,
Encoding encoding = null, CultureInfo cultureInfo = null)
{
using (FileStream csvStream = new FileStream(path, FileMode.Create))
{
WriteCsv(dataFrame: dataFrame, csvStream: csvStream,
separator: separator, header: header,
encoding: encoding, cultureInfo: cultureInfo);
}
}

/// <summary>
/// Writes a DataFrame into a CSV.
/// </summary>
/// <param name="dataFrame"><see cref="DataFrame"/></param>
/// <param name="csvStream">stream of CSV data to be write out</param>
/// <param name="separator">column separator</param>
/// <param name="header">has a header or not</param>
/// <param name="encoding">the character encoding. Defaults to UTF8 if not specified</param>
/// <param name="cultureInfo">culture info for formatting values</param>
public static void WriteCsv(DataFrame dataFrame, Stream csvStream,
char separator = ',', bool header = true,
Encoding encoding = null, CultureInfo cultureInfo = null)
{
if (cultureInfo is null)
{
cultureInfo = CultureInfo.CurrentCulture;
}

if (cultureInfo.NumberFormat.NumberDecimalSeparator.Equals(separator.ToString()))
{
throw new ArgumentException("Decimal separator cannot match the column separator");
}

if (encoding is null)
{
encoding = Encoding.ASCII;
}

using (StreamWriter csvFile = new StreamWriter(csvStream, encoding, bufferSize: DefaultStreamReaderBufferSize, leaveOpen: true))
{
if (dataFrame != null)
{
var columnNames = dataFrame.Columns.GetColumnNames();

if (header)
{
var headerColumns = string.Join(separator.ToString(), columnNames);
csvFile.WriteLine(headerColumns);
}

var record = new StringBuilder();

foreach (var row in dataFrame.Rows)
{
bool firstRow = true;
foreach (var cell in row)
{
if (!firstRow)
{
record.Append(separator);
}
else
{
firstRow = false;
}

Type t = cell?.GetType();

if (t == typeof(bool))
{
record.AppendFormat(cultureInfo, "{0}", cell);
continue;
}

if (t == typeof(float))
{
record.AppendFormat(cultureInfo, "{0:G9}", cell);
continue;
}

if (t == typeof(double))
{
record.AppendFormat(cultureInfo, "{0:G17}", cell);
continue;
}

if (t == typeof(decimal))
{
record.AppendFormat(cultureInfo, "{0:G31}", cell);
continue;
}

record.Append(cell);
}

csvFile.WriteLine(record);

record.Clear();
}
}
}
}
}
}
157 changes: 157 additions & 0 deletions tests/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
using Apache.Arrow;
using Xunit;
Expand Down Expand Up @@ -603,5 +605,160 @@ Stream GetStream(string streamData)
Assert.Null(df[2, 2]);
Assert.Null(df[5, 3]);
}

[Fact]
public void TestWriteCsvWithHeader()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true);

DataFrame.WriteCsv(dataFrame, csvStream);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(1F, readIn[1, 0]);
Assert.Equal(1F, readIn[1, 1]);
Assert.Equal(1F, readIn[1, 2]);
Assert.Equal(1F, readIn[1, 3]);
Assert.Equal(1F, readIn[1, 4]);
Assert.Equal(1F, readIn[1, 5]);
Assert.Equal(1F, readIn[1, 6]);
Assert.Equal(1F, readIn[1, 7]);
Assert.Equal(1F, readIn[1, 8]);
Assert.Equal(1F, readIn[1, 9]);
Assert.Equal(1F, readIn[1, 10]);
}

[Fact]
public void TestWriteCsvWithCultureInfoRomanianAndSemiColon()
{
DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true);
dataFrame[1, 1] = 1.1M;
dataFrame[1, 2] = 1.2D;
dataFrame[1, 3] = 1.3F;

using MemoryStream csvStream = new MemoryStream();
var cultureInfo = new CultureInfo("ro-RO");
var separator = ';';
DataFrame.WriteCsv(dataFrame, csvStream, separator: separator, cultureInfo: cultureInfo);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream, separator: separator);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(1F, readIn[1, 0]);

// LoadCsv does not support culture info, therefore decimal point comma (,) is seen as thousand separator and is ignored when read
Assert.Equal(11F, readIn[1, 1]);
Assert.Equal(12F, readIn[1, 2]);
Assert.Equal(129999992F, readIn[1, 3]);

Assert.Equal(1F, readIn[1, 4]);
Assert.Equal(1F, readIn[1, 5]);
Assert.Equal(1F, readIn[1, 6]);
Assert.Equal(1F, readIn[1, 7]);
Assert.Equal(1F, readIn[1, 8]);
Assert.Equal(1F, readIn[1, 9]);
Assert.Equal(1F, readIn[1, 10]);
}

[Fact]
public void TestWriteCsvWithCultureInfo()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true);
dataFrame[1, 1] = 1.1M;
dataFrame[1, 2] = 1.2D;
dataFrame[1, 3] = 1.3F;

var cultureInfo = new CultureInfo("en-US");
DataFrame.WriteCsv(dataFrame, csvStream, cultureInfo: cultureInfo);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(1F, readIn[1, 0]);
Assert.Equal(1.1F, readIn[1, 1]);
Assert.Equal(1.2F, readIn[1, 2]);
Assert.Equal(1.3F, readIn[1, 3]);
Assert.Equal(1F, readIn[1, 4]);
Assert.Equal(1F, readIn[1, 5]);
Assert.Equal(1F, readIn[1, 6]);
Assert.Equal(1F, readIn[1, 7]);
Assert.Equal(1F, readIn[1, 8]);
Assert.Equal(1F, readIn[1, 9]);
Assert.Equal(1F, readIn[1, 10]);
}

[Fact]
public void TestWriteCsvWithCultureInfoRomanianAndComma()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = MakeDataFrameWithNumericColumns(10, true);

var cultureInfo = new CultureInfo("ro-RO");
var separator = cultureInfo.NumberFormat.NumberDecimalSeparator.First();

Assert.Throws<ArgumentException>(() => DataFrame.WriteCsv(dataFrame, csvStream, separator: separator, cultureInfo: cultureInfo));
}

[Fact]
public void TestWriteCsvWithNoHeader()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true);

DataFrame.WriteCsv(dataFrame, csvStream, header: false);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream, header: false);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(1F, readIn[1, 0]);
Assert.Equal(1F, readIn[1, 1]);
Assert.Equal(1F, readIn[1, 2]);
Assert.Equal(1F, readIn[1, 3]);
Assert.Equal(1F, readIn[1, 4]);
Assert.Equal(1F, readIn[1, 5]);
Assert.Equal(1F, readIn[1, 6]);
Assert.Equal(1F, readIn[1, 7]);
Assert.Equal(1F, readIn[1, 8]);
Assert.Equal(1F, readIn[1, 9]);
Assert.Equal(1F, readIn[1, 10]);
}

[Fact]
public void TestWriteCsvWithSemicolonSeparator()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = MakeDataFrameWithAllColumnTypes(10, true);

var separator = ';';
DataFrame.WriteCsv(dataFrame, csvStream, separator: separator);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream, separator: separator);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
Assert.Equal(1F, readIn[1, 0]);
Assert.Equal(1F, readIn[1, 1]);
Assert.Equal(1F, readIn[1, 2]);
Assert.Equal(1F, readIn[1, 3]);
Assert.Equal(1F, readIn[1, 4]);
Assert.Equal(1F, readIn[1, 5]);
Assert.Equal(1F, readIn[1, 6]);
Assert.Equal(1F, readIn[1, 7]);
Assert.Equal(1F, readIn[1, 8]);
Assert.Equal(1F, readIn[1, 9]);
Assert.Equal(1F, readIn[1, 10]);
}
}
}

0 comments on commit 81d0ba5

Please sign in to comment.