diff --git a/bugbug/models/component.py b/bugbug/models/component.py index e8fa9f8eb1..fb171e6f73 100644 --- a/bugbug/models/component.py +++ b/bugbug/models/component.py @@ -10,6 +10,8 @@ import dateutil.parser import xgboost from dateutil.relativedelta import relativedelta +from imblearn.over_sampling import SMOTE +from imblearn.pipeline import Pipeline as ImblearnPipeline from sklearn.compose import ColumnTransformer from sklearn.feature_extraction import DictVectorizer from sklearn.pipeline import Pipeline @@ -103,7 +105,7 @@ def __init__(self, lemmatization=False): ] ) - self.clf = Pipeline( + self.clf = ImblearnPipeline( [ ( "union", @@ -119,6 +121,7 @@ def __init__(self, lemmatization=False): ] ), ), + ("sampler", SMOTE(random_state=1, sampling_strategy="all")), ( "estimator", xgboost.XGBClassifier(n_jobs=utils.get_physical_cpu_count()),