Skip to content

Commit

Permalink
small code cleanup (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivanidzo4ka authored and justinormont committed May 30, 2018
1 parent c259863 commit 9d19d0e
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 46 deletions.
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Core/EntryPoints/ModuleCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public sealed class EntryPointInfo
public readonly Type[] OutputKinds;
public readonly ObsoleteAttribute ObsoleteAttribute;

internal EntryPointInfo(IExceptionContext ectx, MethodInfo method,
internal EntryPointInfo(IExceptionContext ectx, MethodInfo method,
TlcModule.EntryPointAttribute attribute, ObsoleteAttribute obsoleteAttribute)
{
Contracts.AssertValueOrNull(ectx);
Expand Down Expand Up @@ -187,7 +187,7 @@ private ModuleCatalog(IExceptionContext ectx)
if (attr == null)
continue;

var info = new EntryPointInfo(ectx, methodInfo, attr,
var info = new EntryPointInfo(ectx, methodInfo, attr,
methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute);

entryPoints.Add(info);
Expand Down Expand Up @@ -315,7 +315,7 @@ public bool TryFindComponent(Type interfaceType, Type argumentType, out Componen
Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface");
Contracts.CheckValue(argumentType, nameof(argumentType));

component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType);
component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType);
return component != null;
}

Expand Down
5 changes: 1 addition & 4 deletions src/Microsoft.ML.Data/EntryPoints/ScoreModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,8 @@ public static Output Score(IHostEnvironment env, Input input)
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);


IPredictor predictor;
var inputData = input.Data;
RoleMappedData data;
input.PredictorModel.PrepareData(host, inputData, out data, out predictor);
input.PredictorModel.PrepareData(host, inputData, out RoleMappedData data, out IPredictor predictor);

IDataView scoredPipe;
using (var ch = host.Start("Creating scoring pipeline"))
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.PipelineInference/AutoInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, S
testMetricVal += 1e-10;

// Save performance score
candidate.PerformanceSummary =
candidate.PerformanceSummary =
new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
_sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate);
_history.Add(candidate);
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.PipelineInference/AutoMlUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.ML.Runtime.PipelineInference
{
public static class AutoMlUtils
{
public static double ExtractValueFromIDV(IHostEnvironment env, IDataView result, string columnName)
public static double ExtractValueFromIdv(IHostEnvironment env, IDataView result, string columnName)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(result, nameof(result));
Expand All @@ -40,8 +40,8 @@ public static double ExtractValueFromIDV(IHostEnvironment env, IDataView result,

public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null)
{
double testingMetricValue = ExtractValueFromIDV(env, result, metricColumnName);
double trainingMetricValue = trainResult != null ? ExtractValueFromIDV(env, trainResult, metricColumnName) : double.MinValue;
double testingMetricValue = ExtractValueFromIdv(env, result, metricColumnName);
double trainingMetricValue = trainResult != null ? ExtractValueFromIdv(env, trainResult, metricColumnName) : double.MinValue;
return new AutoInference.RunSummary(testingMetricValue, 0, 0, trainingMetricValue);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.PipelineInference/PipelinePattern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ public void RunTrainTestExperiment(IDataView trainData, IDataView testData,

var dataOut = experiment.GetOutput(trainTestOutput.OverallMetrics);
var dataOutTraining = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics);
testMetricValue = AutoMlUtils.ExtractValueFromIDV(_env, dataOut, metric.Name);
trainMetricValue = AutoMlUtils.ExtractValueFromIDV(_env, dataOutTraining, metric.Name);
testMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOut, metric.Name);
trainMetricValue = AutoMlUtils.ExtractValueFromIdv(_env, dataOutTraining, metric.Name);
}

public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ public static ArrayITransformModelOutput MakeArray(IHostEnvironment env, ArrayIT
return result;
}


public sealed class ArrayIDataViewInput
{
[Argument(ArgumentType.Required, HelpText = "The data sets", SortOrder = 1)]
Expand Down
5 changes: 2 additions & 3 deletions src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public sealed class SubGraphOutput
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)]
public Var<IPredictorModel> PredictorModel;

[Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)]
public Var<ITransformModel> TransformModel;
}
Expand Down Expand Up @@ -104,7 +104,6 @@ public sealed class Output
public IDataView ConfusionMatrix;
}


public sealed class CombineMetricsInput
{
[Argument(ArgumentType.Multiple, HelpText = "Overall metrics datasets", SortOrder = 1)]
Expand Down Expand Up @@ -219,7 +218,7 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
}
else
args.Outputs.TransformModel = null;

// Set train/test trainer kind to match.
args.Kind = input.Kind;

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML/Runtime/EntryPoints/ImportTextData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static Output TextLoader(IHostEnvironment env, LoaderInput input)
var host = env.Register("ImportTextData");
env.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
var loader = host.CreateLoader(input.Arguments, new FileHandleSource(input.InputFile));
var loader = host.CreateLoader(input.Arguments, new FileHandleSource(input.InputFile));
return new Output { Data = loader };
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public sealed class SubGraphOutput
{
[Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)]
public Var<IPredictorModel> PredictorModel;

[Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)]
public Var<ITransformModel> TransformModel;
}
Expand Down Expand Up @@ -130,7 +130,7 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
if (!subGraphRunContext.TryGetVariable(varName, out dataVariable))
throw env.Except($"Invalid variable name '{varName}'.");

string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) :
string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) :
node.GetOutputVariableName(nameof(Output.PredictorModel));

foreach (var subGraphNode in subGraphNodes)
Expand Down Expand Up @@ -249,7 +249,7 @@ public static CommonOutputs.MacroOutput<Output> TrainTest(
var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings);
var evalNodeTraining = evalInputOutputTraining.Item1;
var evalOutputTraining = evalInputOutputTraining.Item2;
evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName :
evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName :
scoreNodeTrainingOutput.ScoredData.VarName;

