From 8372027e4b8eaaeeffa4b1ce550a3e71e01e3f2f Mon Sep 17 00:00:00 2001 From: Felipe Date: Tue, 1 Aug 2023 20:40:43 -0700 Subject: [PATCH 1/3] Update preprocess_data --- sdv/single_table/base.py | 4 +++- tests/unit/multi_table/test_base.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 4bc080a1e..ccc13fd05 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -468,7 +468,9 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - self._fit(processed_data) + if not processed_data.empty: + self._fit(processed_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = pkg_resources.get_distribution('sdv').version diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 572832d0d..8a8db79d7 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -717,6 +717,20 @@ def test_fit_processed_data(self): instance._augment_tables.assert_called_once_with(data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted + + def test_fit_processed_data_empty_table(self): + """Test the fit attributes are properly set when data is empty.""" + # Setup + instance = Mock() + data = pd.DataFrame() + + # Run + BaseMultiTableSynthesizer.fit_processed_data(instance, data) + + # Assert + assert instance._fitted + assert instance._fitted_date + assert instance._fitted_sdv_version def test_fit(self): """Test that ``fit`` calls ``preprocess`` and then ``fit_processed_data``.""" From 9da95a14cdfd42613a46c0d03a6a8ac1a4b18464 Mon Sep 17 00:00:00 2001 From: Felipe Date: Tue, 1 Aug 2023 20:47:47 -0700 Subject: [PATCH 2/3] Fix lint --- tests/unit/multi_table/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 8a8db79d7..63e434106 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -717,7 +717,7 @@ def test_fit_processed_data(self): instance._augment_tables.assert_called_once_with(data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted - + def test_fit_processed_data_empty_table(self): """Test the fit attributes are properly set when data is empty.""" # Setup From c692ae52a399509c606af7f0016dd7c01fae43aa Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 2 Aug 2023 09:08:08 -0700 Subject: [PATCH 3/3] Update test --- tests/unit/multi_table/test_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 63e434106..89e96cf61 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -719,7 +719,7 @@ def test_fit_processed_data(self): assert instance._fitted def test_fit_processed_data_empty_table(self): - """Test the fit attributes are properly set when data is empty.""" + """Test attributes are properly set when data is empty and that _fit is not called.""" # Setup instance = Mock() data = pd.DataFrame() @@ -728,6 +728,7 @@ def test_fit_processed_data_empty_table(self): BaseMultiTableSynthesizer.fit_processed_data(instance, data) # Assert + instance._fit.assert_not_called() assert instance._fitted assert instance._fitted_date assert instance._fitted_sdv_version