Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support SparkML model -- FPM #1096

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.E2ETest.IpcTests.ML.Feature;
using Microsoft.Spark.ML.Feature.Param;
using Microsoft.Spark.ML.Fpm;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Microsoft.Spark.Utils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Fpm
{
[Collection("Spark E2E Tests")]
public class FPGrowthModelTests : FeatureBaseTests<FPGrowthModel>
{
private readonly SparkSession _spark;

public FPGrowthModelTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="FPGrowthModel"/> and test the
/// available methods. Test the FeatureBase methods using <see cref="FeatureBaseTests"/>.
/// </summary>
[Fact]
public void TestFPGrowthModel()
{
var fpGrowth = new FPGrowth();
fpGrowth.SetMinSupport(0.2)
.SetMinConfidence(0.7);

DataFrame dataFrame = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] { new string[] { "r", "z", "h", "k", "p" }}),
new GenericRow(new object[] { new string[] { "z", "y", "x", "w", "v", "u", "t", "s" }}),
new GenericRow(new object[] { new string[] { "s", "x", "o", "n", "r" }}),
new GenericRow(new object[] { new string[] { "x", "z", "y", "m", "t", "s", "q", "e" }}),
new GenericRow(new object[] { new string[] { "z" }}),
new GenericRow(new object[] { new string[] { "x", "z", "y", "r", "q", "t", "p" }}),
},
new StructType(new List<StructField>
{
new StructField("items", new ArrayType(new StringType())),
}));

FPGrowthModel fpm = fpGrowth.Fit(dataFrame);
fpm.SetPredictionCol("newPrediction");
Assert.Equal(0.2, fpm.GetMinSupport());
Assert.Equal(0.7, fpm.GetMinConfidence());

DataFrame newData = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] { new string[] {"t", "s"}})
},
new StructType(new List<StructField>
{
new StructField("items", new ArrayType(new StringType())),
})
);
var prediction = TypeConverter.ConvertTo<string[]>(
fpm.Transform(newData).Select("newPrediction").First().Values[0]);
Array.Sort(prediction);
Assert.Equal(prediction, new string[]{ "x", "y", "z"});

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "fpm");
fpm.Save(savePath);

FPGrowthModel loadedFPGrowthModel = FPGrowthModel.Load(savePath);
Assert.Equal(fpm.Uid(), loadedFPGrowthModel.Uid());
var newPrediction = TypeConverter.ConvertTo<string[]>(loadedFPGrowthModel
.Transform(newData).Select("newPrediction").First().Values[0]);
Array.Sort(newPrediction);
Assert.Equal(new string[]{ "x", "y", "z"}, newPrediction);
}

TestFeatureBase(fpm, "itemsCol", "items");
TestFeatureBase(fpm, "minConfidence", 0.7);
TestFeatureBase(fpm, "minSupport", 0.2);
TestFeatureBase(fpm, "numPartitions", 2);
TestFeatureBase(fpm, "predictionCol", "prediction");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.E2ETest.IpcTests.ML.Feature;
using Microsoft.Spark.ML.Feature.Param;
using Microsoft.Spark.ML.Fpm;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Microsoft.Spark.Utils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Fpm
{
[Collection("Spark E2E Tests")]
public class FPGrowthTests : FeatureBaseTests<FPGrowth>
{
private readonly SparkSession _spark;

public FPGrowthTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="FPGrowth"/> and test the
/// available methods. Test the FeatureBase methods using <see cref="FeatureBaseTests"/>.
/// </summary>
[Fact]
public void TestFPGrowth()
{

double minSupport = 0.2;
double minConfidence = 0.7;

var fpGrowth = new FPGrowth();
fpGrowth.SetMinSupport(minSupport)
.SetMinConfidence(minConfidence);

Assert.Equal(minSupport, fpGrowth.GetMinSupport());
Assert.Equal(minConfidence, fpGrowth.GetMinConfidence());

DataFrame dataFrame = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] { new string[] { "r", "z", "h", "k", "p" }}),
},
new StructType(new List<StructField>
{
new StructField("items", new ArrayType(new StringType())),
}));

FPGrowthModel fpm = fpGrowth.Fit(dataFrame);

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "fpgrowth");
fpGrowth.Save(savePath);

FPGrowth loadedFPGrowth = FPGrowth.Load(savePath);
Assert.Equal(fpGrowth.Uid(), loadedFPGrowth.Uid());
}

TestFeatureBase(fpGrowth, "itemsCol", "items");
TestFeatureBase(fpGrowth, "numPartitions", 2);
}

}
}
170 changes: 170 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Fpm/FPGrowth.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;

