Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GBM bugfix: matching predictions LightGBM, hummingbird #2574

Merged
merged 11 commits into from
Oct 3, 2022
92 changes: 64 additions & 28 deletions ludwig/models/gbm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import copy
import os
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import lightgbm as lgb
import numpy as np
import torch
import torchmetrics
from hummingbird.ml import convert
from hummingbird.ml.operator_converters import constants as hb_constants

from ludwig.constants import BINARY, CATEGORY, LOGITS, MODEL_GBM, NAME, NUMBER
from ludwig.constants import BINARY, CATEGORY, LOGITS, MODEL_GBM, NAME, NUMBER, TYPE
from ludwig.features.base_feature import OutputFeature
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.utils import output_feature_utils
from ludwig.utils.torch_utils import get_torch_device
from ludwig.utils.types import TorchDevice


class GBM(BaseModel):
Expand All @@ -28,6 +30,16 @@ def __init__(
random_seed: int = None,
**_kwargs,
):
if len(output_features) > 1:
jppgks marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Only single task currently supported")
jppgks marked this conversation as resolved.
Show resolved Hide resolved
feat_types = {f[TYPE] for f in output_features + input_features}
unsupported_types = feat_types - {NUMBER, CATEGORY, BINARY}
if len(unsupported_types) != 0:
raise ValueError(
"Model type GBM only supports numerical, categorical, or binary features "
f"but got unsupported types {unsupported_types}"
)

super().__init__(random_seed=random_seed)

