Skip to content

Commit

Permalink
Change in project structure (dotnet#385)
Browse files Browse the repository at this point in the history
* initial changes

* Change in project structure

* correcting test

* change variable name

* fix tests

* fix tests

* fix more tests

* fix codegen errors

* adde log file message

* changed name of args

* change variable names

* fix test
  • Loading branch information
srsaggam authored and Dmitry-A committed Aug 22, 2019
1 parent 49a13c1 commit 3056e3f
Show file tree
Hide file tree
Showing 28 changed files with 752 additions and 2,088 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
//*****************************************************************************************

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using TestNamespace.Model.DataModels;

namespace TestNamespace.Train
namespace TestNamespace.ConsoleApp
{
class Program
public static class ModelBuilder
{
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv";
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv";
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip";

static void Main(string[] args)
{
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
MLContext mlContext = new MLContext(seed: 1);
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);

public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TRAIN_DATA_FILEPATH,
Expand All @@ -49,9 +51,6 @@ namespace TestNamespace.Train

// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);

Console.WriteLine("=============== End of process, hit any key to finish ===============");
Console.ReadKey();
}

public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
Expand Down Expand Up @@ -83,7 +82,7 @@ namespace TestNamespace.Train
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
IDataView predictions = mlModel.Transform(testDataView);
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(predictions, "Label", "Score");
ConsoleHelper.PrintBinaryClassificationMetrics(metrics);
PrintBinaryClassificationMetrics(metrics);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
Expand All @@ -104,5 +103,47 @@ namespace TestNamespace.Train

return fullPath;
}

public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for binary classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
Console.WriteLine($"************************************************************");
}


public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);

var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
var AccuracyAverage = AccuracyValues.Average();
var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);
var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues);


Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Binary Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
}

public static double CalculateStandardDeviation(IEnumerable<double> values)
{
double average = values.Average();
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
return standardDeviation;
}

public static double CalculateConfidenceInterval95(IEnumerable<double> values)
{
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
return confidenceInterval95;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,74 +6,102 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using TestNamespace.Model.DataModels;

namespace TestNamespace.Train
namespace TestNamespace.ConsoleApp
{
public static class ConsoleHelper
public static class ModelBuilder
{
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv";
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv";
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip";

public static void PrintRegressionMetrics(RegressionMetrics metrics)
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);

public static void CreateModel()
{
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Metrics for regression model ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}");
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}");
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}");
Console.WriteLine($"*************************************************");
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);

IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
path: TEST_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);

// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);

// Evaluate quality of Model
EvaluateModel(mlContext, mlModel, testDataView);

// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
}

public static void PrintRegressionFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RegressionMetrics>> crossValidationResults)
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
.AppendCacheCheckpoint(mlContext);

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Regression model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
Console.WriteLine($"*************************************************************************************************************");
// Set the training algorithm
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label");
var trainingPipeline = dataProcessPipeline.Append(trainer);

return trainingPipeline;
}

public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics)
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for binary classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
Console.WriteLine($"************************************************************");
}
Console.WriteLine("=============== Training model ===============");

ITransformer model = trainingPipeline.Fit(trainingDataView);

public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
Console.WriteLine("=============== End of training process ===============");
return model;
}

private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
// Evaluate the model and show accuracy stats
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
IDataView predictions = mlModel.Transform(testDataView);
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score");
PrintMulticlassClassificationMetrics(metrics);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(mlModel, modelInputSchema, fs);

var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
var AccuracyAverage = AccuracyValues.Average();
var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);
var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues);
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}

public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;

Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Binary Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
string fullPath = Path.Combine(assemblyFolderPath, relativePath);

return fullPath;
}

public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
Expand Down
Loading

0 comments on commit 3056e3f

Please sign in to comment.