@@ -764,6 +764,10 @@ class SMOTEN(SMOTE):
764764
765765 Parameters
766766 ----------
767+ categorical_encoder : estimator, default=None
768+ Ordinal encoder used to encode the categorical features. If `None`, a
769+ :class:`~sklearn.preprocessing.OrdinalEncoder` is used with default parameters.
770+
767771 {sampling_strategy}
768772
769773 {random_state}
@@ -791,6 +795,9 @@ class SMOTEN(SMOTE):
791795
792796 Attributes
793797 ----------
798+ categorical_encoder_ : estimator
799+ The encoder used to encode the categorical features.
800+
794801 sampling_strategy_ : dict
795802 Dictionary containing the information to sample the dataset. The keys
796803 corresponds to the class labels from which to sample and the values
@@ -853,6 +860,31 @@ class SMOTEN(SMOTE):
853860 Class counts after resampling Counter({{0: 40, 1: 40}})
854861 """
855862
863+ _parameter_constraints : dict = {
864+ ** SMOTE ._parameter_constraints ,
865+ "categorical_encoder" : [
866+ HasMethods (["fit_transform" , "inverse_transform" ]),
867+ None ,
868+ ],
869+ }
870+
871+ def __init__ (
872+ self ,
873+ categorical_encoder = None ,
874+ * ,
875+ sampling_strategy = "auto" ,
876+ random_state = None ,
877+ k_neighbors = 5 ,
878+ n_jobs = None ,
879+ ):
880+ super ().__init__ (
881+ sampling_strategy = sampling_strategy ,
882+ random_state = random_state ,
883+ k_neighbors = k_neighbors ,
884+ n_jobs = n_jobs ,
885+ )
886+ self .categorical_encoder = categorical_encoder
887+
856888 def _check_X_y (self , X , y ):
857889 """Check should accept strings and not sparse matrices."""
858890 y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
@@ -900,11 +932,14 @@ def _fit_resample(self, X, y):
900932 X_resampled = [X .copy ()]
901933 y_resampled = [y .copy ()]
902934
903- encoder = OrdinalEncoder (dtype = np .int32 )
904- X_encoded = encoder .fit_transform (X )
935+ if self .categorical_encoder is None :
936+ self .categorical_encoder_ = OrdinalEncoder (dtype = np .int32 )
937+ else :
938+ self .categorical_encoder_ = clone (self .categorical_encoder )
939+ X_encoded = self .categorical_encoder_ .fit_transform (X )
905940
906941 vdm = ValueDifferenceMetric (
907- n_categories = [len (cat ) for cat in encoder .categories_ ]
942+ n_categories = [len (cat ) for cat in self . categorical_encoder_ .categories_ ]
908943 ).fit (X_encoded , y )
909944
910945 for class_sample , n_samples in self .sampling_strategy_ .items ():
@@ -922,7 +957,7 @@ def _fit_resample(self, X, y):
922957 X_class , class_sample , y .dtype , nn_indices , n_samples
923958 )
924959
925- X_new = encoder .inverse_transform (X_new )
960+ X_new = self . categorical_encoder_ .inverse_transform (X_new )
926961 X_resampled .append (X_new )
927962 y_resampled .append (y_new )
928963
0 commit comments