diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 4bc080a1e..ccc13fd05 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -468,7 +468,9 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - self._fit(processed_data) + if not processed_data.empty: + self._fit(processed_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = pkg_resources.get_distribution('sdv').version diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 572832d0d..89e96cf61 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -718,6 +718,21 @@ def test_fit_processed_data(self): instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted + def test_fit_processed_data_empty_table(self): + """Test attributes are properly set when data is empty and that _fit is not called.""" + # Setup + instance = Mock() + data = pd.DataFrame() + + # Run + BaseMultiTableSynthesizer.fit_processed_data(instance, data) + + # Assert + instance._fit.assert_not_called() + assert instance._fitted + assert instance._fitted_date + assert instance._fitted_sdv_version + def test_fit(self): """Test that ``fit`` calls ``preprocess`` and then ``fit_processed_data``.""" # Setup