Skip to content

Commit 7c7aab3

Browse files
Ivanidzo4kaeerhardt
authored andcommitted
small fixes in ensembles (dotnet#442)
* separate diversity measurment by interface * proper friendly name for Weighted average * update manifest.json * missed loader signature
1 parent b25c9f8 commit 7c7aab3

File tree

15 files changed

+142
-95
lines changed

15 files changed

+142
-95
lines changed

src/Microsoft.ML.Ensemble/EntryPoints/DiversityMeasure.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@
1717
namespace Microsoft.ML.Ensemble.EntryPoints
1818
{
1919
[TlcModule.Component(Name = DisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
20-
public sealed class DisagreementDiversityFactory : ISupportDiversityMeasureFactory<Single>
20+
public sealed class DisagreementDiversityFactory : ISupportBinaryDiversityMeasureFactory
2121
{
22-
public IDiversityMeasure<float> CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure();
22+
public IBinaryDiversityMeasure CreateComponent(IHostEnvironment env) => new DisagreementDiversityMeasure();
2323
}
2424

2525
[TlcModule.Component(Name = RegressionDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
26-
public sealed class RegressionDisagreementDiversityFactory : ISupportDiversityMeasureFactory<Single>
26+
public sealed class RegressionDisagreementDiversityFactory : ISupportRegressionDiversityMeasureFactory
2727
{
28-
public IDiversityMeasure<float> CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure();
28+
public IRegressionDiversityMeasure CreateComponent(IHostEnvironment env) => new RegressionDisagreementDiversityMeasure();
2929
}
3030

3131
[TlcModule.Component(Name = MultiDisagreementDiversityMeasure.LoadName, FriendlyName = DisagreementDiversityMeasure.UserName)]
32-
public sealed class MultiDisagreementDiversityFactory : ISupportDiversityMeasureFactory<VBuffer<Single>>
32+
public sealed class MultiDisagreementDiversityFactory : ISupportMulticlassDiversityMeasureFactory
3333
{
34-
public IDiversityMeasure<VBuffer<Single>> CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure();
34+
public IMulticlassDiversityMeasure CreateComponent(IHostEnvironment env) => new MultiDisagreementDiversityMeasure();
3535
}
3636

3737
}

src/Microsoft.ML.Ensemble/OutputCombiners/IOutputCombiner.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public interface IMultiClassOutputCombiner : IOutputCombiner<VBuffer<Single>>
4646
{
4747
}
4848

49-
5049
[TlcModule.ComponentKind("EnsembleMulticlassOutputCombiner")]
5150
public interface ISupportMulticlassOutputCombinerFactory : IComponentFactory<IMultiClassOutputCombiner>
5251
{
@@ -63,7 +62,7 @@ public interface ISupportRegressionOutputCombinerFactory : IComponentFactory<IRe
6362
{
6463

6564
}
66-
65+
6766
public interface IWeightedAverager
6867
{
6968
string WeightageMetricName { get; }

src/Microsoft.ML.Ensemble/OutputCombiners/WeightedAverage.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ private static VersionInfo GetVersionInfo()
3535
loaderSignature: LoaderSignature);
3636
}
3737

38-
[TlcModule.Component(Name = LoadName, FriendlyName = Stacking.UserName)]
38+
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
3939
public sealed class Arguments: ISupportBinaryOutputCombinerFactory
4040
{
4141
[Argument(ArgumentType.AtMostOnce, HelpText = "The metric type to be used to find the weights for each model", ShortName = "wn", SortOrder = 50)]

src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/DisagreementDiversityMeasure.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
1414
{
15-
public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>
15+
public class DisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>, IBinaryDiversityMeasure
1616
{
1717
public const string UserName = "Disagreement Diversity Measure";
1818
public const string LoadName = "DisagreementDiversityMeasure";

src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/MultiDisagreementDiversityMeasure.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
1616
{
17-
public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<VBuffer<Single>>
17+
public class MultiDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<VBuffer<Single>>, IMulticlassDiversityMeasure
1818
{
1919
public const string LoadName = "MultiDisagreementDiversityMeasure";
2020

src/Microsoft.ML.Ensemble/Selector/DiversityMeasure/RegressionDisagreementDiversityMeasure.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure
1414
{
15-
public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>
15+
public class RegressionDisagreementDiversityMeasure : BaseDisagreementDiversityMeasure<Single>, IRegressionDiversityMeasure
1616
{
1717
public const string LoadName = "RegressionDisagreementDiversityMeasure";
1818

src/Microsoft.ML.Ensemble/Selector/IDiversityMeasure.cs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Concurrent;
67
using System.Collections.Generic;
8+
using Microsoft.ML.Runtime.Data;
79
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
810
using Microsoft.ML.Runtime.EntryPoints;
911

@@ -17,8 +19,25 @@ List<ModelDiversityMetric<TOutput>> CalculateDiversityMeasure(IList<FeatureSubse
1719

1820
public delegate void SignatureEnsembleDiversityMeasure();
1921

20-
[TlcModule.ComponentKind("EnsembleDiversityMeasure")]
21-
public interface ISupportDiversityMeasureFactory<TOutput> : IComponentFactory<IDiversityMeasure<TOutput>>
22+
public interface IBinaryDiversityMeasure : IDiversityMeasure<Single>
23+
{ }
24+
public interface IRegressionDiversityMeasure : IDiversityMeasure<Single>
25+
{ }
26+
public interface IMulticlassDiversityMeasure : IDiversityMeasure<VBuffer<Single>>
27+
{ }
28+
29+
[TlcModule.ComponentKind("EnsembleBinaryDiversityMeasure")]
30+
public interface ISupportBinaryDiversityMeasureFactory : IComponentFactory<IBinaryDiversityMeasure>
31+
{
32+
}
33+
34+
[TlcModule.ComponentKind("EnsembleRegressionDiversityMeasure")]
35+
public interface ISupportRegressionDiversityMeasureFactory : IComponentFactory<IRegressionDiversityMeasure>
36+
{
37+
}
38+
39+
[TlcModule.ComponentKind("EnsembleMulticlassDiversityMeasure")]
40+
public interface ISupportMulticlassDiversityMeasureFactory : IComponentFactory<IMulticlassDiversityMeasure>
2241
{
2342
}
2443
}

src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BaseDiverseSelector.cs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
using System;
66
using System.Collections.Concurrent;
77
using System.Collections.Generic;
8-
using Microsoft.ML.Runtime.CommandLine;
98
using Microsoft.ML.Runtime.Data;
109
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
11-
using Microsoft.ML.Runtime.Internal.Internallearn;
10+
using Microsoft.ML.Runtime.EntryPoints;
1211
using Microsoft.ML.Runtime.Internal.Utilities;
1312
using Microsoft.ML.Runtime.Training;
1413

@@ -19,27 +18,21 @@ public abstract class BaseDiverseSelector<TOutput, TDiversityMetric> : SubModelD
1918
{
2019
public abstract class DiverseSelectorArguments : ArgumentsBase
2120
{
22-
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
23-
[TGUI(Label = "Diversity Measure Type")]
24-
public ISupportDiversityMeasureFactory<TOutput> DiversityMetricType;
2521
}
2622

27-
private readonly ISupportDiversityMeasureFactory<TOutput> _diversityMetricType;
23+
private readonly IComponentFactory<IDiversityMeasure<TOutput>> _diversityMetricType;
2824
private ConcurrentDictionary<FeatureSubsetModel<IPredictorProducing<TOutput>>, TOutput[]> _predictions;
2925

30-
protected abstract ISupportDiversityMeasureFactory<TOutput> DefaultDiversityMetricType { get; }
31-
32-
protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name)
26+
protected internal BaseDiverseSelector(IHostEnvironment env, DiverseSelectorArguments args, string name,
27+
IComponentFactory<IDiversityMeasure<TOutput>> diversityMetricType)
3328
: base(args, env, name)
3429
{
35-
_diversityMetricType = args.DiversityMetricType;
30+
_diversityMetricType = diversityMetricType;
3631
_predictions = new ConcurrentDictionary<FeatureSubsetModel<IPredictorProducing<TOutput>>, TOutput[]>();
3732
}
3833

3934
protected IDiversityMeasure<TOutput> CreateDiversityMetric()
4035
{
41-
if (_diversityMetricType == null)
42-
return DefaultDiversityMetricType.CreateComponent(Host);
4336
return _diversityMetricType.CreateComponent(Host);
4437
}
4538

src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorBinary.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
using System.Collections.Generic;
88
using Microsoft.ML.Ensemble.EntryPoints;
99
using Microsoft.ML.Runtime;
10+
using Microsoft.ML.Runtime.CommandLine;
1011
using Microsoft.ML.Runtime.Ensemble.Selector;
1112
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
1213
using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
1314
using Microsoft.ML.Runtime.EntryPoints;
15+
using Microsoft.ML.Runtime.Internal.Internallearn;
1416

1517
[assembly: LoadableClass(typeof(BestDiverseSelectorBinary), typeof(BestDiverseSelectorBinary.Arguments),
1618
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorBinary.UserName, BestDiverseSelectorBinary.LoadName)]
@@ -24,17 +26,20 @@ public sealed class BestDiverseSelectorBinary : BaseDiverseSelector<Single, Disa
2426
public const string UserName = "Best Diverse Selector";
2527
public const string LoadName = "BestDiverseSelector";
2628

27-
protected override ISupportDiversityMeasureFactory<Single> DefaultDiversityMetricType => new DisagreementDiversityFactory();
28-
2929
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
3030
public sealed class Arguments : DiverseSelectorArguments, ISupportBinarySubModelSelectorFactory
3131
{
32+
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
33+
[TGUI(Label = "Diversity Measure Type")]
34+
public ISupportBinaryDiversityMeasureFactory DiversityMetricType = new DisagreementDiversityFactory();
35+
3236
public IBinarySubModelSelector CreateComponent(IHostEnvironment env) => new BestDiverseSelectorBinary(env, this);
3337
}
3438

3539
public BestDiverseSelectorBinary(IHostEnvironment env, Arguments args)
36-
: base(env, args, LoadName)
40+
: base(env, args, LoadName, args.DiversityMetricType)
3741
{
42+
3843
}
3944

4045
public override List<ModelDiversityMetric<Single>> CalculateDiversityMeasure(IList<FeatureSubsetModel<TScalarPredictor>> models,

src/Microsoft.ML.Ensemble/Selector/SubModelSelector/BestDiverseSelectorMultiClass.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
using System.Collections.Generic;
88
using Microsoft.ML.Ensemble.EntryPoints;
99
using Microsoft.ML.Runtime;
10+
using Microsoft.ML.Runtime.CommandLine;
1011
using Microsoft.ML.Runtime.Data;
1112
using Microsoft.ML.Runtime.Ensemble.Selector;
1213
using Microsoft.ML.Runtime.Ensemble.Selector.DiversityMeasure;
1314
using Microsoft.ML.Runtime.Ensemble.Selector.SubModelSelector;
1415
using Microsoft.ML.Runtime.EntryPoints;
16+
using Microsoft.ML.Runtime.Internal.Internallearn;
1517

1618
[assembly: LoadableClass(typeof(BestDiverseSelectorMultiClass), typeof(BestDiverseSelectorMultiClass.Arguments),
1719
typeof(SignatureEnsembleSubModelSelector), BestDiverseSelectorMultiClass.UserName, BestDiverseSelectorMultiClass.LoadName)]
@@ -24,16 +26,18 @@ public sealed class BestDiverseSelectorMultiClass : BaseDiverseSelector<VBuffer<
2426
{
2527
public const string UserName = "Best Diverse Selector";
2628
public const string LoadName = "BestDiverseSelectorMultiClass";
27-
protected override ISupportDiversityMeasureFactory<VBuffer<Single>> DefaultDiversityMetricType => new MultiDisagreementDiversityFactory();
2829

29-
[TlcModule.Component(Name = BestDiverseSelectorMultiClass.LoadName, FriendlyName = UserName)]
30+
[TlcModule.Component(Name = LoadName, FriendlyName = UserName)]
3031
public sealed class Arguments : DiverseSelectorArguments, ISupportMulticlassSubModelSelectorFactory
3132
{
33+
[Argument(ArgumentType.Multiple, HelpText = "The metric type to be used to find the diversity among base learners", ShortName = "dm", SortOrder = 50)]
34+
[TGUI(Label = "Diversity Measure Type")]
35+
public ISupportMulticlassDiversityMeasureFactory DiversityMetricType = new MultiDisagreementDiversityFactory();
3236
public IMulticlassSubModelSelector CreateComponent(IHostEnvironment env) => new BestDiverseSelectorMultiClass(env, this);
3337
}
3438

3539
public BestDiverseSelectorMultiClass(IHostEnvironment env, Arguments args)
36-
: base(env, args, LoadName)
40+
: base(env, args, LoadName, args.DiversityMetricType)
3741
{
3842
}
3943

0 commit comments

Comments
 (0)