Skip to content

Commit

Permalink
Fix Saving csv with VBufferDataFrameColumn
Browse files Browse the repository at this point in the history
  • Loading branch information
asmirnov82 committed Oct 12, 2023
1 parent 64d7ebd commit a981605
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 68 deletions.
103 changes: 51 additions & 52 deletions src/Microsoft.Data.Analysis/DataFrame.IO.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
Expand All @@ -11,6 +12,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.Data;

namespace Microsoft.Data.Analysis
{
Expand Down Expand Up @@ -675,58 +677,7 @@ public static void SaveCsv(DataFrame dataFrame, Stream csvStream,

foreach (var row in dataFrame.Rows)
{
bool firstCell = true;
foreach (var cell in row)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}

Type t = cell?.GetType();

if (t == typeof(bool))
{
record.AppendFormat(cultureInfo, "{0}", cell);
continue;
}

if (t == typeof(float))
{
record.AppendFormat(cultureInfo, "{0:G9}", cell);
continue;
}

if (t == typeof(double))
{
record.AppendFormat(cultureInfo, "{0:G17}", cell);
continue;
}

if (t == typeof(decimal))
{
record.AppendFormat(cultureInfo, "{0:G31}", cell);
continue;
}

if (t == typeof(string))
{
string stringCell = (string)cell;
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
}

record.Append(cell);
}
AppendValuesToRecord(record, row, separator, cultureInfo);

csvFile.WriteLine(record);

Expand All @@ -736,6 +687,54 @@ public static void SaveCsv(DataFrame dataFrame, Stream csvStream,
}
}

private static void AppendValuesToRecord(StringBuilder record, IEnumerable values, char separator, CultureInfo cultureInfo)
{
bool firstCell = true;
foreach (var value in values)
{
if (!firstCell)
{
record.Append(separator);
}
else
{
firstCell = false;
}

switch (value)
{
case bool:
record.AppendFormat(cultureInfo, "{0}", value);
continue;
case float:
record.AppendFormat(cultureInfo, "{0:G9}", value);
continue;
case double:
record.AppendFormat(cultureInfo, "{0:G17}", value);
continue;
case decimal:
record.AppendFormat(cultureInfo, "{0:G31}", value);
continue;
case string stringCell:
if (NeedsQuotes(stringCell, separator))
{
record.Append('\"');
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
record.Append('\"');
continue;
}
break;
case IEnumerable nestedValues:
record.Append("(");
AppendValuesToRecord(record, nestedValues, ' ', cultureInfo);
record.Append(")");
continue;
}

record.Append(value);
}
}

private static void SaveHeader(StreamWriter csvFile, IReadOnlyList<string> columnNames, char separator)
{
bool firstColumn = true;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ protected virtual PrimitiveDataFrameColumn<T> CreateNewColumn(string name, long
return new PrimitiveDataFrameColumn<T>(name, length);
}

internal T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
protected T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];

protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex);

