diff --git a/setup.cfg b/setup.cfg index f9e6652c..6947a0f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ install_requires = importlib-metadata pandas>=1.4,<2 torch>=1.10.0,<2.0 - scikit-learn>=1.0 + scikit-learn>=1.2 nflows>=0.14 numpy>=1.20, <1.24 lifelines>=0.27,!= 0.27.5 diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 45b13f50..1e6f9fec 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -60,7 +60,7 @@ class TabularEncoder(TransformerMixin, BaseEstimator): categorical_encoder: Union[str, type] = "onehot" continuous_encoder: Union[str, type] = "bayesian_gmm" - cat_encoder_params: dict = dict(handle_unknown="ignore", sparse=False) + cat_encoder_params: dict = dict(handle_unknown="ignore", sparse_output=False) cont_encoder_params: dict = dict(n_components=10) @validate_arguments(config=dict(arbitrary_types_allowed=True))