diff --git a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs index 0612135ce3..2ce8da42ab 100644 --- a/src/Microsoft.ML.LightGBM/LightGbmArguments.cs +++ b/src/Microsoft.ML.LightGBM/LightGbmArguments.cs @@ -11,6 +11,11 @@ using Microsoft.ML.Runtime.Internal.Internallearn; using Microsoft.ML.Runtime.LightGBM; +[assembly: LoadableClass(typeof(LightGbmArguments.CpuExecutionDevice), typeof(LightGbmArguments.CpuExecutionDevice.Arguments), + typeof(SignatureLightGBMExecutionDevice), LightGbmArguments.CpuExecutionDevice.FriendlyName, LightGbmArguments.CpuExecutionDevice.Name)] +[assembly: LoadableClass(typeof(LightGbmArguments.GpuExecutionDevice), typeof(LightGbmArguments.GpuExecutionDevice.Arguments), + typeof(SignatureLightGBMExecutionDevice), LightGbmArguments.GpuExecutionDevice.FriendlyName, LightGbmArguments.GpuExecutionDevice.Name)] + [assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments), typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)] [assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments), @@ -18,14 +23,28 @@ [assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments), typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)] +[assembly: EntryPointModule(typeof(LightGbmArguments.CpuExecutionDevice.Arguments))] +[assembly: EntryPointModule(typeof(LightGbmArguments.GpuExecutionDevice.Arguments))] + [assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))] [assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))] [assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))] namespace Microsoft.ML.Runtime.LightGBM { + public delegate void SignatureLightGBMExecutionDevice(); + public delegate void SignatureLightGBMBooster(); + [TlcModule.ComponentKind("LightGbmExecutionDevice")] + public interface ISupportExecutionDeviceParameterFactory : IComponentFactory + { + } + public interface IExecutionDeviceParameter + { + void UpdateParameters(Dictionary res); + } + [TlcModule.ComponentKind("BoosterParameterFunction")] public interface ISupportBoosterParameterFactory : IComponentFactory { @@ -62,6 +81,22 @@ public virtual void UpdateParameters(Dictionary res) } } + public abstract class ExecutionDeviceParameter : IExecutionDeviceParameter + where TArgs : class, new() + { + protected TArgs Args { get; } + + protected ExecutionDeviceParameter(TArgs args) + { + Args = args; + } + + /// + /// Update the parameters by specific ExecutionDevice, will update parameters into "res" directly. + /// + public abstract void UpdateParameters(Dictionary res); + } + private static string GetArgName(string name) { StringBuilder strBuf = new StringBuilder(); @@ -82,6 +117,65 @@ private static string GetArgName(string name) return strBuf.ToString(); } + + public sealed class CpuExecutionDevice : ExecutionDeviceParameter + { + public const string Name = "cpu_device"; + public const string FriendlyName = "CPU Device"; + + [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "LightGBM CPU device.")] + public class Arguments : ISupportExecutionDeviceParameterFactory + { + public virtual IExecutionDeviceParameter CreateComponent(IHostEnvironment env) => new CpuExecutionDevice(this); + } + + public CpuExecutionDevice(Arguments args) + : base(args) + { + } + + public override void UpdateParameters(Dictionary res) + { + return; + } + } + + public sealed class GpuExecutionDevice : ExecutionDeviceParameter + { + public const string Name = "gpu_device"; + public const string FriendlyName = "GPU Device"; + + [TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "LightGBM GPU device.")] + public class Arguments : ISupportExecutionDeviceParameterFactory + { + [Argument(ArgumentType.AtMostOnce, HelpText = "OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform. -1 means the system-wide default platform.", ShortName = "gpu_platform_id")] + public int PlatformId = -1; + + [Argument(ArgumentType.AtMostOnce, HelpText = "OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID. -1 means the default device in the selected platform.", ShortName = "gpu_device_id")] + public int DeviceId = -1; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Use double precision math on GPU?", ShortName = "gpu_use_dp")] + public bool UseDoublePrecision = false; + + public virtual IExecutionDeviceParameter CreateComponent(IHostEnvironment env) => new GpuExecutionDevice(this); + } + + public GpuExecutionDevice(Arguments args) + : base(args) + { + Contracts.CheckUserArg(Args.PlatformId >= -1, nameof(Args.PlatformId), "must be >= -1."); + Contracts.CheckUserArg(Args.DeviceId >= -1, nameof(Args.DeviceId), "must be >= -1."); + } + + public override void UpdateParameters(Dictionary res) + { + res["device"] = "gpu"; + if (Args.PlatformId != -1) res["gpu_platform_id"] = Args.PlatformId.ToString(); + if (Args.DeviceId != -1) res["gpu_device_id"] = Args.DeviceId.ToString(); + if (Args.UseDoublePrecision) res["gpu_use_dp"] = Args.UseDoublePrecision.ToString(); + } + } + public sealed class TreeBooster : BoosterParameter { public const string Name = "gbdt"; @@ -355,6 +449,9 @@ public enum EvalMetricType [Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")] public ISupportParallel ParallelTrainer = new SingleTrainerFactory(); + [Argument(ArgumentType.Multiple, HelpText = "Which execution device to use, can be cpu_device or gpu_device. Note: GPU device requires compatible build of LightGBM dll.", SortOrder = 3, ShortName = "device")] + public ISupportExecutionDeviceParameterFactory ExecutionDevice = new CpuExecutionDevice.Arguments(); + internal Dictionary ToDictionary(IHost host) { Contracts.CheckValue(host, nameof(host)); @@ -408,6 +505,10 @@ internal Dictionary ToDictionary(IHost host) res[GetArgName(nameof(MaxCatThreshold))] = MaxCatThreshold.ToString(); res[GetArgName(nameof(CatSmooth))] = CatSmooth.ToString(); res[GetArgName(nameof(CatL2))] = CatL2.ToString(); + + var executionDeviceParams = ExecutionDevice.CreateComponent(host); + executionDeviceParams.UpdateParameters(res); + return res; } } diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index a0f7f17975..7151452d01 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -7599,6 +7599,14 @@ public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.Entr [TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})] public double CatL2 { get; set; } = 10d; + /// + /// Execution device for training (CPU or GPU). + /// NOTE: GPU training requires compatible build of LightGBM as described here: + /// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version + /// + [JsonConverter(typeof(ComponentSerializer))] + public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice(); + /// /// Parallel LightGBM Learning Algorithm /// @@ -7806,6 +7814,14 @@ public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoint [TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})] public double CatL2 { get; set; } = 10d; + /// + /// Execution device for training (CPU or GPU). + /// NOTE: GPU training requires compatible build of LightGBM as described here: + /// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version + /// + [JsonConverter(typeof(ComponentSerializer))] + public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice(); + /// /// Parallel LightGBM Learning Algorithm /// @@ -8013,6 +8029,14 @@ public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.Co [TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})] public double CatL2 { get; set; } = 10d; + /// + /// Execution device for training (CPU or GPU). + /// NOTE: GPU training requires compatible build of LightGBM as described here: + /// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version + /// + [JsonConverter(typeof(ComponentSerializer))] + public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice(); + /// /// Parallel LightGBM Learning Algorithm /// @@ -8220,6 +8244,14 @@ public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints [TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})] public double CatL2 { get; set; } = 10d; + /// + /// Execution device for training (CPU or GPU). + /// NOTE: GPU training requires compatible build of LightGBM as described here: + /// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version + /// + [JsonConverter(typeof(ComponentSerializer))] + public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice(); + /// /// Parallel LightGBM Learning Algorithm /// @@ -15985,8 +16017,42 @@ public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase internal override string ComponentName => "AutoMlState"; } - public abstract class BoosterParameterFunction : ComponentKind {} + public abstract class LightGbmExecutionDevice : ComponentKind { } + /// + /// LightGBM CPU execution device + /// + public sealed class LightGbmCpuExecutionDevice : LightGbmExecutionDevice + { + internal override string ComponentName => "cpu_device"; + } + + /// + /// LightGBM GPU execution device. + /// NOTE: Requires build of LightGBM that supports GPU as described here: + /// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version + /// + public sealed class LightGbmGpuExecutionDevice : LightGbmExecutionDevice + { + /// + /// OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform. -1 means the system-wide default platform. + /// + public int PlatformId { get; set; } = -1; + + /// + /// OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID. -1 means the default device in the selected platform. + /// + public int DeviceId { get; set; } = -1; + + /// + /// Use double precision math on GPU? + /// + public bool UseDoublePrecision { get; set; } = false; + + internal override string ComponentName => "gpu_device"; + } + + public abstract class BoosterParameterFunction : ComponentKind {} ///