Expand Down
11 changes: 10 additions & 1 deletion src/Microsoft.ML.DataView/VBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.ML.Internal.DataView;
Expand All @@ -27,7 +28,7 @@ namespace Microsoft.ML.Data
/// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of
/// buffers is not possible. But, things like <see cref="int"/>, <see cref="float"/>, and <see
/// cref="ReadOnlyMemory{Char}"/>, are totally fine.</typeparam>
public readonly struct VBuffer<T>
public readonly struct VBuffer<T> : IEnumerable
{
/// <summary>
/// The internal re-usable array of values.
Expand Down Expand Up @@ -403,6 +404,14 @@ public T GetItemOrDefault(int index)
public override string ToString()
=> IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values";

/// <summary>
/// Returns an enumerator that iterates through the values in VBuffer.
/// </summary>
public IEnumerator GetEnumerator()
{
return _values.GetEnumerator();
}

internal VBufferEditor<T> GetEditor()
{
return GetEditor(Length, _count);
Expand Down
58 changes: 44 additions & 14 deletions test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Xunit;
using Microsoft.ML.TestFramework.Attributes;
using System.Threading;
using Microsoft.ML.Data;

namespace Microsoft.Data.Analysis.Tests
{
Expand Down Expand Up @@ -273,7 +274,7 @@ void ReducedRowsTest(DataFrame reducedRows)
[Theory]
[InlineData(false)]
[InlineData(true)]
public void TestReadCsvNoHeader(bool useQuotes)
public void TestLoadCsvNoHeader(bool useQuotes)
{
string CMT = useQuotes ? @"""C,MT""" : "CMT";
string verifyCMT = useQuotes ? "C,MT" : "CMT";
Expand Down Expand Up @@ -349,7 +350,7 @@ void VerifyDataFrameWithNamedColumnsAndDataTypes(DataFrame df, bool verifyColumn
[InlineData(false, 0)]
[InlineData(true, 10)]
[InlineData(false, 10)]
public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows)
public void TestLoadCsvWithTypesAndGuessRows(bool header, int guessRows)
{
/* Tests this matrix
*
Expand Down Expand Up @@ -472,7 +473,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithTypesDateTime()
public void TestLoadCsvWithTypesDateTime()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date
CMT,1,1,1271,3.8,CRD,17.5,1-june-2020
Expand Down Expand Up @@ -549,7 +550,7 @@ void Verify(DataFrame df, bool verifyDataTypes)
}

[Fact]
public void TestReadCsvWithPipeSeparator()
public void TestLoadCsvWithPipeSeparator()
{
string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount
CMT|1|1|1271|3.8|CRD|17.5
Expand Down Expand Up @@ -588,7 +589,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithSemicolonSeparator()
public void TestLoadCsvWithSemicolonSeparator()
{
string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount
CMT;1;1;1271;3.8;CRD;17.5
Expand Down Expand Up @@ -627,7 +628,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithExtraColumnInHeader()
public void TestLoadCsvWithExtraColumnInHeader()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra
CMT,1,1,1271,3.8,CRD,17.5
Expand Down Expand Up @@ -656,7 +657,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0
Expand All @@ -671,7 +672,7 @@ public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
}

[Fact]
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
CMT,1,1,1271,3.8,CRD,17.5,0
Expand Down Expand Up @@ -713,7 +714,7 @@ public void TestLoadCsvWithAddIndexColumn()
}

[Fact]
public void TestReadCsvWithExtraColumnInRow()
public void TestLoadCsvWithExtraColumnInRow()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD,17.5,0
Expand All @@ -726,7 +727,7 @@ public void TestReadCsvWithExtraColumnInRow()
}

[Fact]
public void TestReadCsvWithLessColumnsInRow()
public void TestLoadCsvWithLessColumnsInRow()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
CMT,1,1,1271,3.8,CRD
Expand Down Expand Up @@ -755,7 +756,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithAllNulls()
public void TestLoadCsvWithAllNulls()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,null,null,null
Expand Down Expand Up @@ -798,7 +799,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithNullsAndDataTypes()
public void TestLoadCsvWithNullsAndDataTypes()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271
Expand Down Expand Up @@ -860,7 +861,7 @@ void Verify(DataFrame df)
}

[Fact]
public void TestReadCsvWithNulls()
public void TestLoadCsvWithNulls()
{
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
null,1,1,1271
Expand Down Expand Up @@ -922,7 +923,36 @@ void Verify(DataFrame df)
}

[Fact]
public void TestWriteCsvWithHeader()
public void TestSaveCsvVBufferColumn()
{
var vBuffers = new[]
{
new VBuffer<int> (3, new int[] { 1, 2, 3 }),
new VBuffer<int> (3, new int[] { 2, 3, 4 }),
new VBuffer<int> (3, new int[] { 3, 4, 5 }),
};

var vBufferColumn = new VBufferDataFrameColumn<int>("VBuffer", vBuffers);
DataFrame dataFrame = new DataFrame(vBufferColumn);

using MemoryStream csvStream = new MemoryStream();

DataFrame.SaveCsv(dataFrame, csvStream);

csvStream.Seek(0, SeekOrigin.Begin);
DataFrame readIn = DataFrame.LoadCsv(csvStream);

Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);

Assert.Equal(typeof(string), readIn.Columns[0].DataType);
Assert.Equal("(1 2 3)", readIn[0, 0]);
Assert.Equal("(2 3 4)", readIn[1, 0]);
Assert.Equal("(3 4 5)", readIn[2, 0]);
}

[Fact]
public void TestSaveCsvWithHeader()
{
using MemoryStream csvStream = new MemoryStream();
DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true);
Expand Down

0 comments on commit a981605

Please sign in to comment.