if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName))
Expand Down
50 changes: 25 additions & 25 deletions src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ public static string GetJsonFromField(string fieldName, Type fieldType)
private readonly string _regenerate;
private readonly HashSet<string> _excludedSet;
private const string RegistrationName = "CSharpApiGenerator";
public Dictionary<string, string> _typesSymbolTable = new Dictionary<string, string>();
public Dictionary<string, string> TypesSymbolTable = new Dictionary<string, string>();

public CSharpApiGenerator(IHostEnvironment env, Arguments args, string regenerate)
{
Expand Down Expand Up @@ -612,7 +612,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>))
type = type.GetGenericArguments()[0];

if (_typesSymbolTable.ContainsKey(type.FullName))
if (TypesSymbolTable.ContainsKey(type.FullName))
continue;

if (!type.IsEnum)
Expand All @@ -625,13 +625,13 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu

var enumType = Enum.GetUnderlyingType(type);

_typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace);
TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace);
if (enumType == typeof(int))
writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}");
writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}");
else
{
Contracts.Assert(enumType == typeof(byte));
writer.WriteLine($"public enum {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)} : byte");
writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)} : byte");
}

writer.Write("{");
Expand Down Expand Up @@ -707,19 +707,19 @@ private void GenerateStructs(IndentingTextWriter writer,
if (typeEnum != TlcModule.DataKind.Unknown)
continue;

if (_typesSymbolTable.ContainsKey(type.FullName))
if (TypesSymbolTable.ContainsKey(type.FullName))
continue;

_typesSymbolTable[type.FullName] = GetSymbolFromType(_typesSymbolTable, type, currentNamespace);
TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace);
string classBase = "";
if (type.IsSubclassOf(typeof(OneToOneColumn)))
classBase = $" : OneToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn";
classBase = $" : OneToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn";
else if (type.IsSubclassOf(typeof(ManyToOneColumn)))
classBase = $" : ManyToOneColumn<{_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn";
writer.WriteLine($"public sealed partial class {_typesSymbolTable[type.FullName].Substring(_typesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}");
classBase = $" : ManyToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn";
writer.WriteLine($"public sealed partial class {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}");
writer.WriteLine("{");
writer.Indent();
GenerateInputFields(writer, type, catalog, _typesSymbolTable);
GenerateInputFields(writer, type, catalog, TypesSymbolTable);
writer.Outdent();
writer.WriteLine("}");
writer.WriteLine();
Expand Down Expand Up @@ -858,12 +858,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer,
writer.Indent();
if (isArray)
{
writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(source));");
writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source));");
writer.WriteLine($"{fieldName} = list.ToArray();");
}
else
writer.WriteLine($"{fieldName} = OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(source);");
writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source);");
writer.Outdent();
writer.WriteLine("}");
writer.WriteLine();
Expand All @@ -872,12 +872,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer,
writer.Indent();
if (isArray)
{
writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source));");
writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));");
writer.WriteLine($"{fieldName} = list.ToArray();");
}
else
writer.WriteLine($"{fieldName} = OneToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source);");
writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);");
writer.Outdent();
writer.WriteLine("}");
writer.WriteLine();
Expand Down Expand Up @@ -905,12 +905,12 @@ private void GenerateColumnAddMethods(IndentingTextWriter writer,
writer.Indent();
if (isArray)
{
writer.WriteLine($"var list = {fieldName} == null ? new List<{_typesSymbolTable[type.FullName]}>() : new List<{_typesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(ManyToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source));");
writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});");
writer.WriteLine($"list.Add(ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));");
writer.WriteLine($"{fieldName} = list.ToArray();");
}
else
writer.WriteLine($"{fieldName} = ManyToOneColumn<{_typesSymbolTable[type.FullName]}>.Create(name, source);");
writer.WriteLine($"{fieldName} = ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);");
writer.Outdent();
writer.WriteLine("}");
writer.WriteLine();
Expand Down Expand Up @@ -942,10 +942,10 @@ private void GenerateInput(IndentingTextWriter writer,
foreach (var line in entryPointInfo.Description.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries))
writer.WriteLine($"/// {line}");
writer.WriteLine("/// </summary>");
if(entryPointInfo.ObsoleteAttribute != null)

if (entryPointInfo.ObsoleteAttribute != null)
writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]");

writer.WriteLine($"public sealed partial class {classAndMethod.Item2}{classBase}");
writer.WriteLine("{");
writer.Indent();
Expand All @@ -955,7 +955,7 @@ private void GenerateInput(IndentingTextWriter writer,

GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, classAndMethod.Item2, out Type transformType);
writer.WriteLine();
GenerateInputFields(writer, entryPointInfo.InputType, catalog, _typesSymbolTable);
GenerateInputFields(writer, entryPointInfo.InputType, catalog, TypesSymbolTable);
writer.WriteLine();

GenerateOutput(writer, entryPointInfo, out HashSet<string> outputVariableNames);
Expand Down Expand Up @@ -1191,7 +1191,7 @@ private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.Compone
writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}");
writer.WriteLine("{");
writer.Indent();
GenerateInputFields(writer, component.ArgumentType, catalog, _typesSymbolTable, "Microsoft.ML.");
GenerateInputFields(writer, component.ArgumentType, catalog, TypesSymbolTable, "Microsoft.ML.");
writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";");
writer.Outdent();
writer.WriteLine("}");
Expand Down

0 comments on commit 9d19d0e

Please sign in to comment.