diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 58e24041f1..3d285d06d8 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -97,6 +97,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -329,6 +333,22 @@ Global
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release|Any CPU.Build.0 = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{3E4ABF07-7970-4BE6-B45B-A13D3C397545}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release|Any CPU.Build.0 = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -367,6 +387,8 @@ Global
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
+ {A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/build/Dependencies.props b/build/Dependencies.props
index 5325011f05..79ae31c598 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -8,5 +8,7 @@
4.3.0
1.0.0-beta-62824-02
2.1.2.2
+ 0.0.0.5
+ 4.5.0
diff --git a/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj
new file mode 100644
index 0000000000..a531c8e403
--- /dev/null
+++ b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET additional learners making use of hardware acceleration. They depend on the MlNetMklDeps NuGet package.
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj
new file mode 100644
index 0000000000..248ae82414
--- /dev/null
+++ b/pkg/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj
new file mode 100644
index 0000000000..8bdef45d07
--- /dev/null
+++ b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj
@@ -0,0 +1,13 @@
+
+
+
+ netstandard2.0
+ ML.NET component for Image support
+
+
+
+
+
+
+
+
diff --git a/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj
new file mode 100644
index 0000000000..b36800ea0b
--- /dev/null
+++ b/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.symbols.nupkgproj
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
index 341e3a72af..e940ea9d4d 100644
--- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
+++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
@@ -198,6 +198,7 @@ private Delegate CreateGetter(int index)
Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>));
Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType);
del = CreateDirectVBufferGetterDelegate;
+ genericType = colType.ItemType.RawType;
}
else if (colType.IsPrimitive)
{
diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs
index f6ebaf687f..cd8198e14d 100644
--- a/src/Microsoft.ML.Api/TypedCursor.cs
+++ b/src/Microsoft.ML.Api/TypedCursor.cs
@@ -349,6 +349,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit
Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer<>));
Ch.Assert(fieldType.GetGenericArguments()[0] == colType.ItemType.RawType);
del = CreateVBufferToVBufferSetter;
+ genericType = colType.ItemType.RawType;
}
else if (colType.IsPrimitive)
{
diff --git a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
index 405e207773..64fe3b5b80 100644
--- a/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
+++ b/src/Microsoft.ML.Core/CommandLine/ArgumentAttribute.cs
@@ -34,6 +34,7 @@ public enum VisibilityType
private string _specialPurpose;
private VisibilityType _visibility;
private string _name;
+ private Type _signatureType;
///
/// Allows control of command line parsing.
@@ -139,5 +140,11 @@ public bool IsRequired
{
get { return ArgumentType.Required == (_type & ArgumentType.Required); }
}
+
+ public Type SignatureType
+ {
+ get { return _signatureType; }
+ set { _signatureType = value; }
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
index eb85fcce12..191321213f 100644
--- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
+++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs
@@ -249,6 +249,18 @@ public enum SettingsFlags
Default = ShortNames | NoSlashes
}
+ ///
+ /// An IComponentFactory that is used in the command line.
+ ///
+ /// This allows components to be created by name, signature type, and a settings string.
+ ///
+ public interface ICommandLineComponentFactory : IComponentFactory
+ {
+ Type SignatureType { get; }
+ string Name { get; }
+ string GetSettingsString();
+ }
+
///
/// Parser for command line arguments.
///
@@ -797,7 +809,8 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat
ModuleCatalog.ComponentInfo component;
if (IsCurlyGroup(value) && value.Length == 2)
arg.Field.SetValue(destination, null);
- else if (_catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component))
+ else if (!arg.IsCollection &&
+ _catalog.Value.TryFindComponentCaseInsensitive(arg.Field.FieldType, value, out component))
{
var activator = Activator.CreateInstance(component.ArgumentType);
if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1]))
@@ -810,8 +823,9 @@ private bool ParseArgumentList(ArgumentInfo info, string[] strs, object destinat
}
else
{
- Report("Error: Failed to find component with name '{0}' for option '{1}'", value, arg.LongName);
- hadError |= true;
+ hadError |= !arg.SetValue(this, ref values[arg.Index], value, tag, destination);
+ if (!IsCurlyGroup(value) && i + 1 < strs.Length && IsCurlyGroup(strs[i + 1]))
+ hadError |= !arg.SetValue(this, ref values[arg.Index], strs[++i], "", destination);
}
continue;
}
@@ -1532,6 +1546,8 @@ private sealed class Argument
// Used for help and composing settings strings.
public readonly object DefaultValue;
+ private readonly Type _signatureType;
+
// For custom types.
private readonly ArgumentInfo _infoCustom;
private readonly ConstructorInfo _ctorCustom;
@@ -1559,6 +1575,7 @@ public Argument(int index, string name, string[] nicks, object defaults, Argumen
IsDefault = attr is DefaultArgumentAttribute;
Contracts.Assert(!IsDefault || Utils.Size(ShortNames) == 0);
IsHidden = attr.Hide;
+ _signatureType = attr.SignatureType;
if (field.FieldType.IsArray)
{
@@ -1664,6 +1681,40 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
Field.SetValue(destination, com);
}
+ else if (IsSingleComponentFactory)
+ {
+ bool haveName = false;
+ string name = null;
+ string[] settings = null;
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string str = (string)values[i].Value;
+ if (str.StartsWith("{"))
+ {
+ i++;
+ continue;
+ }
+ if (haveName)
+ {
+ owner.Report("Duplicate component kind for argument {0}", LongName);
+ error = true;
+ }
+ name = str;
+ haveName = true;
+ values.RemoveAt(i);
+ }
+
+ if (Utils.Size(values) > 0)
+ settings = values.Select(x => (string)x.Value).ToArray();
+
+ Contracts.Check(_signatureType != null, "ComponentFactory Arguments need a SignatureType set.");
+ var factory = ComponentFactoryFactory.CreateComponentFactory(
+ ItemType,
+ _signatureType,
+ name,
+ settings);
+ Field.SetValue(destination, factory);
+ }
else if (IsMultiSubComponent)
{
// REVIEW: the kind should not be separated from settings: everything related
@@ -1711,6 +1762,63 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
Field.SetValue(destination, arr);
}
}
+ else if (IsMultiComponentFactory)
+ {
+ // REVIEW: the kind should not be separated from settings: everything related
+ // to one item should go into one value, not multiple values
+ if (IsTaggedCollection)
+ {
+ // Tagged collection of IComponentFactory
+ var comList = new List>();
+
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string tag = values[i].Key;
+ string name = (string)values[i++].Value;
+ string[] settings = null;
+ if (i < values.Count && IsCurlyGroup((string)values[i].Value) && string.IsNullOrEmpty(values[i].Key))
+ settings = new string[] { (string)values[i++].Value };
+ var factory = ComponentFactoryFactory.CreateComponentFactory(
+ ItemValueType,
+ _signatureType,
+ name,
+ settings);
+ comList.Add(new KeyValuePair(tag, factory));
+ }
+
+ var arr = Array.CreateInstance(ItemType, comList.Count);
+ for (int i = 0; i < arr.Length; i++)
+ {
+ var kvp = Activator.CreateInstance(ItemType, comList[i].Key, comList[i].Value);
+ arr.SetValue(kvp, i);
+ }
+
+ Field.SetValue(destination, arr);
+ }
+ else
+ {
+ // Collection of IComponentFactory
+ var comList = new List();
+ for (int i = 0; i < Utils.Size(values);)
+ {
+ string name = (string)values[i++].Value;
+ string[] settings = null;
+ if (i < values.Count && IsCurlyGroup((string)values[i].Value))
+ settings = new string[] { (string)values[i++].Value };
+ var factory = ComponentFactoryFactory.CreateComponentFactory(
+ ItemValueType,
+ _signatureType,
+ name,
+ settings);
+ comList.Add(factory);
+ }
+
+ var arr = Array.CreateInstance(ItemValueType, comList.Count);
+ for (int i = 0; i < arr.Length; i++)
+ arr.SetValue(comList[i], i);
+ Field.SetValue(destination, arr);
+ }
+ }
else if (IsTaggedCollection)
{
var res = Array.CreateInstance(ItemType, Utils.Size(values));
@@ -1732,6 +1840,118 @@ public bool Finish(CmdParser owner, ArgValue val, object destination)
return error;
}
+ ///
+ /// A factory class for creating IComponentFactory instances.
+ ///
+ private static class ComponentFactoryFactory
+ {
+ public static IComponentFactory CreateComponentFactory(
+ Type factoryType,
+ Type signatureType,
+ string name,
+ string[] settings)
+ {
+ Contracts.Check(factoryType != null &&
+ typeof(IComponentFactory).IsAssignableFrom(factoryType) &&
+ factoryType.IsGenericType);
+
+ Type componentFactoryType;
+ if (factoryType.GenericTypeArguments.Length == 1)
+ {
+ componentFactoryType = typeof(ComponentFactory<>);
+ }
+ else if (factoryType.GenericTypeArguments.Length == 2)
+ {
+ componentFactoryType = typeof(ComponentFactory<,>);
+ }
+ else
+ {
+ throw Contracts.ExceptNotImpl("ComponentFactoryFactory can only create components with 1 or 2 type args.");
+ }
+
+ return (IComponentFactory)Activator.CreateInstance(
+ componentFactoryType.MakeGenericType(factoryType.GenericTypeArguments),
+ signatureType,
+ name,
+ settings);
+ }
+
+ private abstract class ComponentFactory : ICommandLineComponentFactory
+ {
+ public Type SignatureType { get; }
+ public string Name { get; }
+ private string[] Settings { get; }
+
+ protected ComponentFactory(Type signatureType, string name, string[] settings)
+ {
+ SignatureType = signatureType;
+ Name = name;
+
+ if (settings == null || (settings.Length == 1 && string.IsNullOrEmpty(settings[0])))
+ {
+ settings = Array.Empty();
+ }
+ Settings = settings;
+ }
+
+ public string GetSettingsString()
+ {
+ return CombineSettings(Settings);
+ }
+
+ public override string ToString()
+ {
+ if (string.IsNullOrEmpty(Name) && Settings.Length == 0)
+ return "{}";
+
+ if (Settings.Length == 0)
+ return Name;
+
+ string str = CombineSettings(Settings);
+ StringBuilder sb = new StringBuilder();
+ CmdQuoter.QuoteValue(str, sb, true);
+ return Name + sb.ToString();
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString());
+ }
+ }
+
+ private class ComponentFactory : ComponentFactory, IComponentFactory
+ where TComponent : class
+ {
+ public ComponentFactory(Type signatureType, string name, string[] settings)
+ : base(signatureType, name, settings)
+ {
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+ {
+ return ComponentCatalog.CreateInstance(
+ env,
+ SignatureType,
+ Name,
+ GetSettingsString(),
+ argument1);
+ }
+ }
+ }
+
private bool ReportMissingRequiredArgument(CmdParser owner, ArgValue val)
{
if (!IsRequired || val != null)
@@ -1784,7 +2004,7 @@ public bool SetValue(CmdParser owner, ref ArgValue val, string value, string tag
}
val.Values.Add(new KeyValuePair(tag, newValue));
}
- else if (IsSingleSubComponent)
+ else if (IsSingleSubComponent || IsComponentFactory)
{
Contracts.Assert(newValue is string || newValue == null);
Contracts.Assert((string)newValue != "");
@@ -1834,7 +2054,7 @@ private bool ParseValue(CmdParser owner, string data, out object value)
return false;
}
- if (IsSubComponentItemType)
+ if (IsSubComponentItemType || IsComponentFactory)
{
value = data;
return true;
@@ -2186,19 +2406,28 @@ private string GetString(IExceptionContext ectx, object value, StringBuilder buf
string name;
var catalog = ModuleCatalog.CreateInstance(ectx);
var type = value.GetType();
- bool success = catalog.TryGetComponentShortName(type, out name);
- Contracts.Assert(success);
-
- var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
- buffer.Clear();
- buffer.Append(name);
- if (!string.IsNullOrWhiteSpace(settings))
+ bool isModuleComponent = catalog.TryGetComponentShortName(type, out name);
+ if (isModuleComponent)
{
- StringBuilder sb = new StringBuilder();
- CmdQuoter.QuoteValue(settings, sb, true);
- buffer.Append(sb);
+ var settings = GetSettings(ectx, value, Activator.CreateInstance(type));
+ buffer.Clear();
+ buffer.Append(name);
+ if (!string.IsNullOrWhiteSpace(settings))
+ {
+ StringBuilder sb = new StringBuilder();
+ CmdQuoter.QuoteValue(settings, sb, true);
+ buffer.Append(sb);
+ }
+ return buffer.ToString();
+ }
+ else if (value is ICommandLineComponentFactory)
+ {
+ return value.ToString();
+ }
+ else
+ {
+ throw ectx.Except($"IComponentFactory instances either need to be EntryPointComponents or implement {nameof(ICommandLineComponentFactory)}.");
}
- return buffer.ToString();
}
return value.ToString();
@@ -2344,6 +2573,16 @@ public bool IsMultiSubComponent {
get { return IsSubComponentItemType && Field.FieldType.IsArray; }
}
+ public bool IsSingleComponentFactory
+ {
+ get { return IsComponentFactory && !Field.FieldType.IsArray; }
+ }
+
+ public bool IsMultiComponentFactory
+ {
+ get { return IsComponentFactory && Field.FieldType.IsArray; }
+ }
+
public bool IsCustomItemType {
get { return _infoCustom != null; }
}
diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
index 3b56e8bb36..ddbbd2a500 100644
--- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
+++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
@@ -343,7 +343,7 @@ private static bool ShouldSkipPath(string path)
case "libvw.dll":
case "matrixinterf.dll":
case "Microsoft.ML.neuralnetworks.gpucuda.dll":
- case "Microsoft.ML.mklimports.dll":
+ case "MklImports.dll":
case "microsoft.research.controls.decisiontrees.dll":
case "Microsoft.ML.neuralnetworks.sse.dll":
case "neuraltreeevaluator.dll":
@@ -832,10 +832,15 @@ public static LoadableClassInfo[] FindLoadableClasses()
public static LoadableClassInfo GetLoadableClassInfo(string loadName)
{
- Contracts.CheckParam(typeof(TSig).BaseType == typeof(MulticastDelegate), nameof(TSig), "TSig must be a delegate type");
+ return GetLoadableClassInfo(loadName, typeof(TSig));
+ }
+
+ public static LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType)
+ {
+ Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type");
Contracts.CheckValueOrNull(loadName);
loadName = (loadName ?? "").ToLowerInvariant().Trim();
- return FindClassCore(new LoadableClassInfo.Key(loadName, typeof(TSig)));
+ return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType));
}
public static LoadableClassInfo GetLoadableClassInfo(SubComponent sub)
@@ -886,6 +891,18 @@ public static TRes CreateInstance(this SubComponent comp
throw Contracts.Except("Unknown loadable class: {0}", comp.Kind).MarkSensitive(MessageSensitivity.None);
}
+ ///
+ /// Create an instance of the indicated component with the given extra parameters.
+ ///
+ public static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra)
+ where TRes : class
+ {
+ TRes result;
+ if (TryCreateInstance(env, signatureType, out result, name, options, extra))
+ return result;
+ throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None);
+ }
+
///
/// Try to create an instance of the indicated component with the given extra parameters. If there is no
/// such component in the catalog, returns false. Any other error results in an exception.
@@ -913,13 +930,19 @@ public static bool TryCreateInstance(IHostEnvironment env, out TRes
///
public static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra)
where TRes : class
+ {
+ return TryCreateInstance(env, typeof(TSig), out result, name, options, extra);
+ }
+
+ private static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
+ where TRes : class
{
Contracts.CheckValue(env, nameof(env));
- env.Check(typeof(TSig).BaseType == typeof(MulticastDelegate));
+ env.Check(signatureType.BaseType == typeof(MulticastDelegate));
env.CheckValueOrNull(name);
string nameLower = (name ?? "").ToLowerInvariant().Trim();
- LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, typeof(TSig)));
+ LoadableClassInfo info = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));
if (info == null)
{
result = null;
diff --git a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
index 9334f0f225..d69a9d0b93 100644
--- a/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
+++ b/src/Microsoft.ML.Core/EntryPoints/ComponentFactory.cs
@@ -37,6 +37,26 @@ public interface IComponentFactory : IComponentFactory
TComponent CreateComponent(IHostEnvironment env, TArg1 argument1);
}
+ ///
+ /// A class for creating a component when we take one extra parameter
+ /// (and an ) that simply wraps a delegate which
+ /// creates the component.
+ ///
+ public class SimpleComponentFactory : IComponentFactory
+ {
+ private Func _factory;
+
+ public SimpleComponentFactory(Func factory)
+ {
+ _factory = factory;
+ }
+
+ public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1)
+ {
+ return _factory(env, argument1);
+ }
+ }
+
///
/// An interface for creating a component when we take two extra parameters (and an ).
///
diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
index 26ec32d3fe..affd949064 100644
--- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
@@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;
@@ -69,8 +70,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether we should cache input training data", ShortName = "cache")]
public bool? CacheData;
- [Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf")]
- public KeyValuePair>[] PreTransform;
+ [Argument(ArgumentType.Multiple, HelpText = "Transforms to apply prior to splitting the data into folds", ShortName = "prexf", SignatureType = typeof(SignatureDataTransform))]
+ public KeyValuePair>[] PreTransform;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The validation data file", ShortName = "valid")]
public string ValidationFile;
@@ -159,16 +160,18 @@ private void RunCore(IChannel ch, string cmd)
string name = TrainUtils.MatchNameOrDefaultOrNull(ch, loader.Schema, nameof(Args.NameColumn), Args.NameColumn, DefaultColumnNames.Name);
if (name == null)
{
- var args = new GenerateNumberTransform.Arguments();
- args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, };
- args.UseCounter = true;
- var options = CmdParser.GetSettings(ch, args, new GenerateNumberTransform.Arguments());
preXf = preXf.Concat(
new[]
{
- new KeyValuePair>(
- "", new SubComponent(
- GenerateNumberTransform.LoadName, options))
+ new KeyValuePair>(
+ "", new SimpleComponentFactory(
+ (env, input) =>
+ {
+ var args = new GenerateNumberTransform.Arguments();
+ args.Column = new[] { new GenerateNumberTransform.Column() { Name = DefaultColumnNames.Name }, };
+ args.UseCounter = true;
+ return new GenerateNumberTransform(env, args, input);
+ }))
}).ToArray();
}
}
@@ -263,7 +266,7 @@ private RoleMappedData ApplyAllTransformsToData(IHostEnvironment env, IChannel c
private RoleMappedData CreateRoleMappedData(IHostEnvironment env, IChannel ch, IDataView data, ITrainer trainer)
{
foreach (var kvp in Args.Transform)
- data = kvp.Value.CreateInstance(env, data);
+ data = kvp.Value.CreateComponent(env, data);
var schema = data.Schema;
string label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn), Args.LabelColumn, DefaultColumnNames.Label);
diff --git a/src/Microsoft.ML.Data/Commands/DataCommand.cs b/src/Microsoft.ML.Data/Commands/DataCommand.cs
index 2a62d78901..8f489e270f 100644
--- a/src/Microsoft.ML.Data/Commands/DataCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/DataCommand.cs
@@ -8,6 +8,8 @@
using System.Linq;
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data.IO;
+using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
@@ -20,8 +22,8 @@ public static class DataCommand
{
public abstract class ArgumentsBase
{
- [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "")]
- public SubComponent Loader;
+ [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "", SignatureType = typeof(SignatureDataLoader))]
+ public IComponentFactory Loader;
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
public string DataFile;
@@ -51,8 +53,8 @@ public abstract class ArgumentsBase
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
public int? Parallel;
- [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
- public KeyValuePair>[] Transform;
+ [Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))]
+ public KeyValuePair>[] Transform;
}
public abstract class ImplBase : ICommand
@@ -125,6 +127,17 @@ protected void SendTelemetryComponent(IPipe pipe, SubComponent
pipe.Send(TelemetryMessage.CreateTrainer(sub.Kind, sub.SubComponentSettings));
}
+ protected void SendTelemetryComponent(IPipe pipe, IComponentFactory factory)
+ {
+ Host.AssertValue(pipe);
+ Host.AssertValueOrNull(factory);
+
+ if (factory is ICommandLineComponentFactory commandLineFactory)
+ pipe.Send(TelemetryMessage.CreateTrainer(commandLineFactory.Name, commandLineFactory.GetSettingsString()));
+ else
+ pipe.Send(TelemetryMessage.CreateTrainer("Unknown", "Non-ICommandLineComponentFactory object"));
+ }
+
protected virtual void SendTelemetryCore(IPipe pipe)
{
Contracts.AssertValue(pipe);
@@ -212,9 +225,9 @@ protected void SaveLoader(IDataLoader loader, string path)
LoaderUtils.SaveLoader(loader, file);
}
- protected IDataLoader CreateAndSaveLoader(string defaultLoader = "TextLoader")
+ protected IDataLoader CreateAndSaveLoader(Func defaultLoaderFactory = null)
{
- var loader = CreateLoader(defaultLoader);
+ var loader = CreateLoader(defaultLoaderFactory);
if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
{
using (var file = Host.CreateOutputFile(Args.OutputModelFile))
@@ -268,12 +281,12 @@ protected void LoadModelObjects(
}
// Next create the loader.
- var sub = Args.Loader;
+ var loaderFactory = Args.Loader;
IDataLoader trainPipe = null;
- if (sub.IsGood())
+ if (loaderFactory != null)
{
// The loader is overridden from the command line.
- pipe = sub.CreateInstance(Host, new MultiFileSource(Args.DataFile));
+ pipe = loaderFactory.CreateComponent(Host, new MultiFileSource(Args.DataFile));
if (Args.LoadTransforms == true)
{
Host.CheckUserArg(!string.IsNullOrWhiteSpace(Args.InputModelFile), nameof(Args.InputModelFile));
@@ -316,9 +329,9 @@ protected void LoadModelObjects(
}
}
- protected IDataLoader CreateLoader(string defaultLoader = "TextLoader")
+ protected IDataLoader CreateLoader(Func defaultLoaderFactory = null)
{
- var loader = CreateRawLoader(defaultLoader);
+ var loader = CreateRawLoader(defaultLoaderFactory);
loader = CreateTransformChain(loader);
return loader;
}
@@ -328,13 +341,15 @@ private IDataLoader CreateTransformChain(IDataLoader loader)
return CompositeDataLoader.Create(Host, loader, Args.Transform);
}
- protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", string dataFile = null)
+ protected IDataLoader CreateRawLoader(
+ Func defaultLoaderFactory = null,
+ string dataFile = null)
{
if (string.IsNullOrWhiteSpace(dataFile))
dataFile = Args.DataFile;
IDataLoader loader;
- if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && !Args.Loader.IsGood())
+ if (!string.IsNullOrWhiteSpace(Args.InputModelFile) && Args.Loader == null)
{
// Load the loader from the data model.
using (var file = Host.OpenInputFile(Args.InputModelFile))
@@ -345,8 +360,9 @@ protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", strin
else
{
// Either there is no input model file, or there is, but the loader is overridden.
- var sub = Args.Loader;
- if (!sub.IsGood())
+ IMultiStreamSource fileSource = new MultiFileSource(dataFile);
+ var loaderFactory = Args.Loader;
+ if (loaderFactory == null)
{
var ext = Path.GetExtension(dataFile);
var isText =
@@ -354,12 +370,17 @@ protected IDataLoader CreateRawLoader(string defaultLoader = "TextLoader", strin
string.Equals(ext, ".tlc", StringComparison.OrdinalIgnoreCase);
var isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
var isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
- sub =
- new SubComponent(
- isText ? "TextLoader" : isBinary ? "BinaryLoader" : isTranspose ? "TransposeLoader" : defaultLoader);
- }
- loader = sub.CreateInstance(Host, new MultiFileSource(dataFile));
+ return isText ? new TextLoader(Host, new TextLoader.Arguments(), fileSource) :
+ isBinary ? new BinaryLoader(Host, new BinaryLoader.Arguments(), fileSource) :
+ isTranspose ? new TransposeLoader(Host, new TransposeLoader.Arguments(), fileSource) :
+ defaultLoaderFactory != null ? defaultLoaderFactory(Host, fileSource) :
+ new TextLoader(Host, new TextLoader.Arguments(), fileSource);
+ }
+ else
+ {
+ loader = loaderFactory.CreateComponent(Host, fileSource);
+ }
if (Args.LoadTransforms == true)
{
diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
index 77bdf0e32f..c122c65ffa 100644
--- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs
@@ -217,7 +217,8 @@ private void RunCore(IChannel ch)
Host.AssertValue(ch);
ch.Trace("Creating loader");
- IDataView view = CreateAndSaveLoader(IO.BinaryLoader.LoadName);
+ IDataView view = CreateAndSaveLoader(
+ (env, source) => new IO.BinaryLoader(env, new IO.BinaryLoader.Arguments(), source));
ch.Trace("Binding columns");
ISchema schema = view.Schema;
diff --git a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
index 607bf119d7..f69c35231d 100644
--- a/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
+++ b/src/Microsoft.ML.Data/Commands/ScoreCommand.cs
@@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Command;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
@@ -62,8 +63,8 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to include hidden columns", ShortName = "keep")]
public bool KeepHidden;
- [Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf")]
- public KeyValuePair>[] PostTransform;
+ [Argument(ArgumentType.Multiple, HelpText = "Post processing transform", ShortName = "pxf", SignatureType = typeof(SignatureDataTransform))]
+ public KeyValuePair>[] PostTransform;
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output all columns or just scores", ShortName = "all")]
public bool? OutputAllColumns;
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
index a2ab3a7b16..d83af5d824 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
@@ -14,6 +14,7 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.EntryPoints;
[assembly: LoadableClass(typeof(IDataLoader), typeof(CompositeDataLoader), typeof(CompositeDataLoader.Arguments), typeof(SignatureDataLoader),
"Composite Data Loader", "CompositeDataLoader", "Composite", "PipeData", "Pipe", "PipeDataLoader")]
@@ -34,11 +35,11 @@ public sealed class CompositeDataLoader : IDataLoader, ITransposeDataView
{
public sealed class Arguments
{
- [Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader")]
- public SubComponent Loader;
+ [Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SignatureType = typeof(SignatureDataLoader))]
+ public IComponentFactory Loader;
- [Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")]
- public KeyValuePair>[] Transform;
+ [Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf", SignatureType = typeof(SignatureDataTransform))]
+ public KeyValuePair>[] Transform;
}
private struct TransformEx
@@ -98,10 +99,10 @@ public static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStr
var h = env.Register(RegistrationName);
h.CheckValue(args, nameof(args));
- h.CheckUserArg(args.Loader.IsGood(), nameof(args.Loader));
+ h.CheckValue(args.Loader, nameof(args.Loader));
h.CheckValue(files, nameof(files));
- var loader = args.Loader.CreateInstance(h, files);
+ var loader = args.Loader.CreateComponent(h, files);
return CreateCore(h, loader, args.Transform);
}
@@ -111,7 +112,7 @@ public static IDataLoader Create(IHostEnvironment env, Arguments args, IMultiStr
/// If there are no transforms, the is returned.
///
public static IDataLoader Create(IHostEnvironment env, IDataLoader srcLoader,
- params KeyValuePair>[] transformArgs)
+ params KeyValuePair>[] transformArgs)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
@@ -122,7 +123,7 @@ public static IDataLoader Create(IHostEnvironment env, IDataLoader srcLoader,
}
private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader,
- KeyValuePair>[] transformArgs)
+ KeyValuePair>[] transformArgs)
{
Contracts.AssertValue(host, "host");
host.AssertValue(srcLoader, "srcLoader");
@@ -131,8 +132,15 @@ private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader,
if (Utils.Size(transformArgs) == 0)
return srcLoader;
+ string GetTagData(IComponentFactory factory)
+ {
+ // When coming from the command line, preserve the string arguments.
+ // For other factories, we aren't able to get the string.
+ return (factory as ICommandLineComponentFactory)?.ToString();
+ }
+
var tagData = transformArgs
- .Select(x => new KeyValuePair(x.Key, x.Value.ToString()))
+ .Select(x => new KeyValuePair(x.Key, GetTagData(x.Value)))
.ToArray();
// Warn if tags coincide with ones already present in the loader.
@@ -152,7 +160,7 @@ private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader,
}
return ApplyTransformsCore(host, srcLoader, tagData,
- (prov, index, data) => transformArgs[index].Value.CreateInstance(prov, data));
+ (env, index, data) => transformArgs[index].Value.CreateComponent(env, data));
}
///
diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
index 6eaf48e995..6365e27a54 100644
--- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs
@@ -117,8 +117,8 @@ public abstract class ArgumentsBase : TransformInputBase
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 110, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string DataFile;
- [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 111, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
- public SubComponent Loader;
+ [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 111, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))]
+ public IComponentFactory Loader;
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 112, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string TermsColumn;
@@ -309,12 +309,19 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
string file = args.DataFile;
// First column using the file.
string src = args.TermsColumn;
- var sub = args.Loader;
+ IMultiStreamSource fileSource = new MultiFileSource(file);
+
+ var loaderFactory = args.Loader;
// If the user manually specifies a loader, or this is already a pre-processed binary
// file, then we assume the user knows what they're doing and do not attempt to convert
// to the desired type ourselves.
bool autoConvert = false;
- if (!sub.IsGood())
+ IDataLoader loader;
+ if (loaderFactory != null)
+ {
+ loader = loaderFactory.CreateComponent(env, fileSource);
+ }
+ else
{
// Determine the default loader from the extension.
var ext = Path.GetExtension(file);
@@ -326,11 +333,11 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn),
"Must be specified");
if (isBinary)
- sub = new SubComponent("BinaryLoader");
+ loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
else
{
ch.Assert(isTranspose);
- sub = new SubComponent("TransposeLoader");
+ loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
}
}
else
@@ -341,7 +348,21 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
"{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}",
nameof(Arguments.TermsColumn), src);
}
- sub = new SubComponent("TextLoader", "sep=tab col=Term:TX:0");
+ loader = new TextLoader(env,
+ new TextLoader.Arguments()
+ {
+ Separator = "tab",
+ Column = new[]
+ {
+ new TextLoader.Column()
+ {
+ Name ="Term",
+ Type = DataKind.TX,
+ Source = new[] { new TextLoader.Range() { Min = 0 } }
+ }
+ }
+ },
+ fileSource);
src = "Term";
autoConvert = true;
}
@@ -349,8 +370,6 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, Argu
ch.AssertNonEmpty(src);
int colSrc;
- var loader = sub.CreateInstance(env, new MultiFileSource(file));
-
if (!loader.Schema.TryGetColumnIndex(src, out colSrc))
throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src);
var typeSrc = loader.Schema.GetColumnType(colSrc);
@@ -395,7 +414,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info
ch.AssertValue(trainingData);
if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) &&
- (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader.IsGood() ||
+ (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null ||
!string.IsNullOrWhiteSpace(args.TermsColumn)))
{
ch.Warning("Explicit term list specified. Data file arguments will be ignored");
diff --git a/src/Microsoft.ML.Data/Transforms/doc.xml b/src/Microsoft.ML.Data/Transforms/doc.xml
index a3d4ba9f5e..13f108a107 100644
--- a/src/Microsoft.ML.Data/Transforms/doc.xml
+++ b/src/Microsoft.ML.Data/Transforms/doc.xml
@@ -28,7 +28,7 @@
The TextToKeyConverter transform builds up term vocabularies (dictionaries).
- The TextToKey Converter and the are the two one primary mechanisms by which raw input is transformed into keys.
+ The TextToKeyConverter and the are the two one primary mechanisms by which raw input is transformed into keys.
If multiple columns are used, each column builds/uses exactly one vocabulary.
The output columns are KeyType-valued.
The Key value is the one-based index of the item in the dictionary.
@@ -49,6 +49,52 @@
+
+
+
+ Handle missing values by replacing them with either the default value or the indicated value.
+
+
+ This transform handles missing values in the input columns. For each input column, it creates an output column
+ where the missing values are replaced by one of these specified values:
+
+ -
+ The default value of the appropriate type.
+
+ -
+ The mean value of the appropriate type.
+
+ -
+ The max value of the appropriate type.
+
+ -
+ The min value of the appropriate type.
+
+
+ The last three work only for numeric/TimeSpan/DateTime kind columns.
+
+ The output column can also optionally include an indicator vector for which slots were missing in the input column.
+ This can be done only when the indicator vector type can be converted to the input column type, i.e. only for numeric columns.
+
+
+ When computing the mean/max/min value, there is also an option to compute it over the whole column instead of per slot.
+ This option has a default value of true for variable length vectors, and false for known length vectors.
+ It can be changed to true for known length vectors, but it results in an error if changed to false for variable length vectors.
+
+
+
+
+
+
+
+
+ pipeline.Add(new MissingValueHandler("FeatureCol", "CleanFeatureCol")
+ {
+ ReplaceWith = NAHandleTransformReplacementKind.Mean
+ });
+
+
+
diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs
index fad40495f4..9f24b4bc09 100644
--- a/src/Microsoft.ML.FastTree/FastTree.cs
+++ b/src/Microsoft.ML.FastTree/FastTree.cs
@@ -95,7 +95,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args)
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
// Finally, even the binary classifiers, being logitboost, tend to not benefit from external calibration.
- Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration);
+ Info = new TrainerInfo(normalization: false, caching: false, calibration: NeedCalibration, supportValid: true);
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;
if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
{
diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
index 0e5a4c1862..f404f3ae95 100644
--- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
+++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
@@ -807,7 +807,7 @@ public static partial class TreeFeaturize
Desc = TreeEnsembleFeaturizerTransform.TreeEnsembleSummary,
UserName = TreeEnsembleFeaturizerTransform.UserName,
ShortName = TreeEnsembleFeaturizerBindableMapper.LoadNameShort,
- XmlInclude = new[] { @"" })]
+ XmlInclude = new[] { @"" })]
public static CommonOutputs.TransformOutput Featurizer(IHostEnvironment env, TreeEnsembleFeaturizerTransform.ArgumentsForEntryPoint input)
{
Contracts.CheckValue(env, nameof(env));
diff --git a/src/Microsoft.ML.FastTree/doc.xml b/src/Microsoft.ML.FastTree/doc.xml
index 8678654182..26d3c8c129 100644
--- a/src/Microsoft.ML.FastTree/doc.xml
+++ b/src/Microsoft.ML.FastTree/doc.xml
@@ -95,7 +95,7 @@
Generally, ensemble models provide better coverage and accuracy than single decision trees.
Each tree in a decision forest outputs a Gaussian distribution.
For more see:
-
+
- Wikipedia: Random forest
- Quantile regression forest
- From Stumps to Trees to Forests
@@ -146,7 +146,7 @@
Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector
to three outputs:
-
+
- A vector containing the individual tree outputs of the tree ensemble.
- A vector indicating the leaves that the feature vector falls on in the tree ensemble.
- A vector indicating the paths that the feature vector falls on in the tree ensemble.
@@ -157,28 +157,28 @@
In machine learning it is a pretty common and powerful approach to utilize the already trained model in the process of defining features.
- One such example would be the use of model's scores as features to downstream models. For example, we might run clustering on the original features,
+ One such example would be the use of model's scores as features to downstream models. For example, we might run clustering on the original features,
and use the cluster distances as the new feature set.
- Instead of consuming the model's output, we could go deeper, and extract the 'intermediate outputs' that are used to produce the final score.
+ Instead of consuming the model's output, we could go deeper, and extract the 'intermediate outputs' that are used to produce the final score.
There are a number of famous or popular examples of this technique:
-
- - A deep neural net trained on the ImageNet dataset, with the last layer removed, is commonly used to compute the 'projection' of the image into the 'semantic feature space'.
- It is observed that the Euclidean distance in this space often correlates with the 'semantic similarity': that is, all pictures of pizza are located close together,
+
+ - A deep neural net trained on the ImageNet dataset, with the last layer removed, is commonly used to compute the 'projection' of the image into the 'semantic feature space'.
+ It is observed that the Euclidean distance in this space often correlates with the 'semantic similarity': that is, all pictures of pizza are located close together,
and far away from pictures of kittens.
- - A matrix factorization and/or LDA model is also often used to extract the 'latent topics' or 'latent features' associated with users and items.
- - The weights of the linear model are often used as a crude indicator of 'feature importance'. At the very minimum, the 0-weight features are not needed by the model,
- and there's no reason to compute them.
+ - A matrix factorization and/or LDA model is also often used to extract the 'latent topics' or 'latent features' associated with users and items.
+ - The weights of the linear model are often used as a crude indicator of 'feature importance'. At the very minimum, the 0-weight features are not needed by the model,
+ and there's no reason to compute them.
Tree featurizer uses the decision tree ensembles for feature engineering in the same fashion as above.
- Let's assume that we've built a tree ensemble of 100 trees with 100 leaves each (it doesn't matter whether boosting was used or not in training).
+ Let's assume that we've built a tree ensemble of 100 trees with 100 leaves each (it doesn't matter whether boosting was used or not in training).
If we associate each leaf of each tree with a sequential integer, we can, for every incoming example x,
- produce an indicator vector L(x), where Li(x) = 1 if the example x 'falls' into the leaf #i, and 0 otherwise.
+ produce an indicator vector L(x), where Li(x) = 1 if the example x 'falls' into the leaf #i, and 0 otherwise.
Thus, for every example x, we produce a 10000-valued vector L, with exactly 100 1s and the rest zeroes.
- This 'leaf indicator' vector can be considered the ensemble-induced 'footprint' of the example.
- The 'distance' between two examples in the L-space is actually a Hamming distance, and is equal to the number of trees that do not distinguish the two examples.
+ This 'leaf indicator' vector can be considered the ensemble-induced 'footprint' of the example.
+ The 'distance' between two examples in the L-space is actually a Hamming distance, and is equal to the number of trees that do not distinguish the two examples.
We could repeat the same thought process for the non-leaf, or internal, nodes of the trees (we know that each tree has exactly 99 of them in our 100-leaf example),
- and produce another indicator vector, N (size 9900), for each example, indicating the 'trajectory' of each example through each of the trees.
- The distance in the combined 19900-dimensional LN-space will be equal to the number of 'decisions' in all trees that 'agree' on the given pair of examples.
+ and produce another indicator vector, N (size 9900), for each example, indicating the 'trajectory' of each example through each of the trees.
+ The distance in the combined 19900-dimensional LN-space will be equal to the number of 'decisions' in all trees that 'agree' on the given pair of examples.
The TreeLeafFeaturizer is also producing the third vector, T, which is defined as Ti(x) = output of tree #i on example x.
diff --git a/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj b/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj
new file mode 100644
index 0000000000..a5f3c4b748
--- /dev/null
+++ b/src/Microsoft.ML.HalLearners/Microsoft.ML.HalLearners.csproj
@@ -0,0 +1,16 @@
+
+
+
+ netstandard2.0
+ Microsoft.ML.HalLearners
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
similarity index 95%
rename from src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs
rename to src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
index 7f47271f68..dbf2999657 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs
+++ b/src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
@@ -7,14 +7,15 @@
using System;
using System.Collections.Generic;
using System.IO;
-using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.HalLearners;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.Model;
-using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Training;
using System.Runtime.InteropServices;
@@ -28,8 +29,11 @@
"OLS Linear Regression Executor",
OlsLinearRegressionPredictor.LoaderSignature)]
-namespace Microsoft.ML.Runtime.Learners
+[assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)]
+
+namespace Microsoft.ML.Runtime.HalLearners
{
+ ///
public sealed class OlsLinearRegressionTrainer : TrainerBase
{
public sealed class Arguments : LearnerInputBaseWithWeight
@@ -51,11 +55,6 @@ public sealed class Arguments : LearnerInputBaseWithWeight
public const string ShortName = "ols";
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
+ "that minimizes the square loss function.";
- internal const string Remarks = @"
-Ordinary least squares (OLS) is a parameterized regression method.
-It assumes that the conditional mean of the dependent variable follows a linear function of the dependent variables.
-By minimizing the squares of the difference between observed values and the predictions, the parameters of the regressor can be estimated.
-";
private readonly Float _l2Weight;
private readonly bool _perParameterSignificance;
@@ -222,7 +221,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
catch (DllNotFoundException)
{
// REVIEW: Is there no better way?
- throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing.");
+ throw ch.ExceptNotSupp("The MKL library (libMklImports) or one of its dependencies is missing.");
}
// Solve for beta in (LL')beta = X'y:
Mkl.Pptrs(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, m, 1, xtx, xty, 1);
@@ -331,7 +330,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
internal static class Mkl
{
- private const string DllName = "Microsoft.ML.MklImports.dll";
+ private const string DllName = "MklImports";
public enum Layout
{
@@ -463,6 +462,24 @@ public static void Pptri(Layout layout, UpLo uplo, int n, Double[] ap)
}
}
}
+
+ [TlcModule.EntryPoint(Name = "Trainers.OrdinaryLeastSquaresRegressor",
+ Desc = "Train an OLS regression model.",
+ UserName = UserNameValue,
+ ShortName = ShortName,
+ XmlInclude = new[] { @"" })]
+ public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, Arguments input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var host = env.Register("TrainOLS");
+ host.CheckValue(input, nameof(input));
+ EntryPointUtils.CheckInputArgs(host, input);
+
+ return LearnerEntryPointsUtils.Train(host, input,
+ () => new OlsLinearRegressionTrainer(host, input),
+ () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
+ () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
+ }
}
///
diff --git a/src/Microsoft.ML.HalLearners/doc.xml b/src/Microsoft.ML.HalLearners/doc.xml
new file mode 100644
index 0000000000..d7ec04bb89
--- /dev/null
+++ b/src/Microsoft.ML.HalLearners/doc.xml
@@ -0,0 +1,27 @@
+
+
+
+
+
+
+ Train an OLS regression model.
+
+
+ Ordinary least squares (OLS) is a parameterized regression method.
+ It assumes that the conditional mean of the dependent variable follows a linear function of the dependent variables.
+ The parameters of the regressor can be estimated by minimizing the squares of the difference between observed values and the predictions.
+
+
+
+ new OrdinaryLeastSquaresRegressor
+ {
+ L2Weight = 0.1,
+ PerParameterSignificance = false,
+ NormalizeFeatures = Microsoft.ML.Models.NormalizeOption.Yes
+ }
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
new file mode 100644
index 0000000000..97c613485f
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/EntryPoints/ImageAnalytics.cs
@@ -0,0 +1,79 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
+
+[assembly: LoadableClass(typeof(void), typeof(ImageAnalytics), null, typeof(SignatureEntryPointModule), "ImageAnalytics")]
+namespace Microsoft.ML.Runtime.ImageAnalytics.EntryPoints
+{
+ public static class ImageAnalytics
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.ImageLoader", Desc = ImageLoaderTransform.Summary,
+ UserName = ImageLoaderTransform.UserName, ShortName = ImageLoaderTransform.LoaderSignature)]
+ public static CommonOutputs.TransformOutput ImageLoader(IHostEnvironment env, ImageLoaderTransform.Arguments input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageLoaderTransform", input);
+ var xf = new ImageLoaderTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+
+ [TlcModule.EntryPoint(Name = "Transforms.ImageResizer", Desc = ImageResizerTransform.Summary,
+ UserName = ImageResizerTransform.UserName, ShortName = ImageResizerTransform.LoaderSignature)]
+ public static CommonOutputs.TransformOutput ImageResizer(IHostEnvironment env, ImageResizerTransform.Arguments input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageResizerTransform", input);
+ var xf = new ImageResizerTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+
+ [TlcModule.EntryPoint(Name = "Transforms.ImagePixelExtractor", Desc = ImagePixelExtractorTransform.Summary,
+ UserName = ImagePixelExtractorTransform.UserName, ShortName = ImagePixelExtractorTransform.LoaderSignature)]
+ public static CommonOutputs.TransformOutput ImagePixelExtractor(IHostEnvironment env, ImagePixelExtractorTransform.Arguments input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImagePixelExtractorTransform", input);
+ var xf = new ImagePixelExtractorTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+
+ [TlcModule.EntryPoint(Name = "Transforms.ImageGrayscale", Desc = ImageGrayscaleTransform.Summary,
+ UserName = ImageGrayscaleTransform.UserName, ShortName = ImageGrayscaleTransform.LoaderSignature)]
+ public static CommonOutputs.TransformOutput ImageGrayscale(IHostEnvironment env, ImageGrayscaleTransform.Arguments input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "ImageGrayscaleTransform", input);
+ var xf = new ImageGrayscaleTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+
+ [TlcModule.EntryPoint(Name = "Transforms.VectorToImage", Desc = VectorToImageTransform.Summary,
+ UserName = VectorToImageTransform.UserName, ShortName = VectorToImageTransform.LoaderSignature)]
+ public static CommonOutputs.TransformOutput VectorToImage(IHostEnvironment env, VectorToImageTransform.Arguments input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "VectorToImageTransform", input);
+ var xf = new VectorToImageTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
new file mode 100644
index 0000000000..7a267cf1b8
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs
@@ -0,0 +1,171 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Drawing;
+using System.Drawing.Imaging;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.ImageAnalytics;
+
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), typeof(ImageGrayscaleTransform.Arguments), typeof(SignatureDataTransform),
+ ImageGrayscaleTransform.UserName, "ImageGrayscaleTransform", "ImageGrayscale")]
+
+[assembly: LoadableClass(ImageGrayscaleTransform.Summary, typeof(ImageGrayscaleTransform), null, typeof(SignatureLoadDataTransform),
+ ImageGrayscaleTransform.UserName, ImageGrayscaleTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ // REVIEW: Rewrite as LambdaTransform to simplify.
+ // REVIEW: Should it be separate transform or part of ImageResizerTransform?
+ ///
+ /// Transform which takes one or many columns of type in IDataView and
+ /// convert them to greyscale representation of the same image.
+ ///
+ public sealed class ImageGrayscaleTransform : OneToOneTransformBase
+ {
+ public sealed class Column : OneToOneColumn
+ {
+ public static Column Parse(string str)
+ {
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+ }
+
+ internal const string Summary = "Convert image into grayscale.";
+
+ internal const string UserName = "Image Greyscale Transform";
+ public const string LoaderSignature = "ImageGrayscaleTransform";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "IMGGRAYT",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ private const string RegistrationName = "ImageGrayscale";
+
+ // Public constructor corresponding to SignatureDataTransform.
+ public ImageGrayscaleTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertNonEmpty(Infos);
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+ Metadata.Seal();
+ }
+
+ private ImageGrayscaleTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertValue(ctx);
+ // *** Binary format ***
+ //
+ Host.AssertNonEmpty(Infos);
+ Metadata.Seal();
+ }
+
+ public static ImageGrayscaleTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+ return h.Apply("Loading Model", ch => new ImageGrayscaleTransform(h, ctx, input));
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ //
+ SaveBase(ctx);
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Assert(0 <= iinfo & iinfo < Infos.Length);
+ return Infos[iinfo].TypeSrc;
+ }
+
+ private static readonly ColorMatrix _grayscaleColorMatrix = new ColorMatrix(
+ new float[][]
+ {
+ new float[] {.3f, .3f, .3f, 0, 0},
+ new float[] {.59f, .59f, .59f, 0, 0},
+ new float[] {.11f, .11f, .11f, 0, 0},
+ new float[] {0, 0, 0, 1, 0},
+ new float[] {0, 0, 0, 0, 1}
+ });
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValueOrNull(ch);
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+
+ var src = default(Bitmap);
+ var getSrc = GetSrcGetter(input, iinfo);
+
+ disposer =
+ () =>
+ {
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
+ {
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+
+ dst = new Bitmap(src.Width, src.Height);
+ ImageAttributes attributes = new ImageAttributes();
+ attributes.SetColorMatrix(_grayscaleColorMatrix);
+ var srcRectangle = new Rectangle(0, 0, src.Width, src.Height);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, srcRectangle, 0, 0, src.Width, src.Height, GraphicsUnit.Pixel, attributes);
+ }
+ Host.Assert(dst.Width == src.Width && dst.Height == src.Height);
+ };
+
+ return del;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
new file mode 100644
index 0000000000..488c710743
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs
@@ -0,0 +1,178 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Drawing;
+using System.IO;
+using System.Text;
+using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), typeof(ImageLoaderTransform.Arguments), typeof(SignatureDataTransform),
+ ImageLoaderTransform.UserName, "ImageLoaderTransform", "ImageLoader")]
+
+[assembly: LoadableClass(ImageLoaderTransform.Summary, typeof(ImageLoaderTransform), null, typeof(SignatureLoadDataTransform),
+ ImageLoaderTransform.UserName, ImageLoaderTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ // REVIEW: Rewrite as LambdaTransform to simplify.
+ ///
+ /// Transform which takes one or many columns of type and loads them as
+ ///
+ public sealed class ImageLoaderTransform : OneToOneTransformBase
+ {
+ public sealed class Column : OneToOneColumn
+ {
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public sealed class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
+ ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Folder where to search for images", ShortName = "folder")]
+ public string ImageFolder;
+ }
+
+ internal const string Summary = "Load images from files.";
+ internal const string UserName = "Image Loader Transform";
+ public const string LoaderSignature = "ImageLoaderTransform";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "IMGLOADT",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
+ }
+
+ private readonly ImageType _type;
+ private readonly string _imageFolder;
+
+ private const string RegistrationName = "ImageLoader";
+
+ // Public constructor corresponding to SignatureDataTransform.
+ public ImageLoaderTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, TestIsText)
+ {
+ Host.AssertNonEmpty(Infos);
+ _imageFolder = args.ImageFolder;
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+ _type = new ImageType();
+ Metadata.Seal();
+ }
+
+ private ImageLoaderTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, TestIsText)
+ {
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ //
+ _imageFolder = ctx.Reader.ReadString();
+ _type = new ImageType();
+ Metadata.Seal();
+ }
+
+ public static ImageLoaderTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+ return h.Apply("Loading Model", ch => new ImageLoaderTransform(h, ctx, input));
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ //
+ ctx.Writer.Write(_imageFolder);
+ SaveBase(ctx);
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Check(0 <= iinfo && iinfo < Infos.Length);
+ return _type;
+ }
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValue(ch, nameof(ch));
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ disposer = null;
+
+ var getSrc = GetSrcGetter(input, iinfo);
+ DvText src = default;
+ ValueGetter del =
+ (ref Bitmap dst) =>
+ {
+ if (dst != null)
+ {
+ dst.Dispose();
+ dst = null;
+ }
+
+ getSrc(ref src);
+
+ if (src.Length > 0)
+ {
+ // Catch exceptions and pass null through. Should also log failures...
+ try
+ {
+ string path = src.ToString();
+ if (!string.IsNullOrWhiteSpace(_imageFolder))
+ path = Path.Combine(_imageFolder, path);
+ dst = new Bitmap(path);
+ }
+ catch (Exception e)
+ {
+ // REVIEW: We catch everything since the documentation for new Bitmap(string)
+ // appears to be incorrect. When the file isn't found, it throws an ArgumentException,
+ // while the documentation says FileNotFoundException. Not sure what it will throw
+ // in other cases, like corrupted file, etc.
+
+ // REVIEW : Log failures.
+ ch.Info(e.Message);
+ ch.Info(e.StackTrace);
+ dst = null;
+ }
+ }
+ };
+ return del;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
new file mode 100644
index 0000000000..de0aa98124
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs
@@ -0,0 +1,541 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Drawing;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), typeof(ImagePixelExtractorTransform.Arguments), typeof(SignatureDataTransform),
+ ImagePixelExtractorTransform.UserName, "ImagePixelExtractorTransform", "ImagePixelExtractor")]
+
+[assembly: LoadableClass(ImagePixelExtractorTransform.Summary, typeof(ImagePixelExtractorTransform), null, typeof(SignatureLoadDataTransform),
+ ImagePixelExtractorTransform.UserName, ImagePixelExtractorTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ // REVIEW: Rewrite as LambdaTransform to simplify.
+ ///
+ /// Transform which takes one or many columns of and convert them into vector representation.
+ ///
+ public sealed class ImagePixelExtractorTransform : OneToOneTransformBase
+ {
+ public class Column : OneToOneColumn
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use alpha channel", ShortName = "alpha")]
+ public bool? UseAlpha;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use red channel", ShortName = "red")]
+ public bool? UseRed;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use green channel", ShortName = "green")]
+ public bool? UseGreen;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use blue channel", ShortName = "blue")]
+ public bool? UseBlue;
+
+ // REVIEW: Consider turning this into an enum that allows for pixel, line, or planar interleaving.
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to separate each channel or interleave in ARGB order", ShortName = "interleave")]
+ public bool? InterleaveArgb;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to convert to floating point", ShortName = "conv")]
+ public bool? Convert;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Offset (pre-scale)")]
+ public Single? Offset;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Scale factor")]
+ public Single? Scale;
+
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ if (UseAlpha != null || UseRed != null || UseGreen != null || UseBlue != null || Convert != null ||
+ Offset != null || Scale != null || InterleaveArgb != null)
+ {
+ return false;
+ }
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use alpha channel", ShortName = "alpha")]
+ public bool UseAlpha = false;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use red channel", ShortName = "red")]
+ public bool UseRed = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use green channel", ShortName = "green")]
+ public bool UseGreen = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use blue channel", ShortName = "blue")]
+ public bool UseBlue = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to separate each channel or interleave in ARGB order", ShortName = "interleave")]
+ public bool InterleaveArgb = false;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to convert to floating point", ShortName = "conv")]
+ public bool Convert = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Offset (pre-scale)")]
+ public Single? Offset;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Scale factor")]
+ public Single? Scale;
+ }
+
+ ///
+ /// Which color channels are extracted. Note that these values are serialized so should not be modified.
+ ///
+ [Flags]
+ private enum ColorBits : byte
+ {
+ Alpha = 0x01,
+ Red = 0x02,
+ Green = 0x04,
+ Blue = 0x08,
+
+ All = Alpha | Red | Green | Blue
+ }
+
+ private sealed class ColInfoEx
+ {
+ public readonly ColorBits Colors;
+ public readonly byte Planes;
+
+ public readonly bool Convert;
+ public readonly Single Offset;
+ public readonly Single Scale;
+ public readonly bool Interleave;
+
+ public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } }
+ public bool Red { get { return (Colors & ColorBits.Red) != 0; } }
+ public bool Green { get { return (Colors & ColorBits.Green) != 0; } }
+ public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } }
+
+ public ColInfoEx(Column item, Arguments args)
+ {
+ if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; }
+ if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; }
+ if (item.UseGreen ?? args.UseGreen) { Colors |= ColorBits.Green; Planes++; }
+ if (item.UseBlue ?? args.UseBlue) { Colors |= ColorBits.Blue; Planes++; }
+ Contracts.CheckUserArg(Planes > 0, nameof(item.UseRed), "Need to use at least one color plane");
+
+ Interleave = item.InterleaveArgb ?? args.InterleaveArgb;
+
+ Convert = item.Convert ?? args.Convert;
+ if (!Convert)
+ {
+ Offset = 0;
+ Scale = 1;
+ }
+ else
+ {
+ Offset = item.Offset ?? args.Offset ?? 0;
+ Scale = item.Scale ?? args.Scale ?? 1;
+ Contracts.CheckUserArg(FloatUtils.IsFinite(Offset), nameof(item.Offset));
+ Contracts.CheckUserArg(FloatUtils.IsFiniteNonZero(Scale), nameof(item.Scale));
+ }
+ }
+
+ public ColInfoEx(ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+ // *** Binary format ***
+ // byte: colors
+ // byte: convert
+ // Float: offset
+ // Float: scale
+ // byte: separateChannels
+ Colors = (ColorBits)ctx.Reader.ReadByte();
+ Contracts.CheckDecode(Colors != 0);
+ Contracts.CheckDecode((Colors & ColorBits.All) == Colors);
+
+ // Count the planes.
+ int planes = (int)Colors;
+ planes = (planes & 0x05) + ((planes >> 1) & 0x05);
+ planes = (planes & 0x03) + ((planes >> 2) & 0x03);
+ Planes = (byte)planes;
+ Contracts.Assert(0 < Planes & Planes <= 4);
+
+ Convert = ctx.Reader.ReadBoolByte();
+ Offset = ctx.Reader.ReadFloat();
+ Contracts.CheckDecode(FloatUtils.IsFinite(Offset));
+ Scale = ctx.Reader.ReadFloat();
+ Contracts.CheckDecode(FloatUtils.IsFiniteNonZero(Scale));
+ Contracts.CheckDecode(Convert || Offset == 0 && Scale == 1);
+ Interleave = ctx.Reader.ReadBoolByte();
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+#if DEBUG
+ // This code is used in deserialization - assert that it matches what we computed above.
+ int planes = (int)Colors;
+ planes = (planes & 0x05) + ((planes >> 1) & 0x05);
+ planes = (planes & 0x03) + ((planes >> 2) & 0x03);
+ Contracts.Assert(planes == Planes);
+#endif
+
+ // *** Binary format ***
+ // byte: colors
+ // byte: convert
+ // Float: offset
+ // Float: scale
+ // byte: separateChannels
+ Contracts.Assert(Colors != 0);
+ Contracts.Assert((Colors & ColorBits.All) == Colors);
+ ctx.Writer.Write((byte)Colors);
+ ctx.Writer.WriteBoolByte(Convert);
+ Contracts.Assert(FloatUtils.IsFinite(Offset));
+ ctx.Writer.Write(Offset);
+ Contracts.Assert(FloatUtils.IsFiniteNonZero(Scale));
+ Contracts.Assert(Convert || Offset == 0 && Scale == 1);
+ ctx.Writer.Write(Scale);
+ ctx.Writer.WriteBoolByte(Interleave);
+ }
+ }
+
+ internal const string Summary = "Extract color plane(s) from an image. Options include scaling, offset and conversion to floating point.";
+ internal const string UserName = "Image Pixel Extractor Transform";
+ public const string LoaderSignature = "ImagePixelExtractor";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "IMGPXEXT",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
+ }
+
+ private const string RegistrationName = "ImagePixelExtractor";
+
+ private readonly ColInfoEx[] _exes;
+ private readonly VectorType[] _types;
+
+ // Public constructor corresponding to SignatureDataTransform.
+ public ImagePixelExtractorTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input,
+ t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertNonEmpty(Infos);
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+
+ _exes = new ColInfoEx[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ var item = args.Column[i];
+ _exes[i] = new ColInfoEx(item, args);
+ }
+
+ _types = ConstructTypes(true);
+ }
+
+ private ImagePixelExtractorTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ //
+ //
+ // foreach added column
+ // ColInfoEx
+ Host.AssertNonEmpty(Infos);
+ _exes = new ColInfoEx[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ _exes[i] = new ColInfoEx(ctx);
+
+ _types = ConstructTypes(false);
+ }
+
+ public static ImagePixelExtractorTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return h.Apply("Loading Model",
+ ch =>
+ {
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ int cbFloat = ctx.Reader.ReadInt32();
+ ch.CheckDecode(cbFloat == sizeof(Single));
+ return new ImagePixelExtractorTransform(h, ctx, input);
+ });
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ // foreach added column
+ // ColInfoEx
+ ctx.Writer.Write(sizeof(Single));
+ SaveBase(ctx);
+
+ Host.Assert(_exes.Length == Infos.Length);
+ for (int i = 0; i < _exes.Length; i++)
+ _exes[i].Save(ctx);
+ }
+
+ private VectorType[] ConstructTypes(bool user)
+ {
+ var types = new VectorType[Infos.Length];
+ for (int i = 0; i < Infos.Length; i++)
+ {
+ var info = Infos[i];
+ var ex = _exes[i];
+ Host.Assert(ex.Planes > 0);
+
+ var type = Source.Schema.GetColumnType(info.Source) as ImageType;
+ Host.Assert(type != null);
+ if (type.Height <= 0 || type.Width <= 0)
+ {
+ // REVIEW: Could support this case by making the destination column be variable sized.
+ // However, there's no mechanism to communicate the dimensions through with the pixel data.
+ string name = Source.Schema.GetColumnName(info.Source);
+ throw user ?
+ Host.ExceptUserArg(nameof(Arguments.Column), "Column '{0}' does not have known size", name) :
+ Host.Except("Column '{0}' does not have known size", name);
+ }
+ int height = type.Height;
+ int width = type.Width;
+ Host.Assert(height > 0);
+ Host.Assert(width > 0);
+ Host.Assert((long)height * width <= int.MaxValue / 4);
+
+ if (ex.Interleave)
+ types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, height, width, ex.Planes);
+ else
+ types[i] = new VectorType(ex.Convert ? NumberType.Float : NumberType.U1, ex.Planes, height, width);
+ }
+ Metadata.Seal();
+ return types;
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Assert(0 <= iinfo & iinfo < Infos.Length);
+ return _types[iinfo];
+ }
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValueOrNull(ch);
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+
+ if (_exes[iinfo].Convert)
+ return GetGetterCore(input, iinfo, out disposer);
+ return GetGetterCore(input, iinfo, out disposer);
+ }
+
+ //REVIEW Rewrite it to where TValue : IConvertible
+ private ValueGetter> GetGetterCore(IRow input, int iinfo, out Action disposer)
+ {
+ var type = _types[iinfo];
+ Host.Assert(type.DimCount == 3);
+
+ var ex = _exes[iinfo];
+
+ int planes = ex.Interleave ? type.GetDim(2) : type.GetDim(0);
+ int height = ex.Interleave ? type.GetDim(0) : type.GetDim(1);
+ int width = ex.Interleave ? type.GetDim(1) : type.GetDim(2);
+
+ int size = type.ValueCount;
+ Host.Assert(size > 0);
+ Host.Assert(size == planes * height * width);
+ int cpix = height * width;
+
+ var getSrc = GetSrcGetter(input, iinfo);
+ var src = default(Bitmap);
+
+ disposer =
+ () =>
+ {
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
+
+ return
+ (ref VBuffer dst) =>
+ {
+ getSrc(ref src);
+ Contracts.AssertValueOrNull(src);
+
+ if (src == null)
+ {
+ dst = new VBuffer(size, 0, dst.Values, dst.Indices);
+ return;
+ }
+
+ Host.Check(src.PixelFormat == System.Drawing.Imaging.PixelFormat.Format32bppArgb);
+ Host.Check(src.Height == height && src.Width == width);
+
+ var values = dst.Values;
+ if (Utils.Size(values) < size)
+ values = new TValue[size];
+
+ Single offset = ex.Offset;
+ Single scale = ex.Scale;
+ Host.Assert(scale != 0);
+
+ var vf = values as Single[];
+ var vb = values as byte[];
+ Host.Assert(vf != null || vb != null);
+ bool needScale = offset != 0 || scale != 1;
+ Host.Assert(!needScale || vf != null);
+
+ bool a = ex.Alpha;
+ bool r = ex.Red;
+ bool g = ex.Green;
+ bool b = ex.Blue;
+
+ int h = height;
+ int w = width;
+
+ if (ex.Interleave)
+ {
+ int idst = 0;
+ for (int y = 0; y < h; ++y)
+ for (int x = 0; x < w; x++)
+ {
+ var pb = src.GetPixel(y, x);
+ if (vb != null)
+ {
+ if (a) { vb[idst++] = (byte)0; }
+ if (r) { vb[idst++] = pb.R; }
+ if (g) { vb[idst++] = pb.G; }
+ if (b) { vb[idst++] = pb.B; }
+ }
+ else if (!needScale)
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = pb.R; }
+ if (g) { vf[idst++] = pb.G; }
+ if (b) { vf[idst++] = pb.B; }
+ }
+ else
+ {
+ if (a) { vf[idst++] = 0.0f; }
+ if (r) { vf[idst++] = (pb.R - offset) * scale; }
+ if (g) { vf[idst++] = (pb.B - offset) * scale; }
+ if (b) { vf[idst++] = (pb.G - offset) * scale; }
+ }
+ }
+ Host.Assert(idst == size);
+ }
+ else
+ {
+ int idstMin = 0;
+ if (ex.Alpha)
+ {
+ // The image only has rgb but we need to supply alpha as well, so fake it up,
+ // assuming that it is 0xFF.
+ if (vf != null)
+ {
+ Single v = (0xFF - offset) * scale;
+ for (int i = 0; i < cpix; i++)
+ vf[i] = v;
+ }
+ else
+ {
+ for (int i = 0; i < cpix; i++)
+ vb[i] = 0xFF;
+ }
+ idstMin = cpix;
+
+ // We've preprocessed alpha, avoid it in the
+ // scan operation below.
+ a = false;
+ }
+
+ for (int y = 0; y < h; ++y)
+ {
+ int idstBase = idstMin + y * w;
+
+ // Note that the bytes are in order BGR[A]. We arrange the layers in order ARGB.
+ if (vb != null)
+ {
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vb[idst] = pb.A; idst += cpix; }
+ if (r) { vb[idst] = pb.R; idst += cpix; }
+ if (g) { vb[idst] = pb.G; idst += cpix; }
+ if (b) { vb[idst] = pb.B; idst += cpix; }
+ }
+ }
+ else if (!needScale)
+ {
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = pb.A; idst += cpix; }
+ if (r) { vf[idst] = pb.R; idst += cpix; }
+ if (g) { vf[idst] = pb.G; idst += cpix; }
+ if (b) { vf[idst] = pb.B; idst += cpix; }
+ }
+ }
+ else
+ {
+ for (int x = 0; x < w; x++, idstBase++)
+ {
+ var pb = src.GetPixel(x, y);
+ int idst = idstBase;
+ if (a) { vf[idst] = (pb.A - offset) * scale; idst += cpix; }
+ if (r) { vf[idst] = (pb.R - offset) * scale; idst += cpix; }
+ if (g) { vf[idst] = (pb.G - offset) * scale; idst += cpix; }
+ if (b) { vf[idst] = (pb.B - offset) * scale; idst += cpix; }
+ }
+ }
+ }
+ }
+
+ dst = new VBuffer(size, values, dst.Indices);
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
new file mode 100644
index 0000000000..dd1abc9181
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs
@@ -0,0 +1,370 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Drawing;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), typeof(ImageResizerTransform.Arguments),
+ typeof(SignatureDataTransform), ImageResizerTransform.UserName, "ImageResizerTransform", "ImageResizer")]
+
+[assembly: LoadableClass(ImageResizerTransform.Summary, typeof(ImageResizerTransform), null, typeof(SignatureLoadDataTransform),
+ ImageResizerTransform.UserName, ImageResizerTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ // REVIEW: Rewrite as LambdaTransform to simplify.
+ ///
+ /// Transform which takes one or many columns of and resize them to provided height and width.
+ ///
+ public sealed class ImageResizerTransform : OneToOneTransformBase
+ {
+ public enum ResizingKind : byte
+ {
+ [TGUI(Label = "Isotropic with Padding")]
+ IsoPad = 0,
+
+ [TGUI(Label = "Isotropic with Cropping")]
+ IsoCrop = 1
+ }
+
+ public enum Anchor : byte
+ {
+ Right = 0,
+ Left = 1,
+ Top = 2,
+ Bottom = 3,
+ Center = 4
+ }
+
+ public sealed class Column : OneToOneColumn
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Width of the resized image", ShortName = "width")]
+ public int? ImageWidth;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Height of the resized image", ShortName = "height")]
+ public int? ImageHeight;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Resizing method", ShortName = "scale")]
+ public ResizingKind? Resizing;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Anchor for cropping", ShortName = "anchor")]
+ public Anchor? CropAnchor;
+
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ if (ImageWidth != null || ImageHeight != null || Resizing != null || CropAnchor != null)
+ return false;
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.Required, HelpText = "Resized width of the image", ShortName = "width")]
+ public int ImageWidth;
+
+ [Argument(ArgumentType.Required, HelpText = "Resized height of the image", ShortName = "height")]
+ public int ImageHeight;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Resizing method", ShortName = "scale")]
+ public ResizingKind Resizing = ResizingKind.IsoCrop;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Anchor for cropping", ShortName = "anchor")]
+ public Anchor CropAnchor = Anchor.Center;
+ }
+
+ ///
+ /// Extra information for each column (in addition to ColumnInfo).
+ ///
+ private sealed class ColInfoEx
+ {
+ public readonly int Width;
+ public readonly int Height;
+ public readonly ResizingKind Scale;
+ public readonly Anchor Anchor;
+ public readonly ColumnType Type;
+
+ public ColInfoEx(int width, int height, ResizingKind scale, Anchor anchor)
+ {
+ Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth));
+ Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight));
+ Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), scale), nameof(Column.Resizing));
+ Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor));
+
+ Width = width;
+ Height = height;
+ Scale = scale;
+ Anchor = anchor;
+ Type = new ImageType(Height, Width);
+ }
+ }
+
+ internal const string Summary = "Scales an image to specified dimensions using one of the three scale types: isotropic with padding, "
+ + "isotropic with cropping or anisotropic. In case of isotropic padding, transparent color is used to pad resulting image.";
+
+ internal const string UserName = "Image Resizer Transform";
+ public const string LoaderSignature = "ImageScalerTransform";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "IMGSCALF",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
+ }
+
+ private const string RegistrationName = "ImageScaler";
+
+ // This is parallel to Infos.
+ private readonly ColInfoEx[] _exes;
+
+ // Public constructor corresponding to SignatureDataTransform.
+ public ImageResizerTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertNonEmpty(Infos);
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+
+ _exes = new ColInfoEx[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ var item = args.Column[i];
+ _exes[i] = new ColInfoEx(
+ item.ImageWidth ?? args.ImageWidth,
+ item.ImageHeight ?? args.ImageHeight,
+ item.Resizing ?? args.Resizing,
+ item.CropAnchor ?? args.CropAnchor);
+ }
+ Metadata.Seal();
+ }
+
+ private ImageResizerTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, t => t is ImageType ? null : "Expected Image type")
+ {
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ //
+ //
+ // for each added column
+ // int: width
+ // int: height
+ // byte: scaling kind
+ Host.AssertNonEmpty(Infos);
+
+ _exes = new ColInfoEx[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ int width = ctx.Reader.ReadInt32();
+ Host.CheckDecode(width > 0);
+ int height = ctx.Reader.ReadInt32();
+ Host.CheckDecode(height > 0);
+ var scale = (ResizingKind)ctx.Reader.ReadByte();
+ Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale));
+ var anchor = (Anchor)ctx.Reader.ReadByte();
+ Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor));
+ _exes[i] = new ColInfoEx(width, height, scale, anchor);
+ }
+ Metadata.Seal();
+ }
+
+ public static ImageResizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+ return h.Apply("Loading Model",
+ ch =>
+ {
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ int cbFloat = ctx.Reader.ReadInt32();
+ ch.CheckDecode(cbFloat == sizeof(Single));
+ return new ImageResizerTransform(h, ctx, input);
+ });
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ // for each added column
+ // int: width
+ // int: height
+ // byte: scaling kind
+ ctx.Writer.Write(sizeof(Single));
+ SaveBase(ctx);
+
+ Host.Assert(_exes.Length == Infos.Length);
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ var ex = _exes[i];
+ ctx.Writer.Write(ex.Width);
+ ctx.Writer.Write(ex.Height);
+ Host.Assert((ResizingKind)(byte)ex.Scale == ex.Scale);
+ ctx.Writer.Write((byte)ex.Scale);
+ Host.Assert((Anchor)(byte)ex.Anchor == ex.Anchor);
+ ctx.Writer.Write((byte)ex.Anchor);
+ }
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Check(0 <= iinfo && iinfo < Infos.Length);
+ return _exes[iinfo].Type;
+ }
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValueOrNull(ch);
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+
+ var src = default(Bitmap);
+ var getSrc = GetSrcGetter(input, iinfo);
+ var ex = _exes[iinfo];
+
+ disposer =
+ () =>
+ {
+ if (src != null)
+ {
+ src.Dispose();
+ src = null;
+ }
+ };
+
+ ValueGetter del =
+ (ref Bitmap dst) =>
+ {
+ if (dst != null)
+ dst.Dispose();
+
+ getSrc(ref src);
+ if (src == null || src.Height <= 0 || src.Width <= 0)
+ return;
+ if (src.Height == ex.Height && src.Width == ex.Width)
+ {
+ dst = src;
+ return;
+ }
+
+ int sourceWidth = src.Width;
+ int sourceHeight = src.Height;
+ int sourceX = 0;
+ int sourceY = 0;
+ int destX = 0;
+ int destY = 0;
+ int destWidth = 0;
+ int destHeight = 0;
+ float aspect = 0;
+ float widthAspect = 0;
+ float heightAspect = 0;
+
+ widthAspect = (float)ex.Width / sourceWidth;
+ heightAspect = (float)ex.Height / sourceHeight;
+
+ if (ex.Scale == ResizingKind.IsoPad)
+ {
+ widthAspect = (float)ex.Width / sourceWidth;
+ heightAspect = (float)ex.Height / sourceHeight;
+ if (heightAspect < widthAspect)
+ {
+ aspect = heightAspect;
+ destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
+ }
+ else
+ {
+ aspect = widthAspect;
+ destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
+ }
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
+ }
+ else
+ {
+ if (heightAspect < widthAspect)
+ {
+ aspect = widthAspect;
+ switch (ex.Anchor)
+ {
+ case Anchor.Top:
+ destY = 0;
+ break;
+ case Anchor.Bottom:
+ destY = (int)(ex.Height - (sourceHeight * aspect));
+ break;
+ default:
+ destY = (int)((ex.Height - (sourceHeight * aspect)) / 2);
+ break;
+ }
+ }
+ else
+ {
+ aspect = heightAspect;
+ switch (ex.Anchor)
+ {
+ case Anchor.Left:
+ destX = 0;
+ break;
+ case Anchor.Right:
+ destX = (int)(ex.Width - (sourceWidth * aspect));
+ break;
+ default:
+ destX = (int)((ex.Width - (sourceWidth * aspect)) / 2);
+ break;
+ }
+ }
+
+ destWidth = (int)(sourceWidth * aspect);
+ destHeight = (int)(sourceHeight * aspect);
+ }
+ dst = new Bitmap(ex.Width, ex.Height);
+ var srcRectangle = new Rectangle(sourceX, sourceY, sourceWidth, sourceHeight);
+ var destRectangle = new Rectangle(destX, destY, destWidth, destHeight);
+ using (var g = Graphics.FromImage(dst))
+ {
+ g.DrawImage(src, destRectangle, srcRectangle, GraphicsUnit.Pixel);
+ }
+ Host.Assert(dst.Width == ex.Width && dst.Height == ex.Height);
+ };
+
+ return del;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/ImageType.cs b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
new file mode 100644
index 0000000000..bc822f80b0
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/ImageType.cs
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Drawing;
+using Microsoft.ML.Runtime.Data;
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ public sealed class ImageType : StructuredType
+ {
+ public readonly int Height;
+ public readonly int Width;
+ public ImageType(int height, int width)
+ : base(typeof(Bitmap))
+ {
+ Contracts.CheckParam(height > 0, nameof(height));
+ Contracts.CheckParam(width > 0, nameof(width));
+ Contracts.CheckParam((long)height * width <= int.MaxValue / 4, nameof(height), nameof(height) + " * " + nameof(width) + " is too large");
+ Height = height;
+ Width = width;
+ }
+
+ public ImageType() : base(typeof(Image))
+ {
+ }
+
+ public override bool Equals(ColumnType other)
+ {
+ if (other == this)
+ return true;
+ if (!(other is ImageType tmp))
+ return false;
+ if (Height != tmp.Height)
+ return false;
+ return Width != tmp.Width;
+ }
+
+ public override string ToString()
+ {
+ if (Height == 0 && Width == 0)
+ return "Image";
+ return string.Format("Image<{0}, {1}>", Height, Width);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj b/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj
new file mode 100644
index 0000000000..1a4fa6b66d
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.csproj
@@ -0,0 +1,18 @@
+
+
+
+ netstandard2.0
+ Microsoft.ML.Runtime.ImageAnalytics
+ Microsoft.ML.Runtime.ImageAnalytics
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
new file mode 100644
index 0000000000..b9d35a6cdc
--- /dev/null
+++ b/src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
@@ -0,0 +1,419 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Drawing;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.ImageAnalytics;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(VectorToImageTransform.Summary, typeof(VectorToImageTransform), typeof(VectorToImageTransform.Arguments),
+ typeof(SignatureDataTransform), VectorToImageTransform.UserName, "VectorToImageTransform", "VectorToImage")]
+
+[assembly: LoadableClass(VectorToImageTransform.Summary, typeof(VectorToImageTransform), null, typeof(SignatureLoadDataTransform),
+ VectorToImageTransform.UserName, VectorToImageTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.ImageAnalytics
+{
+ // REVIEW: Rewrite as LambdaTransform to simplify.
+
+ ///
+ /// Transform which takes one or many columns with vectors in them and transform them to representation.
+ ///
+ public sealed class VectorToImageTransform : OneToOneTransformBase
+ {
+ public class Column : OneToOneColumn
+ {
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use alpha channel", ShortName = "alpha")]
+ public bool? ContainsAlpha;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use red channel", ShortName = "red")]
+ public bool? ContainsRed;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use green channel", ShortName = "green")]
+ public bool? ContainsGreen;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use blue channel", ShortName = "blue")]
+ public bool? ContainsBlue;
+
+ // REVIEW: Consider turning this into an enum that allows for pixel, line, or planar interleaving.
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to separate each channel or interleave in ARGB order", ShortName = "interleave")]
+ public bool? InterleaveArgb;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Width of the image", ShortName = "width")]
+ public int? ImageWidth;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Height of the image", ShortName = "height")]
+ public int? ImageHeight;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Offset (pre-scale)")]
+ public Single? Offset;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Scale factor")]
+ public Single? Scale;
+
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ if (ContainsAlpha != null || ContainsRed != null || ContainsGreen != null || ContainsBlue != null || ImageWidth != null ||
+ ImageHeight != null || Offset != null || Scale != null || InterleaveArgb != null)
+ {
+ return false;
+ }
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use alpha channel", ShortName = "alpha")]
+ public bool ContainsAlpha = false;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use red channel", ShortName = "red")]
+ public bool ContainsRed = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use green channel", ShortName = "green")]
+ public bool ContainsGreen = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use blue channel", ShortName = "blue")]
+ public bool ContainsBlue = true;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to separate each channel or interleave in ARGB order", ShortName = "interleave")]
+ public bool InterleaveArgb = false;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Width of the image", ShortName = "width")]
+ public int ImageWidth;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Height of the image", ShortName = "height")]
+ public int ImageHeight;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Offset (pre-scale)")]
+ public Single? Offset;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Scale factor")]
+ public Single? Scale;
+ }
+
+ ///
+ /// Which color channels are extracted. Note that these values are serialized so should not be modified.
+ ///
+ [Flags]
+ private enum ColorBits : byte
+ {
+ Alpha = 0x01,
+ Red = 0x02,
+ Green = 0x04,
+ Blue = 0x08,
+
+ All = Alpha | Red | Green | Blue
+ }
+
+ private sealed class ColInfoEx
+ {
+ public readonly ColorBits Colors;
+ public readonly byte Planes;
+
+ public readonly int Width;
+ public readonly int Height;
+ public readonly Single Offset;
+ public readonly Single Scale;
+ public readonly bool Interleave;
+
+ public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } }
+ public bool Red { get { return (Colors & ColorBits.Red) != 0; } }
+ public bool Green { get { return (Colors & ColorBits.Green) != 0; } }
+ public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } }
+
+ public ColInfoEx(Column item, Arguments args)
+ {
+ if (item.ContainsAlpha ?? args.ContainsAlpha) { Colors |= ColorBits.Alpha; Planes++; }
+ if (item.ContainsRed ?? args.ContainsRed) { Colors |= ColorBits.Red; Planes++; }
+ if (item.ContainsGreen ?? args.ContainsGreen) { Colors |= ColorBits.Green; Planes++; }
+ if (item.ContainsBlue ?? args.ContainsBlue) { Colors |= ColorBits.Blue; Planes++; }
+ Contracts.CheckUserArg(Planes > 0, nameof(item.ContainsRed), "Need to use at least one color plane");
+
+ Interleave = item.InterleaveArgb ?? args.InterleaveArgb;
+
+ Width = item.ImageWidth ?? args.ImageWidth;
+ Height = item.ImageHeight ?? args.ImageHeight;
+ Offset = item.Offset ?? args.Offset ?? 0;
+ Scale = item.Scale ?? args.Scale ?? 1;
+ Contracts.CheckUserArg(FloatUtils.IsFinite(Offset), nameof(item.Offset));
+ Contracts.CheckUserArg(FloatUtils.IsFiniteNonZero(Scale), nameof(item.Scale));
+ }
+
+ public ColInfoEx(ModelLoadContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+ // *** Binary format ***
+ // byte: colors
+ // int: widht
+ // int: height
+ // Float: offset
+ // Float: scale
+ // byte: separateChannels
+ Colors = (ColorBits)ctx.Reader.ReadByte();
+ Contracts.CheckDecode(Colors != 0);
+ Contracts.CheckDecode((Colors & ColorBits.All) == Colors);
+
+ // Count the planes.
+ int planes = (int)Colors;
+ planes = (planes & 0x05) + ((planes >> 1) & 0x05);
+ planes = (planes & 0x03) + ((planes >> 2) & 0x03);
+ Planes = (byte)planes;
+ Contracts.Assert(0 < Planes & Planes <= 4);
+
+ Width = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(Width > 0);
+ Height = ctx.Reader.ReadInt32();
+ Contracts.CheckDecode(Height > 0);
+ Offset = ctx.Reader.ReadFloat();
+ Contracts.CheckDecode(FloatUtils.IsFinite(Offset));
+ Scale = ctx.Reader.ReadFloat();
+ Contracts.CheckDecode(FloatUtils.IsFiniteNonZero(Scale));
+ Interleave = ctx.Reader.ReadBoolByte();
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ Contracts.AssertValue(ctx);
+
+#if DEBUG
+ // This code is used in deserialization - assert that it matches what we computed above.
+ int planes = (int)Colors;
+ planes = (planes & 0x05) + ((planes >> 1) & 0x05);
+ planes = (planes & 0x03) + ((planes >> 2) & 0x03);
+ Contracts.Assert(planes == Planes);
+#endif
+
+ // *** Binary format ***
+ // byte: colors
+ // byte: convert
+ // Float: offset
+ // Float: scale
+ // byte: separateChannels
+ Contracts.Assert(Colors != 0);
+ Contracts.Assert((Colors & ColorBits.All) == Colors);
+ ctx.Writer.Write((byte)Colors);
+ ctx.Writer.Write(Width);
+ ctx.Writer.Write(Height);
+ Contracts.Assert(FloatUtils.IsFinite(Offset));
+ ctx.Writer.Write(Offset);
+ Contracts.Assert(FloatUtils.IsFiniteNonZero(Scale));
+ ctx.Writer.Write(Scale);
+ ctx.Writer.WriteBoolByte(Interleave);
+ }
+ }
+
+ public const string Summary = "Converts vector array into image type.";
+ public const string UserName = "Vector To Image Transform";
+ public const string LoaderSignature = "VectorToImageConverter";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "VECTOIMG",
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Swith from OpenCV to Bitmap
+ verReadableCur: 0x00010002,
+ verWeCanReadBack: 0x00010002,
+ loaderSignature: LoaderSignature);
+ }
+
+ private const string RegistrationName = "VectorToImageConverter";
+
+ private readonly ColInfoEx[] _exes;
+ private readonly ImageType[] _types;
+
+ // Public constructor corresponding to SignatureDataTransform.
+ public VectorToImageTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input,
+ t => t is VectorType ? null : "Expected VectorType type")
+ {
+ Host.AssertNonEmpty(Infos);
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+
+ _exes = new ColInfoEx[Infos.Length];
+ _types = new ImageType[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ var item = args.Column[i];
+ _exes[i] = new ColInfoEx(item, args);
+ _types[i] = new ImageType(_exes[i].Height, _exes[i].Width);
+ }
+ Metadata.Seal();
+ }
+
+ private VectorToImageTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, t => t is VectorType ? null : "Expected VectorType type")
+ {
+ Host.AssertValue(ctx);
+
+ // *** Binary format ***
+ //
+ //
+ // foreach added column
+ // ColInfoEx
+ Host.AssertNonEmpty(Infos);
+ _exes = new ColInfoEx[Infos.Length];
+ _types = new ImageType[Infos.Length];
+ for (int i = 0; i < _exes.Length; i++)
+ {
+ _exes[i] = new ColInfoEx(ctx);
+ _types[i] = new ImageType(_exes[i].Height, _exes[i].Width);
+ }
+ Metadata.Seal();
+ }
+
+ public static VectorToImageTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+
+ return h.Apply("Loading Model",
+ ch =>
+ {
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ int cbFloat = ctx.Reader.ReadInt32();
+ ch.CheckDecode(cbFloat == sizeof(Single));
+ return new VectorToImageTransform(h, ctx, input);
+ });
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // int: sizeof(Float)
+ //
+ // foreach added column
+ // ColInfoEx
+ ctx.Writer.Write(sizeof(Single));
+ SaveBase(ctx);
+
+ Host.Assert(_exes.Length == Infos.Length);
+ for (int i = 0; i < _exes.Length; i++)
+ _exes[i].Save(ctx);
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Assert(0 <= iinfo & iinfo < Infos.Length);
+ return _types[iinfo];
+ }
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValueOrNull(ch);
+ Host.AssertValue(input);
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+
+ var type = _types[iinfo];
+ var ex = _exes[iinfo];
+ bool needScale = ex.Offset != 0 || ex.Scale != 1;
+ disposer = null;
+ var sourceType = Schema.GetColumnType(Infos[iinfo].Source);
+ if (sourceType.ItemType == NumberType.R4 || sourceType.ItemType == NumberType.R8)
+ return GetterFromType(input, iinfo, ex, needScale);
+ else
+ if (sourceType.ItemType == NumberType.U1)
+ return GetterFromType(input, iinfo, ex, false);
+ else
+ throw Contracts.Except("We only support float or byte arrays");
+
+ }
+
+ private ValueGetter GetterFromType(IRow input, int iinfo, ColInfoEx ex, bool needScale) where TValue : IConvertible
+ {
+ var getSrc = GetSrcGetter>(input, iinfo);
+ var src = default(VBuffer);
+ int width = ex.Width;
+ int height = ex.Height;
+ float offset = ex.Offset;
+ float scale = ex.Scale;
+
+ return
+ (ref Bitmap dst) =>
+ {
+ getSrc(ref src);
+ if (src.Count == 0)
+ {
+ dst = null;
+ return;
+ }
+ VBuffer dense = default;
+ src.CopyToDense(ref dense);
+ var values = dense.Values;
+ dst = new Bitmap(width, height);
+ dst.SetResolution(width, height);
+ int cpix = height * width;
+ int planes = dense.Count / cpix;
+ int position = 0;
+
+ for (int x = 0; x < width; x++)
+ for (int y = 0; y < height; ++y)
+ {
+ float red = 0;
+ float green = 0;
+ float blue = 0;
+ float alpha = 0;
+ if (ex.Interleave)
+ {
+ if (ex.Alpha) position++;
+ if (ex.Red) red = Convert.ToSingle(values[position++]);
+ if (ex.Green) green = Convert.ToSingle(values[position++]);
+ if (ex.Blue) blue = Convert.ToSingle(values[position++]);
+ }
+ else
+ {
+ position = y * width + x;
+ if (ex.Alpha) { alpha = Convert.ToSingle(values[position]); position += cpix; }
+ if (ex.Red) { red = Convert.ToSingle(values[position]); position += cpix; }
+ if (ex.Green) { green = Convert.ToSingle(values[position]); position += cpix; }
+ if (ex.Blue) { blue = Convert.ToSingle(values[position]); position += cpix; }
+ }
+ Color pixel;
+ if (!needScale)
+ pixel = Color.FromArgb((int)alpha, (int)red, (int)green, (int)blue);
+ else
+ {
+ pixel = Color.FromArgb(
+ (int)((alpha - offset) * scale),
+ (int)((red - offset) * scale),
+ (int)((green - offset) * scale),
+ (int)((blue - offset) * scale));
+ }
+ dst.SetPixel(x, y, pixel);
+ }
+ };
+ }
+ }
+}
+
diff --git a/src/Microsoft.ML.KMeansClustering/doc.xml b/src/Microsoft.ML.KMeansClustering/doc.xml
index a1590595dc..b4318de334 100644
--- a/src/Microsoft.ML.KMeansClustering/doc.xml
+++ b/src/Microsoft.ML.KMeansClustering/doc.xml
@@ -13,7 +13,7 @@
YYK-Means observes that there is a lot of redundancy across iterations in the KMeans algorithms and most points do not change their clusters during an iteration.
It uses various bounding techniques to identify this redundancy and eliminate many distance computations and optimize centroid computations.
For more information on K-means, and K-means++ see:
-
+
- K-means
- K-means++
diff --git a/src/Microsoft.ML.PCA/doc.xml b/src/Microsoft.ML.PCA/doc.xml
index c4f0be7758..5054950c2d 100644
--- a/src/Microsoft.ML.PCA/doc.xml
+++ b/src/Microsoft.ML.PCA/doc.xml
@@ -11,7 +11,7 @@
Its training is done using the technique described in the paper: Combining Structured and Unstructured Randomness in Large Scale PCA,
and the paper Finding Structure with Randomness: Probabilistic Algorithms for Constructing Approximate Matrix Decompositions
For more information, see also:
-
+
-
Randomized Methods for Computing the Singular Value Decomposition (SVD) of very large matrices
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
index 62270763de..0980cb22dc 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs
@@ -103,7 +103,7 @@ public FieldAwareFactorizationMachineTrainer(IHostEnvironment env, Arguments arg
_shuffle = args.Shuffle;
_verbose = args.Verbose;
_radius = args.Radius;
- Info = new TrainerInfo();
+ Info = new TrainerInfo(supportValid: true, supportIncrementalTrain: true);
}
private void InitializeTrainingState(int fieldCount, int featureCount, FieldAwareFactorizationMachinePredictor predictor, out float[] linearWeights,
diff --git a/src/Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml b/src/Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml
index f18bf60990..bdcb973439 100644
--- a/src/Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml
+++ b/src/Microsoft.ML.StandardLearners/FactorizationMachine/doc.xml
@@ -15,14 +15,14 @@
See references below for more details.
This trainer is essentially faster the one introduced in [2] because of some implemtation tricks[3].
-
+
-
- [1] Field-aware Factorization Machines for CTR Prediction
+ Field-aware Factorization Machines for CTR Prediction
-
- [2] Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
+ Adaptive Subgradient Methods for Online Learning and Stochastic Optimization
-
- [3] An Improved Stochastic Gradient Method for Training Large-scale Field-aware Factorization Machine.
+ An Improved Stochastic Gradient Method for Training Large-scale Field-aware Factorization Machine.
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
index 04d3e2d9e4..cc1e090bd0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs
@@ -207,10 +207,9 @@ internal virtual void Check(IHostEnvironment env)
{
using (var ch = env.Start("SDCA arguments checking"))
{
- ch.Warning("The specified l2Const = {0} is too small. SDCA optimizes the dual objective function. " +
- "The dual formulation is only valid with a positive L2 regularization. Also, an l2Const less than {1} " +
- "could drastically slow down the convergence. So using l2Const = {1} instead.", L2Const);
-
+ ch.Warning($"The L2 regularization constant must be at least {L2LowerBound}. In SDCA, the dual formulation " +
+ $"is only valid with a positive constant, and values below {L2LowerBound} cause very slow convergence. " +
+ $"The original {nameof(L2Const)} = {L2Const}, was replaced with {nameof(L2Const)} = {L2LowerBound}.");
L2Const = L2LowerBound;
ch.Done();
}
@@ -1752,4 +1751,4 @@ public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironm
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
index 8f81324768..82069d413d 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs
@@ -553,7 +553,7 @@ public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICali
public abstract class RegressionPredictor : LinearPredictor
{
- internal RegressionPredictor(IHostEnvironment env, string name, ref VBuffer weights, Float bias)
+ protected RegressionPredictor(IHostEnvironment env, string name, ref VBuffer weights, Float bias)
: base(env, name, ref weights, bias)
{
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
index 09e2bbbcc4..47b08c586a 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LogisticRegression.cs
@@ -28,7 +28,6 @@
namespace Microsoft.ML.Runtime.Learners
{
- using Mkl = Microsoft.ML.Runtime.Learners.OlsLinearRegressionTrainer.Mkl;
///
///
@@ -282,64 +281,7 @@ protected override void ComputeTrainingStatistics(IChannel ch, FloatLabelCursor.
}
}
- // Apply Cholesky Decomposition to find the inverse of the Hessian.
- Double[] invHessian = null;
- try
- {
- // First, find the Cholesky decomposition LL' of the Hessian.
- Mkl.Pptrf(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian);
- // Note that hessian is already modified at this point. It is no longer the original Hessian,
- // but instead represents the Cholesky decomposition L.
- // Also note that the following routine is supposed to consume the Cholesky decomposition L instead
- // of the original information matrix.
- Mkl.Pptri(Mkl.Layout.RowMajor, Mkl.UpLo.Lo, numParams, hessian);
- // At this point, hessian should contain the inverse of the original Hessian matrix.
- // Swap hessian with invHessian to avoid confusion in the following context.
- Utils.Swap(ref hessian, ref invHessian);
- Contracts.Assert(hessian == null);
- }
- catch (DllNotFoundException)
- {
- throw ch.ExceptNotSupp("The MKL library (Microsoft.ML.MklImports.dll) or one of its dependencies is missing.");
- }
-
- Float[] stdErrorValues = new Float[numParams];
- stdErrorValues[0] = (Float)Math.Sqrt(invHessian[0]);
-
- for (int i = 1; i < numParams; i++)
- {
- // Initialize with inverse Hessian.
- stdErrorValues[i] = (Single)invHessian[i * (i + 1) / 2 + i];
- }
-
- if (L2Weight > 0)
- {
- // Iterate through all entries of inverse Hessian to make adjustment to variance.
- // A discussion on ridge regularized LR coefficient covariance matrix can be found here:
- // http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3228544/
- // http://www.inf.unibz.it/dis/teaching/DWDM/project2010/LogisticRegression.pdf
- int ioffset = 1;
- for (int iRow = 1; iRow < numParams; iRow++)
- {
- for (int iCol = 0; iCol <= iRow; iCol++)
- {
- var entry = (Single)invHessian[ioffset];
- var adjustment = -L2Weight * entry * entry;
- stdErrorValues[iRow] -= adjustment;
- if (0 < iCol && iCol < iRow)
- stdErrorValues[iCol] -= adjustment;
- ioffset++;
- }
- }
-
- Contracts.Assert(ioffset == invHessian.Length);
- }
-
- for (int i = 1; i < numParams; i++)
- stdErrorValues[i] = (Float)Math.Sqrt(stdErrorValues[i]);
-
- VBuffer stdErrors = new VBuffer(CurrentWeights.Length, numParams, stdErrorValues, weightIndices);
- _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance, ref stdErrors);
+ _stats = new LinearModelStatistics(Host, NumGoodRows, numParams, deviance, nullDeviance);
}
protected override void ProcessPriorDistribution(Float label, Float weight)
@@ -382,7 +324,7 @@ protected override ParameterMixingCalibratedPredictor CreatePredictor()
CurrentWeights.GetItemOrDefault(0, ref bias);
CurrentWeights.CopyTo(ref weights, 1, CurrentWeights.Length - 1);
return new ParameterMixingCalibratedPredictor(Host,
- new LinearBinaryPredictor(Host, ref weights, bias, _stats),
+ new LinearBinaryPredictor(Host, ref weights, bias),
new PlattCalibrator(Host, -1, 0));
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
index 52cd025370..23a81f78ee 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs
@@ -10,6 +10,7 @@
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Training;
+using Microsoft.ML.Runtime.EntryPoints;
namespace Microsoft.ML.Runtime.Learners
{
@@ -21,13 +22,12 @@ public abstract class MetaMulticlassTrainer : TrainerBase
{
public abstract class ArgumentsBase
{
- [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1)]
+ [Argument(ArgumentType.Multiple, HelpText = "Base predictor", ShortName = "p", SortOrder = 1, SignatureType = typeof(SignatureBinaryClassifierTrainer))]
[TGUI(Label = "Predictor Type", Description = "Type of underlying binary predictor")]
- public SubComponent PredictorType =
- new SubComponent(LinearSvm.LoadNameValue);
+ public IComponentFactory PredictorType;
- [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "")]
- public SubComponent Calibrator = new SubComponent("PlattCalibration");
+ [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", NullName = "", SignatureType = typeof(SignatureCalibrator))]
+ public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory();
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", ShortName = "numcali")]
public int MaxCalibrationExamples = 1000000000;
@@ -47,14 +47,20 @@ internal MetaMulticlassTrainer(IHostEnvironment env, TArgs args, string name)
{
Host.CheckValue(args, nameof(args));
Args = args;
- Host.CheckUserArg(Args.PredictorType.IsGood(), nameof(Args.PredictorType));
// Create the first trainer so errors in the args surface early.
- _trainer = Args.PredictorType.CreateInstance(Host);
+ _trainer = CreateTrainer();
// Regarding caching, no matter what the internal predictor, we're performing many passes
// simply by virtue of this being a meta-trainer, so we will still cache.
Info = new TrainerInfo(normalization: _trainer.Info.NeedNormalization);
}
+ private TScalarTrainer CreateTrainer()
+ {
+ return Args.PredictorType != null ?
+ Args.PredictorType.CreateComponent(Host) :
+ new LinearSvm(Host, new LinearSvm.Arguments());
+ }
+
protected IDataView MapLabelsCore(ColumnType type, RefPredicate equalsTarget, RoleMappedData data, string dstName)
{
Host.AssertValue(type);
@@ -84,7 +90,7 @@ protected TScalarTrainer GetTrainer()
{
// We may have instantiated the first trainer to use already, from the constructor.
// If so capture it and set the retained trainer to null; otherwise create a new one.
- var train = _trainer ?? Args.PredictorType.CreateInstance(Host);
+ var train = _trainer ?? CreateTrainer();
_trainer = null;
return train;
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 0a73d55395..8c96ee1e0b 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -123,8 +123,8 @@ public override MultiClassNaiveBayesPredictor Train(TrainContext context)
Desc = "Train a MultiClassNaiveBayesTrainer.",
UserName = UserName,
ShortName = ShortName,
- XmlInclude = new[] { @"",
- @"" })]
+ XmlInclude = new[] { @"",
+ @"" })]
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments input)
{
Contracts.CheckValue(env, nameof(env));
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index c123411edd..24359feca3 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -89,10 +89,10 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
if (Args.UseProbabilities)
{
ICalibratorTrainer calibrator;
- if (!Args.Calibrator.IsGood())
+ if (Args.Calibrator == null)
calibrator = null;
else
- calibrator = Args.Calibrator.CreateInstance(Host);
+ calibrator = Args.Calibrator.CreateComponent(Host);
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
trainer, predictor, td);
predictor = res as TScalarPredictor;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
index 193c8f0290..6434038384 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Pkpd.cs
@@ -104,10 +104,10 @@ private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedD
var predictor = trainer.Train(td);
ICalibratorTrainer calibrator;
- if (!Args.Calibrator.IsGood())
+ if (Args.Calibrator == null)
calibrator = null;
else
- calibrator = Args.Calibrator.CreateInstance(Host);
+ calibrator = Args.Calibrator.CreateComponent(Host);
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
trainer, predictor, td);
var dist = res as TDistPredictor;
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
index 60fe7f9705..4976bf20d3 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs
@@ -84,7 +84,7 @@ protected OnlineLinearTrainer(TArguments args, IHostEnvironment env, string name
Args = args;
// REVIEW: Caching could be false for one iteration, if we got around the whole shuffling issue.
- Info = new TrainerInfo(calibration: NeedCalibration);
+ Info = new TrainerInfo(calibration: NeedCalibration, supportIncrementalTrain: true);
}
///
diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
index 0ace721221..2ad6e77aa0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
+++ b/src/Microsoft.ML.StandardLearners/Standard/Online/doc.xml
@@ -13,8 +13,8 @@
and an option to update the weight vector using the average of the vectors seen over time (averaged argument is set to True by default).
-
-
+
+
new OnlineGradientDescentRegressor
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/doc.xml
index ec14c9446b..975f1eb2ff 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/doc.xml
+++ b/src/Microsoft.ML.StandardLearners/Standard/PoissonRegression/doc.xml
@@ -12,8 +12,8 @@
Assuming that the dependent variable follows a Poisson distribution, the parameters of the regressor can be estimated by maximizing the likelihood of the obtained observations.
-
-
+
+
new PoissonRegressor
{
diff --git a/src/Microsoft.ML.StandardLearners/Standard/doc.xml b/src/Microsoft.ML.StandardLearners/Standard/doc.xml
index a704827b88..eb87605232 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/doc.xml
+++ b/src/Microsoft.ML.StandardLearners/Standard/doc.xml
@@ -22,7 +22,7 @@
In general, the larger the 'L2Const', the faster SDCA converges.
For more information, see also:
-
+
-
Scaling Up Stochastic Dual Coordinate Ascent.
diff --git a/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs b/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs
index c3ab4ea5e0..5733f84b6b 100644
--- a/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs
+++ b/src/Microsoft.ML.Transforms/EntryPoints/SelectFeatures.cs
@@ -14,8 +14,8 @@ public static class SelectFeatures
[TlcModule.EntryPoint(Name = "Transforms.FeatureSelectorByCount",
Desc = CountFeatureSelectionTransform.Summary,
UserName = CountFeatureSelectionTransform.UserName,
- XmlInclude = new[] { @"",
- @""})]
+ XmlInclude = new[] { @"",
+ @""})]
public static CommonOutputs.TransformOutput CountSelect(IHostEnvironment env, CountFeatureSelectionTransform.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
@@ -31,8 +31,8 @@ public static CommonOutputs.TransformOutput CountSelect(IHostEnvironment env, Co
Desc = MutualInformationFeatureSelectionTransform.Summary,
UserName = MutualInformationFeatureSelectionTransform.UserName,
ShortName = MutualInformationFeatureSelectionTransform.ShortName,
- XmlInclude = new[] { @"",
- @""})]
+ XmlInclude = new[] { @"",
+ @""})]
public static CommonOutputs.TransformOutput MutualInformationSelect(IHostEnvironment env, MutualInformationFeatureSelectionTransform.Arguments input)
{
Contracts.CheckValue(env, nameof(env));
diff --git a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
index 543b997e6d..3d29bce613 100644
--- a/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
+++ b/src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs
@@ -137,5 +137,25 @@ public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTr
OutputData = view
};
}
+
+ [TlcModule.EntryPoint(Name = "Transforms.WordEmbeddings",
+ Desc = WordEmbeddingsTransform.Summary,
+ UserName = WordEmbeddingsTransform.UserName,
+ ShortName = WordEmbeddingsTransform.ShortName,
+ XmlInclude = new[] { @"",
+ @"" })]
+ public static CommonOutputs.TransformOutput WordEmbeddings(IHostEnvironment env, WordEmbeddingsTransform.Arguments input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(input, nameof(input));
+
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, "WordEmbeddings", input);
+ var view = new WordEmbeddingsTransform(h, input, input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModel(h, view, input.Data),
+ OutputData = view
+ };
+ }
}
}
diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
index 55330cb6fb..0af833a046 100644
--- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
+++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
@@ -21,7 +21,7 @@
namespace Microsoft.ML.Runtime.Data
{
- ///
+ ///
public static class MutualInformationFeatureSelectionTransform
{
public const string Summary =
diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
index fb56fa2e0a..9d6f5564cf 100644
--- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
+++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
@@ -451,8 +451,8 @@ public sealed class TermLoaderArguments
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Data file containing the terms", ShortName = "data", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string DataFile;
- [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
- public SubComponent Loader;
+ [Argument(ArgumentType.Multiple, HelpText = "Data loader", NullName = "", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, SignatureType = typeof(SignatureDataLoader))]
+ public IComponentFactory Loader;
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the text column containing the terms", ShortName = "termCol", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly)]
public string TermsColumn;
diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs
new file mode 100644
index 0000000000..bf85ddf42f
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsTransform.cs
@@ -0,0 +1,444 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+
+[assembly: LoadableClass(WordEmbeddingsTransform.Summary, typeof(IDataTransform), typeof(WordEmbeddingsTransform), typeof(WordEmbeddingsTransform.Arguments),
+ typeof(SignatureDataTransform), WordEmbeddingsTransform.UserName, "WordEmbeddingsTransform", WordEmbeddingsTransform.ShortName, DocName = "transform/WordEmbeddingsTransform.md")]
+
+[assembly: LoadableClass(typeof(WordEmbeddingsTransform), null, typeof(SignatureLoadDataTransform),
+ WordEmbeddingsTransform.UserName, WordEmbeddingsTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ public sealed class WordEmbeddingsTransform : OneToOneTransformBase
+ {
+ public sealed class Column : OneToOneColumn
+ {
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public sealed class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 0)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Pre-trained model used to create the vocabulary", ShortName = "model", SortOrder = 1)]
+ public PretrainedModelKind? ModelKind = PretrainedModelKind.Sswe;
+
+ [Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Filename for custom word embedding model",
+ ShortName = "dataFile", SortOrder = 2)]
+ public string CustomLookupTable;
+ }
+
+ internal const string Summary = "Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence " +
+ "vectors using a pre-trained model";
+ internal const string UserName = "Word Embeddings Transform";
+ internal const string ShortName = "WordEmbeddings";
+ public const string LoaderSignature = "WordEmbeddingsTransform";
+
+ public static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "W2VTRANS",
+ verWrittenCur: 0x00010001, //Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ private readonly PretrainedModelKind? _modelKind;
+ private readonly string _modelFileNameWithPath;
+ private readonly Model _currentVocab;
+ private static object _embeddingsLock = new object();
+ private readonly VectorType _outputType;
+ private readonly bool _customLookup;
+ private readonly int _linesToSkip;
+ private static Dictionary> _vocab = new Dictionary>();
+
+ private sealed class Model
+ {
+ private readonly BigArray _wordVectors;
+ private readonly NormStr.Pool _pool;
+ public readonly int Dimension;
+
+ public Model(int dimension)
+ {
+ Dimension = dimension;
+ _wordVectors = new BigArray();
+ _pool = new NormStr.Pool();
+ }
+
+ public void AddWordVector(IChannel ch, string word, float[] wordVector)
+ {
+ ch.Assert(wordVector.Length == Dimension);
+ if (_pool.Get(word) == null)
+ {
+ _pool.Add(word);
+ _wordVectors.AddRange(wordVector, Dimension);
+ }
+ }
+
+ public bool GetWordVector(ref DvText word, float[] wordVector)
+ {
+ if (word.IsNA)
+ return false;
+ string rawWord = word.GetRawUnderlyingBufferInfo(out int ichMin, out int ichLim);
+ NormStr str = _pool.Get(rawWord, ichMin, ichLim);
+ if (str != null)
+ {
+ _wordVectors.CopyTo(str.Id * Dimension, wordVector, Dimension);
+ return true;
+ }
+ return false;
+ }
+ }
+
+ private const string RegistrationName = "WordEmbeddings";
+
+ private const int Timeout = 10 * 60 * 1000;
+
+ ///
+ /// Public constructor corresponding to .
+ ///
+ public WordEmbeddingsTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column,
+ input, TestIsTextVector)
+ {
+ if (args.ModelKind == null)
+ args.ModelKind = PretrainedModelKind.Sswe;
+ Host.CheckUserArg(!args.ModelKind.HasValue || Enum.IsDefined(typeof(PretrainedModelKind), args.ModelKind), nameof(args.ModelKind));
+ Host.AssertNonEmpty(Infos);
+ Host.Assert(Infos.Length == Utils.Size(args.Column));
+
+ _customLookup = !string.IsNullOrWhiteSpace(args.CustomLookupTable);
+ if (_customLookup)
+ {
+ _modelKind = null;
+ _modelFileNameWithPath = args.CustomLookupTable;
+ }
+ else
+ {
+ _modelKind = args.ModelKind;
+ _modelFileNameWithPath = EnsureModelFile(env, out _linesToSkip, (PretrainedModelKind)_modelKind);
+ }
+
+ Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
+ _currentVocab = GetVocabularyDictionary();
+ _outputType = new VectorType(NumberType.R4, 3 * _currentVocab.Dimension);
+ Metadata.Seal();
+ }
+
+ private WordEmbeddingsTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, ctx, input, TestIsTextVector)
+ {
+ Host.AssertValue(ctx);
+ Host.AssertNonEmpty(Infos);
+ _customLookup = ctx.Reader.ReadBoolByte();
+
+ if (_customLookup)
+ {
+ _modelFileNameWithPath = ctx.LoadNonEmptyString();
+ _modelKind = null;
+ }
+ else
+ {
+ _modelKind = (PretrainedModelKind)ctx.Reader.ReadUInt32();
+ _modelFileNameWithPath = EnsureModelFile(Host, out _linesToSkip, (PretrainedModelKind)_modelKind);
+ }
+
+ Host.CheckNonWhiteSpace(_modelFileNameWithPath, nameof(_modelFileNameWithPath));
+ _currentVocab = GetVocabularyDictionary();
+ _outputType = new VectorType(NumberType.R4, 3 * _currentVocab.Dimension);
+ Metadata.Seal();
+ }
+
+ public static WordEmbeddingsTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ IHost h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ return h.Apply("Loading Model",
+ ch => new WordEmbeddingsTransform(h, ctx, input));
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ SaveBase(ctx);
+ ctx.Writer.WriteBoolByte(_customLookup);
+ if (_customLookup)
+ ctx.SaveString(_modelFileNameWithPath);
+ else
+ ctx.Writer.Write((uint)_modelKind);
+ }
+
+ protected override ColumnType GetColumnTypeCore(int iinfo)
+ {
+ Host.Assert(0 <= iinfo && iinfo < Infos.Length);
+ return _outputType;
+ }
+
+ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
+ {
+ Host.AssertValue(ch);
+ ch.AssertValue(input);
+ ch.Assert(0 <= iinfo && iinfo < Infos.Length);
+ disposer = null;
+
+ var info = Infos[iinfo];
+ if (!info.TypeSrc.IsVector)
+ {
+ throw Host.ExceptParam(nameof(input),
+ "Text input given, expects a text vector");
+ }
+ return GetGetterVec(ch, input, iinfo);
+ }
+
+ private ValueGetter> GetGetterVec(IChannel ch, IRow input, int iinfo)
+ {
+ Host.AssertValue(ch);
+ ch.AssertValue(input);
+ ch.Assert(0 <= iinfo && iinfo < Infos.Length);
+
+ var info = Infos[iinfo];
+ ch.Assert(info.TypeSrc.IsVector);
+ ch.Assert(info.TypeSrc.ItemType.IsText);
+
+ var srcGetter = input.GetGetter>(info.Source);
+ var src = default(VBuffer);
+ int dimension = _currentVocab.Dimension;
+ float[] wordVector = new float[_currentVocab.Dimension];
+
+ return
+ (ref VBuffer dst) =>
+ {
+ int deno = 0;
+ srcGetter(ref src);
+ var values = dst.Values;
+ if (Utils.Size(values) != 3 * dimension)
+ values = new float[3 * dimension];
+ int offset = 2 * dimension;
+ for (int i = 0; i < dimension; i++)
+ {
+ values[i] = float.MaxValue;
+ values[i + dimension] = 0;
+ values[i + offset] = float.MinValue;
+ }
+ for (int word = 0; word < src.Count; word++)
+ {
+ if (_currentVocab.GetWordVector(ref src.Values[word], wordVector))
+ {
+ deno++;
+ for (int i = 0; i < dimension; i++)
+ {
+ float currentTerm = wordVector[i];
+ if (values[i] > currentTerm)
+ values[i] = currentTerm;
+ values[dimension + i] += currentTerm;
+ if (values[offset + i] < currentTerm)
+ values[offset + i] = currentTerm;
+ }
+ }
+ }
+
+ if (deno != 0)
+ for (int index = 0; index < dimension; index++)
+ values[index + dimension] /= deno;
+
+ dst = new VBuffer(values.Length, values, dst.Indices);
+ };
+ }
+
+ public enum PretrainedModelKind
+ {
+ [TGUI(Label = "GloVe 50D")]
+ GloVe50D = 0,
+
+ [TGUI(Label = "GloVe 100D")]
+ GloVe100D = 1,
+
+ [TGUI(Label = "GloVe 200D")]
+ GloVe200D = 2,
+
+ [TGUI(Label = "GloVe 300D")]
+ GloVe300D = 3,
+
+ [TGUI(Label = "GloVe Twitter 25D")]
+ GloVeTwitter25D = 4,
+
+ [TGUI(Label = "GloVe Twitter 50D")]
+ GloVeTwitter50D = 5,
+
+ [TGUI(Label = "GloVe Twitter 100D")]
+ GloVeTwitter100D = 6,
+
+ [TGUI(Label = "GloVe Twitter 200D")]
+ GloVeTwitter200D = 7,
+
+ [TGUI(Label = "fastText Wikipedia 300D")]
+ FastTextWikipedia300D = 8,
+
+ [TGUI(Label = "Sentiment-Specific Word Embedding")]
+ Sswe = 9
+ }
+
+ private static Dictionary _modelsMetaData = new Dictionary()
+ {
+ { PretrainedModelKind.GloVe50D, "glove.6B.50d.txt" },
+ { PretrainedModelKind.GloVe100D, "glove.6B.100d.txt" },
+ { PretrainedModelKind.GloVe200D, "glove.6B.200d.txt" },
+ { PretrainedModelKind.GloVe300D, "glove.6B.300d.txt" },
+ { PretrainedModelKind.GloVeTwitter25D, "glove.twitter.27B.25d.txt" },
+ { PretrainedModelKind.GloVeTwitter50D, "glove.twitter.27B.50d.txt" },
+ { PretrainedModelKind.GloVeTwitter100D, "glove.twitter.27B.100d.txt" },
+ { PretrainedModelKind.GloVeTwitter200D, "glove.twitter.27B.200d.txt" },
+ { PretrainedModelKind.FastTextWikipedia300D, "wiki.en.vec" },
+ { PretrainedModelKind.Sswe, "sentiment.emd" }
+ };
+
+ private static Dictionary _linesToSkipInModels = new Dictionary()
+ { { PretrainedModelKind.FastTextWikipedia300D, 1 } };
+
+ private string EnsureModelFile(IHostEnvironment env, out int linesToSkip, PretrainedModelKind kind)
+ {
+ linesToSkip = 0;
+ if (_modelsMetaData.ContainsKey(kind))
+ {
+ var modelFileName = _modelsMetaData[kind];
+ if (_linesToSkipInModels.ContainsKey(kind))
+ linesToSkip = _linesToSkipInModels[kind];
+ using (var ch = Host.Start("Ensuring resources"))
+ {
+ string dir = kind == PretrainedModelKind.Sswe ? Path.Combine("Text", "Sswe") : "WordVectors";
+ var url = $"{dir}/{modelFileName}";
+ var ensureModel = ResourceManagerUtils.Instance.EnsureResource(Host, ch, url, modelFileName, dir, Timeout);
+ ensureModel.Wait();
+ var errorResult = ResourceManagerUtils.GetErrorMessage(out var errorMessage, ensureModel.Result);
+ if (errorResult != null)
+ {
+ var directory = Path.GetDirectoryName(errorResult.FileName);
+ var name = Path.GetFileName(errorResult.FileName);
+ throw ch.Except($"{errorMessage}\nModel file for Word Embedding transform could not be found! " +
+ $@"Please copy the model file '{name}' from '{url}' to '{directory}'.");
+ }
+ return ensureModel.Result.FileName;
+ }
+ }
+ throw Host.Except($"Can't map model kind = {kind} to specific file, please refer to https://aka.ms/MLNetIssue for assistance");
+ }
+
+ private Model GetVocabularyDictionary()
+ {
+ int dimension = 0;
+ if (!File.Exists(_modelFileNameWithPath))
+ throw Host.Except("Custom word embedding model file '{0}' could not be found for Word Embeddings transform.", _modelFileNameWithPath);
+
+ if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
+ {
+ if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model model))
+ {
+ dimension = model.Dimension;
+ return model;
+ }
+ }
+
+ lock (_embeddingsLock)
+ {
+ if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
+ {
+ if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model modelObject))
+ {
+ dimension = modelObject.Dimension;
+ return modelObject;
+ }
+ }
+
+ Model model = null;
+ using (StreamReader sr = File.OpenText(_modelFileNameWithPath))
+ {
+ string line;
+ int lineNumber = 1;
+ char[] delimiters = { ' ', '\t' };
+ using (var ch = Host.Start(LoaderSignature))
+ using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform"))
+ {
+ var header = new ProgressHeader(new[] { "lines" });
+ pch.SetHeader(header, e => e.SetProgress(0, lineNumber));
+ string firstLine = sr.ReadLine();
+ while ((line = sr.ReadLine()) != null)
+ {
+ if (lineNumber >= _linesToSkip)
+ {
+ string[] words = line.TrimEnd().Split(delimiters);
+ dimension = words.Length - 1;
+ if (model == null)
+ model = new Model(dimension);
+ if (model.Dimension != dimension)
+ ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}, expected dimension = {model.Dimension}, received dimension = {dimension}");
+ else
+ {
+ float tmp;
+ string key = words[0];
+ float[] value = words.Skip(1).Select(x => float.TryParse(x, out tmp) ? tmp : Single.NaN).ToArray();
+ if (!value.Contains(Single.NaN))
+ model.AddWordVector(ch, key, value);
+ else
+ ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}");
+ }
+ }
+ lineNumber++;
+ }
+
+ // Handle first line of the embedding file separately since some embedding files including fastText have a single-line header
+ string[] wordsInFirstLine = firstLine.TrimEnd().Split(delimiters);
+ dimension = wordsInFirstLine.Length - 1;
+ if (model == null)
+ model = new Model(dimension);
+ float temp;
+ string firstKey = wordsInFirstLine[0];
+ float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray();
+ if (!firstValue.Contains(Single.NaN))
+ model.AddWordVector(ch, firstKey, firstValue);
+ else
+ ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number 1");
+ pch.Checkpoint(lineNumber);
+ }
+ }
+ _vocab[_modelFileNameWithPath] = new WeakReference(model, false);
+ return model;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/Text/doc.xml b/src/Microsoft.ML.Transforms/Text/doc.xml
index 5f734e1cfd..2d077dc3ed 100644
--- a/src/Microsoft.ML.Transforms/Text/doc.xml
+++ b/src/Microsoft.ML.Transforms/Text/doc.xml
@@ -179,7 +179,49 @@
- pipeline.Add(new LightLda(("InTextCol" , "OutTextCol")));
+ pipeline.Add(new LightLda(("InTextCol" , "OutTextCol")));
+
+
+
+
+
+
+ Word Embeddings transform is a text featurizer which converts vectors of text tokens into sentence vectors using a pre-trained model.
+
+
+ WordEmbeddings wrap different embedding models, such as GloVe. Users can specify which embedding to use.
+ The available options are various versions of GloVe Models, fastText, and SSWE.
+
+ Note: As WordEmbedding requires a column with text vector, e.g. %3C%27this%27, %27is%27, %27good%27%3E, users need to create an input column by
+ using the output_tokens=True for TextTransform to convert a column with sentences like "This is good" into %3C%27this%27, %27is%27, %27good%27 %3E.
+ The suffix of %27_TransformedText%27 is added to the original column name to create the output token column. For instance if the input column is %27body%27,
+ the output tokens column is named %27body_TransformedText%27.
+
+
+ License attributes for pretrained models:
+
+ -
+
+ "fastText Wikipedia 300D" by Facebook, Inc. is licensed under CC-BY-SA 3.0 based on:
+ P. Bojanowski*, E. Grave*, A. Joulin, T. Mikolov,Enriching Word Vectors with Subword Information
+ %40article%7Bbojanowski2016enriching%2C%0A%20%20title%3D%7BEnriching%20Word%20Vectors%20with%20Subword%20Information%7D%2C%0A%20%20author%3D%7BBojanowski%2C%20Piotr%20and%20Grave%2C%20Edouard%20and%20Joulin%2C%20Armand%20and%20Mikolov%2C%20Tomas%7D%2C%0A%20%20journal%3D%7BarXiv%20preprint%20arXiv%3A1607.04606%7D%2C%0A%20%20year%3D%7B2016%7D%0A%7D
+ More information can be found here.
+
+
+ -
+
+ GloVe models by Stanford University, or (Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation) is licensed under PDDL.
+ More information can be found here. Repository can be found here.
+
+
+
+
+
+
+
+
+
+ pipeline.Add(new WordEmbeddings(("InVectorTextCol" , "OutTextCol")));
diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs
index 6854157f31..a46a764352 100644
--- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs
+++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs
@@ -597,7 +597,7 @@ private static Float DotProduct(Float[] a, int aOffset, Float[] b, int[] indices
private static class Mkl
{
- private const string DllName = "Microsoft.ML.MklImports.dll";
+ private const string DllName = "MklImports";
public enum Layout
{
diff --git a/src/Microsoft.ML.Transforms/doc.xml b/src/Microsoft.ML.Transforms/doc.xml
index cb6ef6af25..63d2765afc 100644
--- a/src/Microsoft.ML.Transforms/doc.xml
+++ b/src/Microsoft.ML.Transforms/doc.xml
@@ -7,8 +7,7 @@
Encodes the categorical variable with hash-based encoding.
- CategoricalHashOneHotVectorizer converts a categorical value into an indicator array by hashing the
- value and using the hash as an index in the bag.
+ CategoricalHashOneHotVectorizer converts a categorical value into an indicator array by hashing the value and using the hash as an index in the bag.
If the input column is a vector, a single indicator bag is returned for it.
@@ -33,16 +32,16 @@
The CategoricalOneHotVectorizer transform passes through a data set, operating on text columns, to
build a dictionary of categories.
For each row, the entire text string appearing in the input column is defined as a category.
- The output of this transform is an indicator vector.
+ The output of this transform is an indicator vector.
Each slot in this vector corresponds to a category in the dictionary, so its length is the size of the built dictionary.
- The CategoricalOneHotVectorizer can be applied to one or more columns, in which case it builds and uses a separate dictionary
+ The CategoricalOneHotVectorizer can be applied to one or more columns, in which case it builds and uses a separate dictionary
for each column that it is applied to.
- The produces integer values and columns.
+ The produces integer values and KeyType columns.
The Key value is the one-based index of the slot set in the Ind/Bag options.
If the Key option is not found, it is assigned the value zero.
- In the , options are not found, they result in an all zero bit vector.
- and differ simply in how the bit-vectors generated from individual slots are aggregated:
+ In the , options are not found, they result in an all zero bit vector.
+ and differ simply in how the bit-vectors generated from individual slots are aggregated:
for Ind they are concatenated and for Bag they are added.
When the source column is a singleton, the Ind and Bag options are identical.
@@ -117,8 +116,7 @@
Creates a new column with the specified type and default values.
- If the user wish to create additional columns with a particular type and default values,
- or replicated the values from one column to another, changing their type, they can do so using this transform.
+ If the user wish to create additional columns with a particular type and default values, or replicated the values from one column to another, changing their type, they can do so using this transform.
This transform can be used as a workaround to create a Label column after deserializing a model, for prediction.
Some transforms in the serialized model operate on the Label column, and would throw errors during prediction if such a column is not found.
@@ -206,53 +204,7 @@
-
-
-
- Handle missing values by replacing them with either the default value or the indicated value.
-
-
- This transform handles missing values in the input columns. For each input column, it creates an output column
- where the missing values are replaced by one of these specified values:
-
- -
- The default value of the appropriate type.
-
- -
- The mean value of the appropriate type.
-
- -
- The max value of the appropriate type.
-
- -
- The min value of the appropriate type.
-
-
- The last three work only for numeric/TimeSpan/DateTime kind columns.
-
- The output column can also optionally include an indicator vector for which slots were missing in the input column.
- This can be done only when the indicator vector type can be converted to the input column type, i.e. only for numeric columns.
-
-
- When computing the mean/max/min value, there is also an option to compute it over the whole column instead of per slot.
- This option has a default value of true for variable length vectors, and false for known length vectors.
- It can be changed to true for known length vectors, but it results in an error if changed to false for variable length vectors.
-
-
-
-
-
-
-
-
- pipeline.Add(new MissingValueHandler("FeatureCol", "CleanFeatureCol")
- {
- ReplaceWith = NAHandleTransformReplacementKind.Mean
- });
-
-
-
-
+
The LpNormalizer transforms, normalizes vectors (rows) individually by rescaling them to unit norm (L2, L1 or LInf).
@@ -325,8 +277,8 @@
be ignored, and the missing slots will be 'padded' with default values.
- All metadata is preserved for the retained columns. For 'unrolled' columns, all known metadata
- except slot names is preserved.
+ All metadata are preserved for the retained columns. For 'unrolled' columns, all known metadata
+ except slot names are preserved.
diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs
index b2181fb256..bf753fe486 100644
--- a/src/Microsoft.ML/CSharpApi.cs
+++ b/src/Microsoft.ML/CSharpApi.cs
@@ -754,6 +754,18 @@ public void Add(Microsoft.ML.Trainers.OnlineGradientDescentRegressor input, Micr
_jsonNodes.Add(Serialize("Trainers.OnlineGradientDescentRegressor", input, output));
}
+ public Microsoft.ML.Trainers.OrdinaryLeastSquaresRegressor.Output Add(Microsoft.ML.Trainers.OrdinaryLeastSquaresRegressor input)
+ {
+ var output = new Microsoft.ML.Trainers.OrdinaryLeastSquaresRegressor.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Trainers.OrdinaryLeastSquaresRegressor input, Microsoft.ML.Trainers.OrdinaryLeastSquaresRegressor.Output output)
+ {
+ _jsonNodes.Add(Serialize("Trainers.OrdinaryLeastSquaresRegressor", input, output));
+ }
+
public Microsoft.ML.Trainers.PcaAnomalyDetector.Output Add(Microsoft.ML.Trainers.PcaAnomalyDetector input)
{
var output = new Microsoft.ML.Trainers.PcaAnomalyDetector.Output();
@@ -1090,6 +1102,54 @@ public void Add(Microsoft.ML.Transforms.HashConverter input, Microsoft.ML.Transf
_jsonNodes.Add(Serialize("Transforms.HashConverter", input, output));
}
+ public Microsoft.ML.Transforms.ImageGrayscale.Output Add(Microsoft.ML.Transforms.ImageGrayscale input)
+ {
+ var output = new Microsoft.ML.Transforms.ImageGrayscale.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.ImageGrayscale input, Microsoft.ML.Transforms.ImageGrayscale.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.ImageGrayscale", input, output));
+ }
+
+ public Microsoft.ML.Transforms.ImageLoader.Output Add(Microsoft.ML.Transforms.ImageLoader input)
+ {
+ var output = new Microsoft.ML.Transforms.ImageLoader.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.ImageLoader input, Microsoft.ML.Transforms.ImageLoader.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.ImageLoader", input, output));
+ }
+
+ public Microsoft.ML.Transforms.ImagePixelExtractor.Output Add(Microsoft.ML.Transforms.ImagePixelExtractor input)
+ {
+ var output = new Microsoft.ML.Transforms.ImagePixelExtractor.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.ImagePixelExtractor input, Microsoft.ML.Transforms.ImagePixelExtractor.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.ImagePixelExtractor", input, output));
+ }
+
+ public Microsoft.ML.Transforms.ImageResizer.Output Add(Microsoft.ML.Transforms.ImageResizer input)
+ {
+ var output = new Microsoft.ML.Transforms.ImageResizer.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.ImageResizer input, Microsoft.ML.Transforms.ImageResizer.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.ImageResizer", input, output));
+ }
+
public Microsoft.ML.Transforms.KeyToTextConverter.Output Add(Microsoft.ML.Transforms.KeyToTextConverter input)
{
var output = new Microsoft.ML.Transforms.KeyToTextConverter.Output();
@@ -1522,6 +1582,30 @@ public void Add(Microsoft.ML.Transforms.TwoHeterogeneousModelCombiner input, Mic
_jsonNodes.Add(Serialize("Transforms.TwoHeterogeneousModelCombiner", input, output));
}
+ public Microsoft.ML.Transforms.VectorToImage.Output Add(Microsoft.ML.Transforms.VectorToImage input)
+ {
+ var output = new Microsoft.ML.Transforms.VectorToImage.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.VectorToImage input, Microsoft.ML.Transforms.VectorToImage.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.VectorToImage", input, output));
+ }
+
+ public Microsoft.ML.Transforms.WordEmbeddings.Output Add(Microsoft.ML.Transforms.WordEmbeddings input)
+ {
+ var output = new Microsoft.ML.Transforms.WordEmbeddings.Output();
+ Add(input, output);
+ return output;
+ }
+
+ public void Add(Microsoft.ML.Transforms.WordEmbeddings input, Microsoft.ML.Transforms.WordEmbeddings.Output output)
+ {
+ _jsonNodes.Add(Serialize("Transforms.WordEmbeddings", input, output));
+ }
+
public Microsoft.ML.Transforms.WordTokenizer.Output Add(Microsoft.ML.Transforms.WordTokenizer input)
{
var output = new Microsoft.ML.Transforms.WordTokenizer.Output();
@@ -3080,7 +3164,7 @@ public sealed partial class OneVersusAllMacroSubGraphOutput
}
- ///
+ ///
public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
{
@@ -8601,8 +8685,8 @@ public LogisticRegressionClassifierPipelineStep(Output output)
namespace Trainers
{
- ///
- ///
+ ///
+ ///
public sealed partial class NaiveBayesClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
{
@@ -8824,6 +8908,93 @@ public OnlineGradientDescentRegressorPipelineStep(Output output)
}
}
+ namespace Trainers
+ {
+
+ ///
+ public sealed partial class OrdinaryLeastSquaresRegressor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+
+ ///
+ /// L2 regularization weight
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[]{1E-06f, 0.1f, 1f})]
+ public float L2Weight { get; set; } = 1E-06f;
+
+ ///
+ /// Whether to calculate per parameter significance statistics
+ ///
+ public bool PerParameterSignificance { get; set; } = true;
+
+ ///
+ /// Column to use for example weight
+ ///
+ public Microsoft.ML.Runtime.EntryPoints.Optional WeightColumn { get; set; }
+
+ ///
+ /// Column to use for labels
+ ///
+ public string LabelColumn { get; set; } = "Label";
+
+ ///
+ /// The data to be used for training
+ ///
+ public Var TrainingData { get; set; } = new Var();
+
+ ///
+ /// Column to use for features
+ ///
+ public string FeatureColumn { get; set; } = "Features";
+
+ ///
+ /// Normalize option for the feature column
+ ///
+ public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto;
+
+ ///
+ /// Whether learner should cache input training data
+ ///
+ public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto;
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRegressionOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput
+ {
+ ///
+ /// The trained model
+ ///
+ public Var PredictorModel { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => TrainingData;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(OrdinaryLeastSquaresRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ TrainingData = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new OrdinaryLeastSquaresRegressorPipelineStep(output);
+ }
+
+ private class OrdinaryLeastSquaresRegressorPipelineStep : ILearningPipelinePredictorStep
+ {
+ public OrdinaryLeastSquaresRegressorPipelineStep(Output output)
+ {
+ Model = output.PredictorModel;
+ }
+
+ public Var Model { get; }
+ }
+ }
+ }
+
namespace Trainers
{
@@ -11417,8 +11588,8 @@ public FeatureCombinerPipelineStep(Output output)
namespace Transforms
{
- ///
- ///
+ ///
+ ///
public sealed partial class FeatureSelectorByCount : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
@@ -11486,8 +11657,8 @@ public FeatureSelectorByCountPipelineStep(Output output)
namespace Transforms
{
- ///
- ///
+ ///
+ ///
public sealed partial class FeatureSelectorByMutualInformation : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
@@ -11875,7 +12046,7 @@ public HashConverterPipelineStep(Output output)
namespace Transforms
{
- public sealed partial class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn
+ public sealed partial class ImageGrayscaleTransformColumn : OneToOneColumn, IOneToOneColumn
{
///
/// Name of the new column
@@ -11889,15 +12060,17 @@ public sealed partial class KeyToValueTransformColumn : OneToOneColumn
- public sealed partial class KeyToTextConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ ///
+ /// Convert image into grayscale.
+ ///
+ public sealed partial class ImageGrayscale : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
- public KeyToTextConverter()
+ public ImageGrayscale()
{
}
- public KeyToTextConverter(params string[] inputColumns)
+ public ImageGrayscale(params string[] inputColumns)
{
if (inputColumns != null)
{
@@ -11908,7 +12081,7 @@ public KeyToTextConverter(params string[] inputColumns)
}
}
- public KeyToTextConverter(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ public ImageGrayscale(params (string inputColumn, string outputColumn)[] inputOutputColumns)
{
if (inputOutputColumns != null)
{
@@ -11921,15 +12094,15 @@ public KeyToTextConverter(params (string inputColumn, string outputColumn)[] inp
public void AddColumn(string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
Column = list.ToArray();
}
public void AddColumn(string outputColumn, string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
Column = list.ToArray();
}
@@ -11937,7 +12110,7 @@ public void AddColumn(string outputColumn, string inputColumn)
///
/// New column definition(s) (optional form: name:src)
///
- public KeyToValueTransformColumn[] Column { get; set; }
+ public ImageGrayscaleTransformColumn[] Column { get; set; }
///
/// Input dataset
@@ -11966,18 +12139,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(KeyToTextConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(ImageGrayscale)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new KeyToTextConverterPipelineStep(output);
+ return new ImageGrayscalePipelineStep(output);
}
- private class KeyToTextConverterPipelineStep : ILearningPipelineDataStep
+ private class ImageGrayscalePipelineStep : ILearningPipelineDataStep
{
- public KeyToTextConverterPipelineStep(Output output)
+ public ImageGrayscalePipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -11992,22 +12165,76 @@ public KeyToTextConverterPipelineStep(Output output)
namespace Transforms
{
+ public sealed partial class ImageLoaderTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
///
- /// Transforms the label to either key or bool (if needed) to make it suitable for classification.
+ /// Load images from files.
///
- public sealed partial class LabelColumnKeyBooleanConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ public sealed partial class ImageLoader : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
+ public ImageLoader()
+ {
+ }
+
+ public ImageLoader(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public ImageLoader(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
+
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
+
///
- /// Convert the key values to text
+ /// New column definition(s) (optional form: name:src)
///
- public bool TextKeyValues { get; set; } = true;
+ public ImageLoaderTransformColumn[] Column { get; set; }
///
- /// The label column
+ /// Folder where to search for images
///
- public string LabelColumn { get; set; }
+ public string ImageFolder { get; set; }
///
/// Input dataset
@@ -12036,18 +12263,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(LabelColumnKeyBooleanConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(ImageLoader)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new LabelColumnKeyBooleanConverterPipelineStep(output);
+ return new ImageLoaderPipelineStep(output);
}
- private class LabelColumnKeyBooleanConverterPipelineStep : ILearningPipelineDataStep
+ private class ImageLoaderPipelineStep : ILearningPipelineDataStep
{
- public LabelColumnKeyBooleanConverterPipelineStep(Output output)
+ public ImageLoaderPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -12062,12 +12289,47 @@ public LabelColumnKeyBooleanConverterPipelineStep(Output output)
namespace Transforms
{
- public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn
+ public sealed partial class ImagePixelExtractorTransformColumn : OneToOneColumn, IOneToOneColumn
{
///
- /// The positive example class for binary classification.
+ /// Whether to use alpha channel
///
- public int? ClassIndex { get; set; }
+ public bool? UseAlpha { get; set; }
+
+ ///
+ /// Whether to use red channel
+ ///
+ public bool? UseRed { get; set; }
+
+ ///
+ /// Whether to use green channel
+ ///
+ public bool? UseGreen { get; set; }
+
+ ///
+ /// Whether to use blue channel
+ ///
+ public bool? UseBlue { get; set; }
+
+ ///
+ /// Whether to separate each channel or interleave in ARGB order
+ ///
+ public bool? InterleaveArgb { get; set; }
+
+ ///
+ /// Whether to convert to floating point
+ ///
+ public bool? Convert { get; set; }
+
+ ///
+ /// Offset (pre-scale)
+ ///
+ public float? Offset { get; set; }
+
+ ///
+ /// Scale factor
+ ///
+ public float? Scale { get; set; }
///
/// Name of the new column
@@ -12082,16 +12344,16 @@ public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn
- public sealed partial class LabelIndicator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ public sealed partial class ImagePixelExtractor : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
- public LabelIndicator()
+ public ImagePixelExtractor()
{
}
- public LabelIndicator(params string[] inputColumns)
+ public ImagePixelExtractor(params string[] inputColumns)
{
if (inputColumns != null)
{
@@ -12102,7 +12364,7 @@ public LabelIndicator(params string[] inputColumns)
}
}
- public LabelIndicator(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ public ImagePixelExtractor(params (string inputColumn, string outputColumn)[] inputOutputColumns)
{
if (inputOutputColumns != null)
{
@@ -12115,15 +12377,15 @@ public LabelIndicator(params (string inputColumn, string outputColumn)[] inputOu
public void AddColumn(string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
Column = list.ToArray();
}
public void AddColumn(string outputColumn, string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
Column = list.ToArray();
}
@@ -12131,77 +12393,47 @@ public void AddColumn(string outputColumn, string inputColumn)
///
/// New column definition(s) (optional form: name:src)
///
- public LabelIndicatorTransformColumn[] Column { get; set; }
+ public ImagePixelExtractorTransformColumn[] Column { get; set; }
///
- /// Label of the positive class.
+ /// Whether to use alpha channel
///
- public int ClassIndex { get; set; }
+ public bool UseAlpha { get; set; } = false;
///
- /// Input dataset
+ /// Whether to use red channel
///
- public Var Data { get; set; } = new Var();
+ public bool UseRed { get; set; } = true;
+ ///
+ /// Whether to use green channel
+ ///
+ public bool UseGreen { get; set; } = true;
- public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
- {
- ///
- /// Transformed dataset
- ///
- public Var OutputData { get; set; } = new Var();
+ ///
+ /// Whether to use blue channel
+ ///
+ public bool UseBlue { get; set; } = true;
- ///
- /// Transform model
- ///
- public Var Model { get; set; } = new Var();
-
- }
- public Var GetInputData() => Data;
-
- public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
- {
- if (previousStep != null)
- {
- if (!(previousStep is ILearningPipelineDataStep dataStep))
- {
- throw new InvalidOperationException($"{ nameof(LabelIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
- }
-
- Data = dataStep.Data;
- }
- Output output = experiment.Add(this);
- return new LabelIndicatorPipelineStep(output);
- }
-
- private class LabelIndicatorPipelineStep : ILearningPipelineDataStep
- {
- public LabelIndicatorPipelineStep(Output output)
- {
- Data = output.OutputData;
- Model = output.Model;
- }
-
- public Var Data { get; }
- public Var Model { get; }
- }
- }
- }
-
- namespace Transforms
- {
+ ///
+ /// Whether to separate each channel or interleave in ARGB order
+ ///
+ public bool InterleaveArgb { get; set; } = false;
- ///
- /// Transforms the label to float to make it suitable for regression.
- ///
- public sealed partial class LabelToFloatConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
- {
+ ///
+ /// Whether to convert to floating point
+ ///
+ public bool Convert { get; set; } = true;
+ ///
+ /// Offset (pre-scale)
+ ///
+ public float? Offset { get; set; }
///
- /// The label column
+ /// Scale factor
///
- public string LabelColumn { get; set; }
+ public float? Scale { get; set; }
///
/// Input dataset
@@ -12230,18 +12462,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(LabelToFloatConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(ImagePixelExtractor)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new LabelToFloatConverterPipelineStep(output);
+ return new ImagePixelExtractorPipelineStep(output);
}
- private class LabelToFloatConverterPipelineStep : ILearningPipelineDataStep
+ private class ImagePixelExtractorPipelineStep : ILearningPipelineDataStep
{
- public LabelToFloatConverterPipelineStep(Output output)
+ public ImagePixelExtractorPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -12255,63 +12487,43 @@ public LabelToFloatConverterPipelineStep(Output output)
namespace Transforms
{
-
- public sealed partial class LdaTransformColumn : OneToOneColumn, IOneToOneColumn
+ public enum ImageResizerTransformResizingKind : byte
{
- ///
- /// The number of topics in the LDA
- ///
- public int? NumTopic { get; set; }
-
- ///
- /// Dirichlet prior on document-topic vectors
- ///
- public float? AlphaSum { get; set; }
-
- ///
- /// Dirichlet prior on vocab-topic vectors
- ///
- public float? Beta { get; set; }
-
- ///
- /// Number of Metropolis Hasting step
- ///
- public int? Mhstep { get; set; }
-
- ///
- /// Number of iterations
- ///
- public int? NumIterations { get; set; }
+ IsoPad = 0,
+ IsoCrop = 1
+ }
- ///
- /// Compute log likelihood over local dataset on this iteration interval
- ///
- public int? LikelihoodInterval { get; set; }
+ public enum ImageResizerTransformAnchor : byte
+ {
+ Right = 0,
+ Left = 1,
+ Top = 2,
+ Bottom = 3,
+ Center = 4
+ }
- ///
- /// The number of training threads
- ///
- public int? NumThreads { get; set; }
+ public sealed partial class ImageResizerTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
///
- /// The threshold of maximum count of tokens per doc
+ /// Width of the resized image
///
- public int? NumMaxDocToken { get; set; }
+ public int? ImageWidth { get; set; }
///
- /// The number of words to summarize the topic
+ /// Height of the resized image
///
- public int? NumSummaryTermPerTopic { get; set; }
+ public int? ImageHeight { get; set; }
///
- /// The number of burn-in iterations
+ /// Resizing method
///
- public int? NumBurninIterations { get; set; } = 10;
+ public ImageResizerTransformResizingKind? Resizing { get; set; }
///
- /// Reset the random number generator for each document
+ /// Anchor for cropping
///
- public bool? ResetRandomGenerator { get; set; }
+ public ImageResizerTransformAnchor? CropAnchor { get; set; }
///
/// Name of the new column
@@ -12325,16 +12537,17 @@ public sealed partial class LdaTransformColumn : OneToOneColumn
- ///
- public sealed partial class LightLda : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ ///
+ /// Scales an image to specified dimensions using one of the three scale types: isotropic with padding, isotropic with cropping or anisotropic. In case of isotropic padding, transparent color is used to pad resulting image.
+ ///
+ public sealed partial class ImageResizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
- public LightLda()
+ public ImageResizer()
{
}
- public LightLda(params string[] inputColumns)
+ public ImageResizer(params string[] inputColumns)
{
if (inputColumns != null)
{
@@ -12345,7 +12558,7 @@ public LightLda(params string[] inputColumns)
}
}
- public LightLda(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ public ImageResizer(params (string inputColumn, string outputColumn)[] inputOutputColumns)
{
if (inputOutputColumns != null)
{
@@ -12358,89 +12571,43 @@ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputCo
public void AddColumn(string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
Column = list.ToArray();
}
public void AddColumn(string outputColumn, string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
Column = list.ToArray();
}
///
- /// New column definition(s) (optional form: name:srcs)
- ///
- public LdaTransformColumn[] Column { get; set; }
-
- ///
- /// The number of topics in the LDA
- ///
- [TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})]
- public int NumTopic { get; set; } = 100;
-
- ///
- /// Dirichlet prior on document-topic vectors
- ///
- [TlcModule.SweepableDiscreteParamAttribute("AlphaSum", new object[]{1, 10, 100, 200})]
- public float AlphaSum { get; set; } = 100f;
-
- ///
- /// Dirichlet prior on vocab-topic vectors
- ///
- [TlcModule.SweepableDiscreteParamAttribute("Beta", new object[]{0.01f, 0.015f, 0.07f, 0.02f})]
- public float Beta { get; set; } = 0.01f;
-
- ///
- /// Number of Metropolis Hasting step
- ///
- [TlcModule.SweepableDiscreteParamAttribute("Mhstep", new object[]{2, 4, 8, 16})]
- public int Mhstep { get; set; } = 4;
-
- ///
- /// Number of iterations
- ///
- [TlcModule.SweepableDiscreteParamAttribute("NumIterations", new object[]{100, 200, 300, 400})]
- public int NumIterations { get; set; } = 200;
-
- ///
- /// Compute log likelihood over local dataset on this iteration interval
- ///
- public int LikelihoodInterval { get; set; } = 5;
-
- ///
- /// The threshold of maximum count of tokens per doc
- ///
- public int NumMaxDocToken { get; set; } = 512;
-
- ///
- /// The number of training threads. Default value depends on number of logical processors.
+ /// New column definition(s) (optional form: name:src)
///
- public int? NumThreads { get; set; }
+ public ImageResizerTransformColumn[] Column { get; set; }
///
- /// The number of words to summarize the topic
+ /// Resized width of the image
///
- public int NumSummaryTermPerTopic { get; set; } = 10;
+ public int ImageWidth { get; set; }
///
- /// The number of burn-in iterations
+ /// Resized height of the image
///
- [TlcModule.SweepableDiscreteParamAttribute("NumBurninIterations", new object[]{10, 20, 30, 40})]
- public int NumBurninIterations { get; set; } = 10;
+ public int ImageHeight { get; set; }
///
- /// Reset the random number generator for each document
+ /// Resizing method
///
- public bool ResetRandomGenerator { get; set; } = false;
+ public ImageResizerTransformResizingKind Resizing { get; set; } = ImageResizerTransformResizingKind.IsoCrop;
///
- /// Whether to output the topic-word summary in text format
+ /// Anchor for cropping
///
- public bool OutputTopicWordSummary { get; set; } = false;
+ public ImageResizerTransformAnchor CropAnchor { get; set; } = ImageResizerTransformAnchor.Center;
///
/// Input dataset
@@ -12469,18 +12636,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(LightLda)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(ImageResizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new LightLdaPipelineStep(output);
+ return new ImageResizerPipelineStep(output);
}
- private class LightLdaPipelineStep : ILearningPipelineDataStep
+ private class ImageResizerPipelineStep : ILearningPipelineDataStep
{
- public LightLdaPipelineStep(Output output)
+ public ImageResizerPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -12495,13 +12662,8 @@ public LightLdaPipelineStep(Output output)
namespace Transforms
{
- public sealed partial class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn
+ public sealed partial class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn
{
- ///
- /// Max number of examples used to train the normalizer
- ///
- public long? MaxTrainingExamples { get; set; }
-
///
/// Name of the new column
///
@@ -12514,17 +12676,15 @@ public sealed partial class NormalizeTransformLogNormalColumn : OneToOneColumn
- /// Normalizes the data based on the computed mean and variance of the logarithm of the data.
- ///
- public sealed partial class LogMeanVarianceNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ ///
+ public sealed partial class KeyToTextConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
- public LogMeanVarianceNormalizer()
+ public KeyToTextConverter()
{
}
- public LogMeanVarianceNormalizer(params string[] inputColumns)
+ public KeyToTextConverter(params string[] inputColumns)
{
if (inputColumns != null)
{
@@ -12535,7 +12695,7 @@ public LogMeanVarianceNormalizer(params string[] inputColumns)
}
}
- public LogMeanVarianceNormalizer(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ public KeyToTextConverter(params (string inputColumn, string outputColumn)[] inputOutputColumns)
{
if (inputOutputColumns != null)
{
@@ -12548,33 +12708,23 @@ public LogMeanVarianceNormalizer(params (string inputColumn, string outputColumn
public void AddColumn(string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
Column = list.ToArray();
}
public void AddColumn(string outputColumn, string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
Column = list.ToArray();
}
- ///
- /// Whether to use CDF as the output
- ///
- public bool UseCdf { get; set; } = true;
-
///
/// New column definition(s) (optional form: name:src)
///
- public NormalizeTransformLogNormalColumn[] Column { get; set; }
-
- ///
- /// Max number of examples used to train the normalizer
- ///
- public long MaxTrainingExamples { get; set; } = 1000000000;
+ public KeyToValueTransformColumn[] Column { get; set; }
///
/// Input dataset
@@ -12603,18 +12753,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(LogMeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(KeyToTextConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new LogMeanVarianceNormalizerPipelineStep(output);
+ return new KeyToTextConverterPipelineStep(output);
}
- private class LogMeanVarianceNormalizerPipelineStep : ILearningPipelineDataStep
+ private class KeyToTextConverterPipelineStep : ILearningPipelineDataStep
{
- public LogMeanVarianceNormalizerPipelineStep(Output output)
+ public KeyToTextConverterPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -12628,41 +12778,678 @@ public LogMeanVarianceNormalizerPipelineStep(Output output)
namespace Transforms
{
- public enum LpNormNormalizerTransformNormalizerKind : byte
- {
- L2Norm = 0,
- StdDev = 1,
- L1Norm = 2,
- LInf = 3
- }
-
- public sealed partial class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn
+ ///
+ /// Transforms the label to either key or bool (if needed) to make it suitable for classification.
+ ///
+ public sealed partial class LabelColumnKeyBooleanConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
- ///
- /// The norm to use to normalize each sample
- ///
- public LpNormNormalizerTransformNormalizerKind? NormKind { get; set; }
+
///
- /// Subtract mean from each value before normalizing
+ /// Convert the key values to text
///
- public bool? SubMean { get; set; }
+ public bool TextKeyValues { get; set; } = true;
///
- /// Name of the new column
+ /// The label column
///
- public string Name { get; set; }
+ public string LabelColumn { get; set; }
///
- /// Name of the source column
+ /// Input dataset
///
- public string Source { get; set; }
+ public Var Data { get; set; } = new Var();
- }
- ///
- public sealed partial class LpNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(LabelColumnKeyBooleanConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new LabelColumnKeyBooleanConverterPipelineStep(output);
+ }
+
+ private class LabelColumnKeyBooleanConverterPipelineStep : ILearningPipelineDataStep
+ {
+ public LabelColumnKeyBooleanConverterPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// The positive example class for binary classification.
+ ///
+ public int? ClassIndex { get; set; }
+
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
+ ///
+ /// Label remapper used by OVA
+ ///
+ public sealed partial class LabelIndicator : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+ public LabelIndicator()
+ {
+ }
+
+ public LabelIndicator(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public LabelIndicator(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
+
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
+
+
+ ///
+ /// New column definition(s) (optional form: name:src)
+ ///
+ public LabelIndicatorTransformColumn[] Column { get; set; }
+
+ ///
+ /// Label of the positive class.
+ ///
+ public int ClassIndex { get; set; }
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(LabelIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new LabelIndicatorPipelineStep(output);
+ }
+
+ private class LabelIndicatorPipelineStep : ILearningPipelineDataStep
+ {
+ public LabelIndicatorPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ ///
+ /// Transforms the label to float to make it suitable for regression.
+ ///
+ public sealed partial class LabelToFloatConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+
+ ///
+ /// The label column
+ ///
+ public string LabelColumn { get; set; }
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(LabelToFloatConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new LabelToFloatConverterPipelineStep(output);
+ }
+
+ private class LabelToFloatConverterPipelineStep : ILearningPipelineDataStep
+ {
+ public LabelToFloatConverterPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ public sealed partial class LdaTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// The number of topics in the LDA
+ ///
+ public int? NumTopic { get; set; }
+
+ ///
+ /// Dirichlet prior on document-topic vectors
+ ///
+ public float? AlphaSum { get; set; }
+
+ ///
+ /// Dirichlet prior on vocab-topic vectors
+ ///
+ public float? Beta { get; set; }
+
+ ///
+ /// Number of Metropolis Hasting step
+ ///
+ public int? Mhstep { get; set; }
+
+ ///
+ /// Number of iterations
+ ///
+ public int? NumIterations { get; set; }
+
+ ///
+ /// Compute log likelihood over local dataset on this iteration interval
+ ///
+ public int? LikelihoodInterval { get; set; }
+
+ ///
+ /// The number of training threads
+ ///
+ public int? NumThreads { get; set; }
+
+ ///
+ /// The threshold of maximum count of tokens per doc
+ ///
+ public int? NumMaxDocToken { get; set; }
+
+ ///
+ /// The number of words to summarize the topic
+ ///
+ public int? NumSummaryTermPerTopic { get; set; }
+
+ ///
+ /// The number of burn-in iterations
+ ///
+ public int? NumBurninIterations { get; set; } = 10;
+
+ ///
+ /// Reset the random number generator for each document
+ ///
+ public bool? ResetRandomGenerator { get; set; }
+
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
+ ///
+ ///
+ public sealed partial class LightLda : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+ public LightLda()
+ {
+ }
+
+ public LightLda(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public LightLda(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
+
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
+
+
+ ///
+ /// New column definition(s) (optional form: name:srcs)
+ ///
+ public LdaTransformColumn[] Column { get; set; }
+
+ ///
+ /// The number of topics in the LDA
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("NumTopic", new object[]{20, 40, 100, 200})]
+ public int NumTopic { get; set; } = 100;
+
+ ///
+ /// Dirichlet prior on document-topic vectors
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("AlphaSum", new object[]{1, 10, 100, 200})]
+ public float AlphaSum { get; set; } = 100f;
+
+ ///
+ /// Dirichlet prior on vocab-topic vectors
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("Beta", new object[]{0.01f, 0.015f, 0.07f, 0.02f})]
+ public float Beta { get; set; } = 0.01f;
+
+ ///
+ /// Number of Metropolis Hasting step
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("Mhstep", new object[]{2, 4, 8, 16})]
+ public int Mhstep { get; set; } = 4;
+
+ ///
+ /// Number of iterations
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("NumIterations", new object[]{100, 200, 300, 400})]
+ public int NumIterations { get; set; } = 200;
+
+ ///
+ /// Compute log likelihood over local dataset on this iteration interval
+ ///
+ public int LikelihoodInterval { get; set; } = 5;
+
+ ///
+ /// The threshold of maximum count of tokens per doc
+ ///
+ public int NumMaxDocToken { get; set; } = 512;
+
+ ///
+ /// The number of training threads. Default value depends on number of logical processors.
+ ///
+ public int? NumThreads { get; set; }
+
+ ///
+ /// The number of words to summarize the topic
+ ///
+ public int NumSummaryTermPerTopic { get; set; } = 10;
+
+ ///
+ /// The number of burn-in iterations
+ ///
+ [TlcModule.SweepableDiscreteParamAttribute("NumBurninIterations", new object[]{10, 20, 30, 40})]
+ public int NumBurninIterations { get; set; } = 10;
+
+ ///
+ /// Reset the random number generator for each document
+ ///
+ public bool ResetRandomGenerator { get; set; } = false;
+
+ ///
+ /// Whether to output the topic-word summary in text format
+ ///
+ public bool OutputTopicWordSummary { get; set; } = false;
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(LightLda)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new LightLdaPipelineStep(output);
+ }
+
+ private class LightLdaPipelineStep : ILearningPipelineDataStep
+ {
+ public LightLdaPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ public sealed partial class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// Max number of examples used to train the normalizer
+ ///
+ public long? MaxTrainingExamples { get; set; }
+
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
+ ///
+ /// Normalizes the data based on the computed mean and variance of the logarithm of the data.
+ ///
+ public sealed partial class LogMeanVarianceNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+ public LogMeanVarianceNormalizer()
+ {
+ }
+
+ public LogMeanVarianceNormalizer(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public LogMeanVarianceNormalizer(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
+
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
+
+
+ ///
+ /// Whether to use CDF as the output
+ ///
+ public bool UseCdf { get; set; } = true;
+
+ ///
+ /// New column definition(s) (optional form: name:src)
+ ///
+ public NormalizeTransformLogNormalColumn[] Column { get; set; }
+
+ ///
+ /// Max number of examples used to train the normalizer
+ ///
+ public long MaxTrainingExamples { get; set; } = 1000000000;
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(LogMeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new LogMeanVarianceNormalizerPipelineStep(output);
+ }
+
+ private class LogMeanVarianceNormalizerPipelineStep : ILearningPipelineDataStep
+ {
+ public LogMeanVarianceNormalizerPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+ public enum LpNormNormalizerTransformNormalizerKind : byte
+ {
+ L2Norm = 0,
+ StdDev = 1,
+ L1Norm = 2,
+ LInf = 3
+ }
+
+
+ public sealed partial class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// The norm to use to normalize each sample
+ ///
+ public LpNormNormalizerTransformNormalizerKind? NormKind { get; set; }
+
+ ///
+ /// Subtract mean from each value before normalizing
+ ///
+ public bool? SubMean { get; set; }
+
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
+ ///
+ public sealed partial class LpNormalizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
{
public LpNormalizer()
@@ -15088,27 +15875,266 @@ public void AddColumn(string name, params string[] source)
///
public bool OutputTokens { get; set; } = false;
- ///
- /// A dictionary of whitelisted terms.
- ///
- public TermLoaderArguments Dictionary { get; set; }
+ ///
+ /// A dictionary of whitelisted terms.
+ ///
+ public TermLoaderArguments Dictionary { get; set; }
+
+ ///
+ /// Ngram feature extractor to use for words (WordBag/WordHashBag).
+ ///
+ [JsonConverter(typeof(ComponentSerializer))]
+ public NgramExtractor WordFeatureExtractor { get; set; } = new NGramNgramExtractor();
+
+ ///
+ /// Ngram feature extractor to use for characters (WordBag/WordHashBag).
+ ///
+ [JsonConverter(typeof(ComponentSerializer))]
+ public NgramExtractor CharFeatureExtractor { get; set; } = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false };
+
+ ///
+ /// Normalize vectors (rows) individually by rescaling them to unit norm.
+ ///
+ public TextTransformTextNormKind VectorNormalizer { get; set; } = TextTransformTextNormKind.L2;
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new TextFeaturizerPipelineStep(output);
+ }
+
+ private class TextFeaturizerPipelineStep : ILearningPipelineDataStep
+ {
+ public TextFeaturizerPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ ///
+ ///
+ public sealed partial class TextToKeyConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+ public TextToKeyConverter()
+ {
+ }
+
+ public TextToKeyConverter(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public TextToKeyConverter(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
+
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
+
+
+ ///
+ /// New column definition(s) (optional form: name:src)
+ ///
+ public TermTransformColumn[] Column { get; set; }
+
+ ///
+ /// Maximum number of terms to keep per column when auto-training
+ ///
+ public int MaxNumTerms { get; set; } = 1000000;
+
+ ///
+ /// List of terms
+ ///
+ public string[] Term { get; set; }
+
+ ///
+ /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').
+ ///
+ public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence;
+
+ ///
+ /// Whether key value metadata should be text, regardless of the actual input type
+ ///
+ public bool TextKeyValues { get; set; } = false;
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+
+ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput
+ {
+ ///
+ /// Transformed dataset
+ ///
+ public Var OutputData { get; set; } = new Var();
+
+ ///
+ /// Transform model
+ ///
+ public Var Model { get; set; } = new Var();
+
+ }
+ public Var GetInputData() => Data;
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(TextToKeyConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ Data = dataStep.Data;
+ }
+ Output output = experiment.Add(this);
+ return new TextToKeyConverterPipelineStep(output);
+ }
+
+ private class TextToKeyConverterPipelineStep : ILearningPipelineDataStep
+ {
+ public TextToKeyConverterPipelineStep(Output output)
+ {
+ Data = output.OutputData;
+ Model = output.Model;
+ }
+
+ public Var Data { get; }
+ public Var Model { get; }
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ ///
+ /// Split the dataset into train and test sets
+ ///
+ public sealed partial class TrainTestDatasetSplitter
+ {
+
+
+ ///
+ /// Input dataset
+ ///
+ public Var Data { get; set; } = new Var();
+
+ ///
+ /// Fraction of training data
+ ///
+ public float Fraction { get; set; } = 0.8f;
+
+ ///
+ /// Stratification column
+ ///
+ public string StratificationColumn { get; set; }
+
+
+ public sealed class Output
+ {
+ ///
+ /// Training data
+ ///
+ public Var TrainData { get; set; } = new Var();
+
+ ///
+ /// Testing data
+ ///
+ public Var TestData { get; set; } = new Var();
+
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ ///
+ public sealed partial class TreeLeafFeaturizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IFeaturizerInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
///
- /// Ngram feature extractor to use for words (WordBag/WordHashBag).
+ /// Output column: The suffix to append to the default column names
///
- [JsonConverter(typeof(ComponentSerializer))]
- public NgramExtractor WordFeatureExtractor { get; set; } = new NGramNgramExtractor();
+ public string Suffix { get; set; }
///
- /// Ngram feature extractor to use for characters (WordBag/WordHashBag).
+ /// If specified, determines the permutation seed for applying this featurizer to a multiclass problem.
///
- [JsonConverter(typeof(ComponentSerializer))]
- public NgramExtractor CharFeatureExtractor { get; set; } = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false };
+ public int LabelPermutationSeed { get; set; }
///
- /// Normalize vectors (rows) individually by rescaling them to unit norm.
+ /// Trainer to use
///
- public TextTransformTextNormKind VectorNormalizer { get; set; } = TextTransformTextNormKind.L2;
+ public Var PredictorModel { get; set; } = new Var();
///
/// Input dataset
@@ -15137,18 +16163,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(TreeLeafFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new TextFeaturizerPipelineStep(output);
+ return new TreeLeafFeaturizerPipelineStep(output);
}
- private class TextFeaturizerPipelineStep : ILearningPipelineDataStep
+ private class TreeLeafFeaturizerPipelineStep : ILearningPipelineDataStep
{
- public TextFeaturizerPipelineStep(Output output)
+ public TreeLeafFeaturizerPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -15163,16 +16189,108 @@ public TextFeaturizerPipelineStep(Output output)
namespace Transforms
{
- ///
- ///
- public sealed partial class TextToKeyConverter : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ ///
+ /// Combines a TransformModel and a PredictorModel into a single PredictorModel.
+ ///
+ public sealed partial class TwoHeterogeneousModelCombiner
{
- public TextToKeyConverter()
+
+ ///
+ /// Transform model
+ ///
+ public Var TransformModel { get; set; } = new Var();
+
+ ///
+ /// Predictor model
+ ///
+ public Var PredictorModel { get; set; } = new Var();
+
+
+ public sealed class Output
+ {
+ ///
+ /// Predictor model
+ ///
+ public Var PredictorModel { get; set; } = new Var();
+
+ }
+ }
+ }
+
+ namespace Transforms
+ {
+
+ public sealed partial class VectorToImageTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
+ ///
+ /// Whether to use alpha channel
+ ///
+ public bool? ContainsAlpha { get; set; }
+
+ ///
+ /// Whether to use red channel
+ ///
+ public bool? ContainsRed { get; set; }
+
+ ///
+ /// Whether to use green channel
+ ///
+ public bool? ContainsGreen { get; set; }
+
+ ///
+ /// Whether to use blue channel
+ ///
+ public bool? ContainsBlue { get; set; }
+
+ ///
+ /// Whether to separate each channel or interleave in ARGB order
+ ///
+ public bool? InterleaveArgb { get; set; }
+
+ ///
+ /// Width of the image
+ ///
+ public int? ImageWidth { get; set; }
+
+ ///
+ /// Height of the image
+ ///
+ public int? ImageHeight { get; set; }
+
+ ///
+ /// Offset (pre-scale)
+ ///
+ public float? Offset { get; set; }
+
+ ///
+ /// Scale factor
+ ///
+ public float? Scale { get; set; }
+
+ ///
+ /// Name of the new column
+ ///
+ public string Name { get; set; }
+
+ ///
+ /// Name of the source column
+ ///
+ public string Source { get; set; }
+
+ }
+
+ ///
+ /// Converts vector array into image type.
+ ///
+ public sealed partial class VectorToImage : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
+
+ public VectorToImage()
{
}
- public TextToKeyConverter(params string[] inputColumns)
+ public VectorToImage(params string[] inputColumns)
{
if (inputColumns != null)
{
@@ -15183,7 +16301,7 @@ public TextToKeyConverter(params string[] inputColumns)
}
}
- public TextToKeyConverter(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ public VectorToImage(params (string inputColumn, string outputColumn)[] inputOutputColumns)
{
if (inputOutputColumns != null)
{
@@ -15196,15 +16314,15 @@ public TextToKeyConverter(params (string inputColumn, string outputColumn)[] inp
public void AddColumn(string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
Column = list.ToArray();
}
public void AddColumn(string outputColumn, string inputColumn)
{
- var list = Column == null ? new List() : new List(Column);
- list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
Column = list.ToArray();
}
@@ -15212,27 +16330,52 @@ public void AddColumn(string outputColumn, string inputColumn)
///
/// New column definition(s) (optional form: name:src)
///
- public TermTransformColumn[] Column { get; set; }
+ public VectorToImageTransformColumn[] Column { get; set; }
///
- /// Maximum number of terms to keep per column when auto-training
+ /// Whether to use alpha channel
///
- public int MaxNumTerms { get; set; } = 1000000;
+ public bool ContainsAlpha { get; set; } = false;
///
- /// List of terms
+ /// Whether to use red channel
///
- public string[] Term { get; set; }
+ public bool ContainsRed { get; set; } = true;
///
- /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').
+ /// Whether to use green channel
///
- public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence;
+ public bool ContainsGreen { get; set; } = true;
///
- /// Whether key value metadata should be text, regardless of the actual input type
+ /// Whether to use blue channel
///
- public bool TextKeyValues { get; set; } = false;
+ public bool ContainsBlue { get; set; } = true;
+
+ ///
+ /// Whether to separate each channel or interleave in ARGB order
+ ///
+ public bool InterleaveArgb { get; set; } = false;
+
+ ///
+ /// Width of the image
+ ///
+ public int ImageWidth { get; set; }
+
+ ///
+ /// Height of the image
+ ///
+ public int ImageHeight { get; set; }
+
+ ///
+ /// Offset (pre-scale)
+ ///
+ public float? Offset { get; set; }
+
+ ///
+ /// Scale factor
+ ///
+ public float? Scale { get; set; }
///
/// Input dataset
@@ -15261,18 +16404,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(TextToKeyConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(VectorToImage)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new TextToKeyConverterPipelineStep(output);
+ return new VectorToImagePipelineStep(output);
}
- private class TextToKeyConverterPipelineStep : ILearningPipelineDataStep
+ private class VectorToImagePipelineStep : ILearningPipelineDataStep
{
- public TextToKeyConverterPipelineStep(Output output)
+ public VectorToImagePipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -15286,68 +16429,95 @@ public TextToKeyConverterPipelineStep(Output output)
namespace Transforms
{
-
- ///
- /// Split the dataset into train and test sets
- ///
- public sealed partial class TrainTestDatasetSplitter
+ public enum WordEmbeddingsTransformPretrainedModelKind
{
+ GloVe50D = 0,
+ GloVe100D = 1,
+ GloVe200D = 2,
+ GloVe300D = 3,
+ GloVeTwitter25D = 4,
+ GloVeTwitter50D = 5,
+ GloVeTwitter100D = 6,
+ GloVeTwitter200D = 7,
+ FastTextWikipedia300D = 8,
+ Sswe = 9
+ }
+ public sealed partial class WordEmbeddingsTransformColumn : OneToOneColumn, IOneToOneColumn
+ {
///
- /// Input dataset
+ /// Name of the new column
///
- public Var Data { get; set; } = new Var();
+ public string Name { get; set; }
///
- /// Fraction of training data
+ /// Name of the source column
///
- public float Fraction { get; set; } = 0.8f;
+ public string Source { get; set; }
- ///
- /// Stratification column
- ///
- public string StratificationColumn { get; set; }
+ }
+ ///
+ ///
+ public sealed partial class WordEmbeddings : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
+ {
- public sealed class Output
+ public WordEmbeddings()
{
- ///
- /// Training data
- ///
- public Var TrainData { get; set; } = new Var();
-
- ///
- /// Testing data
- ///
- public Var TestData { get; set; } = new Var();
-
}
- }
- }
-
- namespace Transforms
- {
+
+ public WordEmbeddings(params string[] inputColumns)
+ {
+ if (inputColumns != null)
+ {
+ foreach (string input in inputColumns)
+ {
+ AddColumn(input);
+ }
+ }
+ }
+
+ public WordEmbeddings(params (string inputColumn, string outputColumn)[] inputOutputColumns)
+ {
+ if (inputOutputColumns != null)
+ {
+ foreach (var inputOutput in inputOutputColumns)
+ {
+ AddColumn(inputOutput.outputColumn, inputOutput.inputColumn);
+ }
+ }
+ }
+
+ public void AddColumn(string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(inputColumn));
+ Column = list.ToArray();
+ }
- ///
- public sealed partial class TreeLeafFeaturizer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.IFeaturizerInput, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem
- {
+ public void AddColumn(string outputColumn, string inputColumn)
+ {
+ var list = Column == null ? new List() : new List(Column);
+ list.Add(OneToOneColumn.Create(outputColumn, inputColumn));
+ Column = list.ToArray();
+ }
///
- /// Output column: The suffix to append to the default column names
+ /// New column definition(s) (optional form: name:src)
///
- public string Suffix { get; set; }
+ public WordEmbeddingsTransformColumn[] Column { get; set; }
///
- /// If specified, determines the permutation seed for applying this featurizer to a multiclass problem.
+ /// Pre-trained model used to create the vocabulary
///
- public int LabelPermutationSeed { get; set; }
+ public WordEmbeddingsTransformPretrainedModelKind? ModelKind { get; set; } = WordEmbeddingsTransformPretrainedModelKind.Sswe;
///
- /// Trainer to use
+ /// Filename for custom word embedding model
///
- public Var PredictorModel { get; set; } = new Var();
+ public string CustomLookupTable { get; set; }
///
/// Input dataset
@@ -15376,18 +16546,18 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper
{
if (!(previousStep is ILearningPipelineDataStep dataStep))
{
- throw new InvalidOperationException($"{ nameof(TreeLeafFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ throw new InvalidOperationException($"{ nameof(WordEmbeddings)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
}
Data = dataStep.Data;
}
Output output = experiment.Add(this);
- return new TreeLeafFeaturizerPipelineStep(output);
+ return new WordEmbeddingsPipelineStep(output);
}
- private class TreeLeafFeaturizerPipelineStep : ILearningPipelineDataStep
+ private class WordEmbeddingsPipelineStep : ILearningPipelineDataStep
{
- public TreeLeafFeaturizerPipelineStep(Output output)
+ public WordEmbeddingsPipelineStep(Output output)
{
Data = output.OutputData;
Model = output.Model;
@@ -15399,38 +16569,6 @@ public TreeLeafFeaturizerPipelineStep(Output output)
}
}
- namespace Transforms
- {
-
- ///
- /// Combines a TransformModel and a PredictorModel into a single PredictorModel.
- ///
- public sealed partial class TwoHeterogeneousModelCombiner
- {
-
-
- ///
- /// Transform model
- ///
- public Var TransformModel { get; set; } = new Var();
-
- ///
- /// Predictor model
- ///
- public Var PredictorModel { get; set; } = new Var();
-
-
- public sealed class Output
- {
- ///
- /// Predictor model
- ///
- public Var PredictorModel { get; set; } = new Var();
-
- }
- }
- }
-
namespace Transforms
{
diff --git a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
index 05688cd2af..3da05f1fbf 100644
--- a/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
+++ b/src/Microsoft.ML/Runtime/EntryPoints/OneVersusAllMacro.cs
@@ -136,7 +136,7 @@ private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out
[TlcModule.EntryPoint(Desc = "One-vs-All macro (OVA)",
Name = "Models.OneVersusAll",
- XmlInclude = new[] { @"" })]
+ XmlInclude = new[] { @"" })]
public static CommonOutputs.MacroOutput