namespace Microsoft.Spark.ML.Fpm
{
/// <summary>
/// <see cref="FPGrowth"/> implements FPGrowth
/// </summary>
public class FPGrowth : JavaEstimator<FPGrowthModel>, IJavaMLWritable, IJavaMLReadable<FPGrowth>
{

private static readonly string s_className = "org.apache.spark.ml.fpm.FPGrowth";

/// <summary>
/// Creates a <see cref="FPGrowth"/> without any parameters.
/// </summary>
public FPGrowth() : base(s_className)
{
}

/// <summary>
/// Creates a <see cref="FPGrowth"/> with a UID that is used to give the
/// <see cref="FPGrowth"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public FPGrowth(string uid) : base(s_className, uid)
{
}

internal FPGrowth(JvmObjectReference jvmObject) : base(jvmObject)
{
}

/// <summary>
/// Sets value for itemsCol
/// </summary>
/// <param name="value">
/// items column name
/// </param>
/// <returns> New FPGrowth object </returns>
public FPGrowth SetItemsCol(string value) =>
WrapAsFPGrowth(Reference.Invoke("setItemsCol", (object)value));

/// <summary>
/// Sets value for minConfidence
/// </summary>
/// <param name="value">
/// minimal confidence for generating Association Rule
/// </param>
/// <returns> New FPGrowth object </returns>
public FPGrowth SetMinConfidence(double value) =>
WrapAsFPGrowth(Reference.Invoke("setMinConfidence", (object)value));

/// <summary>
/// Sets value for minSupport
/// </summary>
/// <param name="value">
/// the minimal support level of a frequent pattern
/// </param>
/// <returns> New FPGrowth object </returns>
public FPGrowth SetMinSupport(double value) =>
WrapAsFPGrowth(Reference.Invoke("setMinSupport", (object)value));

/// <summary>
/// Sets value for numPartitions
/// </summary>
/// <param name="value">
/// Number of partitions used by parallel FP-growth
/// </param>
/// <returns> New FPGrowth object </returns>
public FPGrowth SetNumPartitions(int value) =>
WrapAsFPGrowth(Reference.Invoke("setNumPartitions", (object)value));

/// <summary>
/// Sets value for predictionCol
/// </summary>
/// <param name="value">
/// prediction column name
/// </param>
/// <returns> New FPGrowth object </returns>
public FPGrowth SetPredictionCol(string value) =>
WrapAsFPGrowth(Reference.Invoke("setPredictionCol", (object)value));
/// <summary>
/// Gets itemsCol value
/// </summary>
/// <returns>
/// itemsCol: items column name
/// </returns>
public string GetItemsCol() =>
(string)Reference.Invoke("getItemsCol");

/// <summary>
/// Gets minConfidence value
/// </summary>
/// <returns>
/// minConfidence: minimal confidence for generating Association Rule
/// </returns>
public double GetMinConfidence() =>
(double)Reference.Invoke("getMinConfidence");

/// <summary>
/// Gets minSupport value
/// </summary>
/// <returns>
/// minSupport: the minimal support level of a frequent pattern
/// </returns>
public double GetMinSupport() =>
(double)Reference.Invoke("getMinSupport");

/// <summary>
/// Gets numPartitions value
/// </summary>
/// <returns>
/// numPartitions: Number of partitions used by parallel FP-growth
/// </returns>
public int GetNumPartitions() =>
(int)Reference.Invoke("getNumPartitions");

/// <summary>
/// Gets predictionCol value
/// </summary>
/// <returns>
/// predictionCol: prediction column name
/// </returns>
public string GetPredictionCol() =>
(string)Reference.Invoke("getPredictionCol");
/// <summary>Fits a model to the input data.</summary>
/// <param name="dataset">The <see cref="DataFrame"/> to fit the model to.</param>
/// <returns><see cref="FPGrowthModel"/></returns>
override public FPGrowthModel Fit(DataFrame dataset) =>
new FPGrowthModel(
(JvmObjectReference)Reference.Invoke("fit", dataset));
/// <summary>
/// Loads the <see cref="FPGrowth"/> that was previously saved using Save(string).
/// </summary>
/// <param name="path">The path the previous <see cref="FPGrowth"/> was saved to</param>
/// <returns>New <see cref="FPGrowth"/> object, loaded from path.</returns>
public static FPGrowth Load(string path) => WrapAsFPGrowth(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(s_className, "load", path));

/// <summary>
/// Saves the object so that it can be loaded later using Load. Note that these objects
/// can be shared with Scala by Loading or Saving in Scala.
/// </summary>
/// <param name="path">The path to save the object to</param>
public void Save(string path) => Reference.Invoke("save", path);

/// <returns>a <see cref="JavaMLWriter"/> instance for this ML instance.</returns>
public JavaMLWriter Write() =>
new JavaMLWriter((JvmObjectReference)Reference.Invoke("write"));

/// <summary>
/// Get the corresponding JavaMLReader instance.
/// </summary>
/// <returns>an <see cref="JavaMLReader&lt;FPGrowth&gt;"/> instance for this ML instance.</returns>
public JavaMLReader<FPGrowth> Read() =>
new JavaMLReader<FPGrowth>((JvmObjectReference)Reference.Invoke("read"));
private static FPGrowth WrapAsFPGrowth(object obj) =>
new FPGrowth((JvmObjectReference)obj);

}
}


Loading