self._input_features_def = copy.deepcopy(input_features)
Expand All @@ -48,7 +60,7 @@ def __init__(
self.eval_loss_metric = torchmetrics.MeanMetric()
self.eval_additional_losses_metrics = torchmetrics.MeanMetric()

self.lgb_booster: lgb.Booster = None
self.lgbm_model: lgb.LGBMModel = None
self.compiled_model: torch.nn.Module = None

@classmethod
Expand All @@ -69,24 +81,16 @@ def build_outputs(cls, output_features_def: List[Dict[str, Any]], input_size: in

def compile(self):
"""Convert the LightGBM model to a PyTorch model and store internally."""
if self.lgb_booster is None:
if self.lgbm_model is None:
raise ValueError("Model has not been trained yet.")

output_feature_name = self.output_features.keys()[0]
output_feature = self.output_features[output_feature_name]

# https://github.com/microsoft/LightGBM/issues/1942#issuecomment-453975607
gbm_sklearn_cls = lgb.LGBMRegressor if output_feature.type() == NUMBER else lgb.LGBMClassifier
gbm_sklearn = gbm_sklearn_cls(feature_name=list(self.input_features.keys())) # , **params)
gbm_sklearn._Booster = self.lgb_booster
gbm_sklearn.fitted_ = True
gbm_sklearn._n_features = len(self.input_features)
if isinstance(gbm_sklearn, lgb.LGBMClassifier):
gbm_sklearn._n_classes = output_feature.num_classes if output_feature.type() == CATEGORY else 2

hb_model = convert(gbm_sklearn, "torch", extra_config={"tree_implementation": "gemm"})

self.compiled_model = hb_model.model
# explicitly use sigmoid for classification, so we can invert to logits at inference time
extra_config = (
{hb_constants.POST_TRANSFORM: hb_constants.SIGMOID}
if isinstance(self.lgbm_model, lgb.LGBMClassifier)
else {}
)
self.compiled_model = convert(self.lgbm_model, "torch", extra_config=extra_config)

def forward(
self,
Expand Down Expand Up @@ -133,40 +137,72 @@ def forward(
output_feature_name = self.output_features.keys()[0]
output_feature = self.output_features[output_feature_name]

preds = self.compiled_model(inputs)
assert (
type(inputs) is torch.Tensor
and inputs.dtype == torch.float32
and inputs.ndim == 2
and inputs.shape[1] == len(self.input_features)
), (
f"Expected inputs to be a 2D tensor of shape (batch_size, {len(self.input_features)}) of type float32, "
f"but got {inputs.shape} of type {inputs.dtype}"
)
# Predict using PyTorch module, so it is included when converting to TorchScript.
preds = self.compiled_model.model(inputs)

if output_feature.type() == NUMBER:
# regression
if len(preds.shape) == 2:
preds = preds.squeeze(1)
logits = preds
logits = preds.view(-1)
else:
# classification
_, probs = preds
# keep positive class only for binary feature
probs = probs[:, 1] if output_feature.type() == BINARY else probs

if output_feature.type() == BINARY:
# keep positive class only for binary feature
probs = probs[:, 1] # shape (batch_size,)
elif output_feature.num_classes > 2:
probs = probs.view(-1, 2, output_feature.num_classes) # shape (batch_size, 2, num_classes)
probs = probs.transpose(2, 1) # shape (batch_size, num_classes, 2)

# probabilities for belonging to each class
probs = probs[:, :, 1] # shape (batch_size, num_classes)

# invert sigmoid to get back logits and use Ludwig's output feature prediction functionality
logits = torch.logit(probs)

output_feature_utils.set_output_feature_tensor(output_logits, output_feature_name, LOGITS, logits)

return output_logits

def save(self, save_path):
"""Saves the model to the given path."""
if self.lgb_booster is None:
if self.lgbm_model is None:
raise ValueError("Model has not been trained yet.")

import joblib

weights_save_path = os.path.join(save_path, MODEL_WEIGHTS_FILE_NAME)
self.lgb_booster.save_model(weights_save_path, num_iteration=self.lgb_booster.best_iteration)
joblib.dump(self.lgbm_model, weights_save_path)

def load(self, save_path):
"""Loads the model from the given path."""
import joblib

weights_save_path = os.path.join(save_path, MODEL_WEIGHTS_FILE_NAME)
self.lgb_booster = lgb.Booster(model_file=weights_save_path)
self.lgbm_model = joblib.load(weights_save_path)

self.compile()

device = torch.device(get_torch_device())
self.compiled_model.to(device)

def to_torchscript(self, device: Optional[TorchDevice] = None):
"""Converts the ECD model as a TorchScript model."""

# Disable gradient calculation for hummingbird Parameter nodes.
self.compiled_model.model.requires_grad_(False)

return super().to_torchscript(device)

def get_args(self):
"""Returns init arguments for constructing this model."""
return (self._input_features_df, self._output_features_df, self._random_seed)
94 changes: 50 additions & 44 deletions ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,22 @@ def _train_loop(
) -> bool:
self.callback(lambda c: c.on_batch_start(self, progress_tracker, save_path))

booster = None
evals_result = {}
booster = self.train_step(
params, lgb_train, eval_sets, eval_names, booster, self.boosting_rounds_per_checkpoint, evals_result
self.model.lgbm_model = self.train_step(
params,
lgb_train,
eval_sets,
eval_names,
self.model.lgbm_model,
self.boosting_rounds_per_checkpoint,
evals_result,
)

progress_bar.update(self.boosting_rounds_per_checkpoint)
progress_tracker.steps += self.boosting_rounds_per_checkpoint
progress_tracker.last_improvement_steps = booster.best_iteration
progress_tracker.last_improvement_steps = self.model.lgbm_model.best_iteration_

# convert to pytorch for inference
self.model.lgb_booster = booster
self.model.compile()
self.model = self.model.to(self.device)

Expand Down Expand Up @@ -463,10 +467,10 @@ def train_step(
lgb_train: lgb.Dataset,
eval_sets: List[lgb.Dataset],
eval_names: List[str],
booster: lgb.Booster,
init_model: lgb.LGBMModel,
boost_rounds_per_train_step: int,
evals_result: Dict,
) -> lgb.Booster:
) -> lgb.LGBMModel:
"""Trains a LightGBM model.

Args:
Expand All @@ -478,18 +482,21 @@ def train_step(
Returns:
LightGBM Booster model
"""
gbm = lgb.train(
params,
lgb_train,
init_model=booster,
num_boost_round=boost_rounds_per_train_step,
valid_sets=eval_sets,
valid_names=eval_names,
feature_name=list(self.model.input_features.keys()),
output_feature = next(iter(self.model.output_features.values()))
gbm_sklearn_cls = lgb.LGBMRegressor if output_feature.type() == NUMBER else lgb.LGBMClassifier

gbm = gbm_sklearn_cls(n_estimators=boost_rounds_per_train_step, **params).fit(
X=lgb_train.get_data(),
y=lgb_train.get_label(),
init_model=init_model,
eval_set=[(ds.get_data(), ds.get_label()) for ds in eval_sets],
eval_names=eval_names,
# add early stopping callback to populate best_iteration
callbacks=[lgb.early_stopping(boost_rounds_per_train_step)],
# NOTE: hummingbird does not support categorical features
# categorical_feature=categorical_features,
evals_result=evals_result,
)
evals_result.update(gbm.evals_result_)

return gbm

Expand All @@ -514,10 +521,6 @@ def train(

# TODO: construct new datasets by running encoders (for text, image)

# TODO: only single task currently
if len(output_features) > 1:
raise ValueError("Only single task currently supported")

metrics_names = get_metric_names(output_features)

# check if validation_field is valid
Expand Down Expand Up @@ -719,17 +722,17 @@ def set_steps_to_1_or_quit(self, signum, frame):
def _construct_lgb_params(self) -> Tuple[dict, dict]:
output_params = {}
feature = next(iter(self.model.output_features.values()))
if feature.type() == CATEGORY:
if feature.type() == BINARY or (hasattr(feature, "num_classes") and feature.num_classes == 2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check the num_classes stuff here when the feature is binary?

Copy link
Contributor Author

@jppgks jppgks Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a user can specify a category variable with only two classes, this is to catch that case and explicitly use LightGBM with the binary objective

output_params = {
"objective": "binary",
"metric": ["binary_logloss"],
}
elif feature.type() == CATEGORY:
output_params = {
"objective": "multiclass",
"metric": ["multi_logloss"],
"num_class": feature.num_classes,
}
elif feature.type() == BINARY:
output_params = {
"objective": "binary",
"metric": ["binary_logloss"],
}
elif feature.type() == NUMBER:
output_params = {
"objective": "regression",
Expand Down Expand Up @@ -799,15 +802,15 @@ def _construct_lgb_datasets(
y_train = training_set.to_df(self.model.output_features.values())

# create dataset for lightgbm
# if you want to re-use data, remember to set free_raw_data=False
lgb_train = lgb.Dataset(X_train, label=y_train)
# keep raw data for continued training https://github.com/microsoft/LightGBM/issues/4965#issuecomment-1019344293
lgb_train = lgb.Dataset(X_train, label=y_train, free_raw_data=False).construct()

eval_sets = [lgb_train]
eval_names = [LightGBMTrainer.TRAIN_KEY]
if validation_set is not None:
X_val = validation_set.to_df(self.model.input_features.values())
y_val = validation_set.to_df(self.model.output_features.values())
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train)
lgb_val = lgb.Dataset(X_val, label=y_val, reference=lgb_train, free_raw_data=False).construct()
eval_sets.append(lgb_val)
eval_names.append(LightGBMTrainer.VALID_KEY)
else:
Expand All @@ -817,7 +820,7 @@ def _construct_lgb_datasets(
if test_set is not None:
X_test = test_set.to_df(self.model.input_features.values())
y_test = test_set.to_df(self.model.output_features.values())
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train)
lgb_test = lgb.Dataset(X_test, label=y_test, reference=lgb_train, free_raw_data=False).construct()
eval_sets.append(lgb_test)
eval_names.append(LightGBMTrainer.TEST_KEY)

Expand Down Expand Up @@ -898,10 +901,10 @@ def train_step(
lgb_train: "RayDMatrix", # noqa: F821
eval_sets: List["RayDMatrix"], # noqa: F821
eval_names: List[str],
booster: lgb.Booster,
init_model: lgb.LGBMModel,
boost_rounds_per_train_step: int,
evals_result: Dict,
) -> lgb.Booster:
) -> lgb.LGBMModel:
"""Trains a LightGBM model using ray.

Args:
Expand All @@ -913,23 +916,26 @@ def train_step(
Returns:
LightGBM Booster model
"""
from lightgbm_ray import train as lgb_ray_train
from lightgbm_ray import RayLGBMClassifier, RayLGBMRegressor

gbm = lgb_ray_train(
params,
lgb_train,
init_model=booster,
num_boost_round=boost_rounds_per_train_step,
valid_sets=eval_sets,
valid_names=eval_names,
feature_name=list(self.model.input_features.keys()),
evals_result=evals_result,
output_feature = next(iter(self.model.output_features.values()))
gbm_sklearn_cls = RayLGBMRegressor if output_feature.type() == NUMBER else RayLGBMClassifier

gbm = gbm_sklearn_cls(n_estimators=boost_rounds_per_train_step, **params).fit(
X=lgb_train,
y=None,
init_model=init_model,
eval_set=[(s, n) for s, n in zip(eval_sets, eval_names)],
eval_names=eval_names,
# add early stopping callback to populate best_iteration
callbacks=[lgb.early_stopping(boost_rounds_per_train_step)],
ray_params=_map_to_lgb_ray_params(self.trainer_kwargs),
# NOTE: hummingbird does not support categorical features
# categorical_feature=categorical_features,
ray_params=_map_to_lgb_ray_params(self.trainer_kwargs),
)
evals_result.update(gbm.evals_result_)

return gbm.booster_
return gbm.to_local()

def _construct_lgb_datasets(
self,
Expand Down
Loading