Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/Microsoft.ML.LightGBM/LightGbmArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,40 @@
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.LightGBM;

[assembly: LoadableClass(typeof(LightGbmArguments.CpuExecutionDevice), typeof(LightGbmArguments.CpuExecutionDevice.Arguments),
typeof(SignatureLightGBMExecutionDevice), LightGbmArguments.CpuExecutionDevice.FriendlyName, LightGbmArguments.CpuExecutionDevice.Name)]
[assembly: LoadableClass(typeof(LightGbmArguments.GpuExecutionDevice), typeof(LightGbmArguments.GpuExecutionDevice.Arguments),
typeof(SignatureLightGBMExecutionDevice), LightGbmArguments.GpuExecutionDevice.FriendlyName, LightGbmArguments.GpuExecutionDevice.Name)]

[assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)]
[assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.DartBooster.FriendlyName, LightGbmArguments.DartBooster.Name)]
[assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments),
typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)]

[assembly: EntryPointModule(typeof(LightGbmArguments.CpuExecutionDevice.Arguments))]
[assembly: EntryPointModule(typeof(LightGbmArguments.GpuExecutionDevice.Arguments))]

[assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))]
[assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))]
[assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))]

namespace Microsoft.ML.Runtime.LightGBM
{
public delegate void SignatureLightGBMExecutionDevice();

public delegate void SignatureLightGBMBooster();

[TlcModule.ComponentKind("LightGbmExecutionDevice")]
public interface ISupportExecutionDeviceParameterFactory : IComponentFactory<IExecutionDeviceParameter>
{
}
public interface IExecutionDeviceParameter
{
void UpdateParameters(Dictionary<string, object> res);
}

[TlcModule.ComponentKind("BoosterParameterFunction")]
public interface ISupportBoosterParameterFactory : IComponentFactory<IBoosterParameter>
{
Expand Down Expand Up @@ -62,6 +81,22 @@ public virtual void UpdateParameters(Dictionary<string, object> res)
}
}

public abstract class ExecutionDeviceParameter<TArgs> : IExecutionDeviceParameter
where TArgs : class, new()
{
protected TArgs Args { get; }

protected ExecutionDeviceParameter(TArgs args)
{
Args = args;
}

/// <summary>
/// Update the parameters by specific ExecutionDevice, will update parameters into "res" directly.
/// </summary>
public abstract void UpdateParameters(Dictionary<string, object> res);
}

private static string GetArgName(string name)
{
StringBuilder strBuf = new StringBuilder();
Expand All @@ -82,6 +117,65 @@ private static string GetArgName(string name)
return strBuf.ToString();
}


public sealed class CpuExecutionDevice : ExecutionDeviceParameter<CpuExecutionDevice.Arguments>
{
public const string Name = "cpu_device";
public const string FriendlyName = "CPU Device";

[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "LightGBM CPU device.")]
public class Arguments : ISupportExecutionDeviceParameterFactory
{
public virtual IExecutionDeviceParameter CreateComponent(IHostEnvironment env) => new CpuExecutionDevice(this);
}

public CpuExecutionDevice(Arguments args)
: base(args)
{
}

public override void UpdateParameters(Dictionary<string, object> res)
{
return;
}
}

public sealed class GpuExecutionDevice : ExecutionDeviceParameter<GpuExecutionDevice.Arguments>
{
public const string Name = "gpu_device";
public const string FriendlyName = "GPU Device";

[TlcModule.Component(Name = Name, FriendlyName = FriendlyName, Desc = "LightGBM GPU device.")]
public class Arguments : ISupportExecutionDeviceParameterFactory
{
[Argument(ArgumentType.AtMostOnce, HelpText = "OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform. -1 means the system-wide default platform.", ShortName = "gpu_platform_id")]
public int PlatformId = -1;

[Argument(ArgumentType.AtMostOnce, HelpText = "OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID. -1 means the default device in the selected platform.", ShortName = "gpu_device_id")]
public int DeviceId = -1;

[Argument(ArgumentType.AtMostOnce, HelpText = "Use double precision math on GPU?", ShortName = "gpu_use_dp")]
public bool UseDoublePrecision = false;

public virtual IExecutionDeviceParameter CreateComponent(IHostEnvironment env) => new GpuExecutionDevice(this);
}

public GpuExecutionDevice(Arguments args)
: base(args)
{
Contracts.CheckUserArg(Args.PlatformId >= -1, nameof(Args.PlatformId), "must be >= -1.");
Contracts.CheckUserArg(Args.DeviceId >= -1, nameof(Args.DeviceId), "must be >= -1.");
}

public override void UpdateParameters(Dictionary<string, object> res)
{
res["device"] = "gpu";
if (Args.PlatformId != -1) res["gpu_platform_id"] = Args.PlatformId.ToString();
if (Args.DeviceId != -1) res["gpu_device_id"] = Args.DeviceId.ToString();
if (Args.UseDoublePrecision) res["gpu_use_dp"] = Args.UseDoublePrecision.ToString();
}
}

public sealed class TreeBooster : BoosterParameter<TreeBooster.Arguments>
{
public const string Name = "gbdt";
Expand Down Expand Up @@ -355,6 +449,9 @@ public enum EvalMetricType
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
public ISupportParallel ParallelTrainer = new SingleTrainerFactory();

[Argument(ArgumentType.Multiple, HelpText = "Which execution device to use, can be cpu_device or gpu_device. Note: GPU device requires compatible build of LightGBM dll.", SortOrder = 3, ShortName = "device")]
public ISupportExecutionDeviceParameterFactory ExecutionDevice = new CpuExecutionDevice.Arguments();

internal Dictionary<string, object> ToDictionary(IHost host)
{
Contracts.CheckValue(host, nameof(host));
Expand Down Expand Up @@ -408,6 +505,10 @@ internal Dictionary<string, object> ToDictionary(IHost host)
res[GetArgName(nameof(MaxCatThreshold))] = MaxCatThreshold.ToString();
res[GetArgName(nameof(CatSmooth))] = CatSmooth.ToString();
res[GetArgName(nameof(CatL2))] = CatL2.ToString();

var executionDeviceParams = ExecutionDevice.CreateComponent(host);
executionDeviceParams.UpdateParameters(res);

return res;
}
}
Expand Down
68 changes: 67 additions & 1 deletion src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7599,6 +7599,14 @@ public sealed partial class LightGbmBinaryClassifier : Microsoft.ML.Runtime.Entr
[TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})]
public double CatL2 { get; set; } = 10d;

/// <summary>
/// Execution device for training (CPU or GPU).
/// NOTE: GPU training requires compatible build of LightGBM as described here:
/// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version
/// </summary>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth moving this comment on the actual LightGBMArguments file, since this file is generated. Mabye on the ExecutionDevice classes.

[JsonConverter(typeof(ComponentSerializer))]
public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice();

/// <summary>
/// Parallel LightGBM Learning Algorithm
/// </summary>
Expand Down Expand Up @@ -7806,6 +7814,14 @@ public sealed partial class LightGbmClassifier : Microsoft.ML.Runtime.EntryPoint
[TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})]
public double CatL2 { get; set; } = 10d;

/// <summary>
/// Execution device for training (CPU or GPU).
/// NOTE: GPU training requires compatible build of LightGBM as described here:
/// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version
/// </summary>
[JsonConverter(typeof(ComponentSerializer))]
public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice();

/// <summary>
/// Parallel LightGBM Learning Algorithm
/// </summary>
Expand Down Expand Up @@ -8013,6 +8029,14 @@ public sealed partial class LightGbmRanker : Microsoft.ML.Runtime.EntryPoints.Co
[TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})]
public double CatL2 { get; set; } = 10d;

/// <summary>
/// Execution device for training (CPU or GPU).
/// NOTE: GPU training requires compatible build of LightGBM as described here:
/// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version
/// </summary>
[JsonConverter(typeof(ComponentSerializer))]
public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice();

/// <summary>
/// Parallel LightGBM Learning Algorithm
/// </summary>
Expand Down Expand Up @@ -8220,6 +8244,14 @@ public sealed partial class LightGbmRegressor : Microsoft.ML.Runtime.EntryPoints
[TlcModule.SweepableDiscreteParamAttribute("CatL2", new object[]{0.1f, 0.5f, 1, 5, 10})]
public double CatL2 { get; set; } = 10d;

/// <summary>
/// Execution device for training (CPU or GPU).
/// NOTE: GPU training requires compatible build of LightGBM as described here:
/// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version
/// </summary>
[JsonConverter(typeof(ComponentSerializer))]
public LightGbmExecutionDevice ExecutionDevice { get; set; } = new LightGbmCpuExecutionDevice();

/// <summary>
/// Parallel LightGBM Learning Algorithm
/// </summary>
Expand Down Expand Up @@ -15985,8 +16017,42 @@ public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase
internal override string ComponentName => "AutoMlState";
}

public abstract class BoosterParameterFunction : ComponentKind {}
public abstract class LightGbmExecutionDevice : ComponentKind { }

/// <summary>
/// LightGBM CPU execution device
/// </summary>
public sealed class LightGbmCpuExecutionDevice : LightGbmExecutionDevice
{
internal override string ComponentName => "cpu_device";
}

/// <summary>
/// LightGBM GPU execution device.
/// NOTE: Requires build of LightGBM that supports GPU as described here:
/// https://github.com/Microsoft/LightGBM/blob/master/docs/Installation-Guide.rst#build-gpu-version
/// </summary>
public sealed class LightGbmGpuExecutionDevice : LightGbmExecutionDevice
{
/// <summary>
/// OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform. -1 means the system-wide default platform.
/// </summary>
public int PlatformId { get; set; } = -1;

/// <summary>
/// OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID. -1 means the default device in the selected platform.
/// </summary>
public int DeviceId { get; set; } = -1;

/// <summary>
/// Use double precision math on GPU?
/// </summary>
public bool UseDoublePrecision { get; set; } = false;

internal override string ComponentName => "gpu_device";
}

public abstract class BoosterParameterFunction : ComponentKind {}


/// <summary>
Expand Down