diff --git a/sdv/constraints/base.py b/sdv/constraints/base.py index 82e50927d..5c122e177 100644 --- a/sdv/constraints/base.py +++ b/sdv/constraints/base.py @@ -14,6 +14,7 @@ from sdv.constraints.errors import ( AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError) from sdv.errors import ConstraintsNotMetError +from sdv.utils import groupby_list LOGGER = logging.getLogger(__name__) @@ -497,7 +498,7 @@ def sample(self, table_data): Table data with additional ``constraint_columns``. """ condition_columns = [c for c in self.constraint_columns if c in table_data.columns] - grouped_conditions = table_data[condition_columns].groupby(condition_columns) + grouped_conditions = table_data[condition_columns].groupby(groupby_list(condition_columns)) all_sampled_rows = [] for group, dataframe in grouped_conditions: if not isinstance(group, tuple): diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index f3cbd52da..c85886beb 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -15,7 +15,7 @@ from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table import GaussianCopulaSynthesizer from sdv.single_table.base import BaseSynthesizer -from sdv.utils import cast_to_iterable +from sdv.utils import cast_to_iterable, groupby_list LOGGER = logging.getLogger(__name__) @@ -134,8 +134,7 @@ def add_constraints(self, constraints): def _validate_context_columns(self, data): errors = [] if self.context_columns: - errors = [] - for sequence_key_value, data_values in data.groupby(self._sequence_key): + for sequence_key_value, data_values in data.groupby(groupby_list(self._sequence_key)): for context_column in self.context_columns: if len(data_values[context_column].unique()) > 1: errors.append(( diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 5c5ec1963..23f1bd418 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -23,7 +23,7 @@ from sdv.single_table.errors import InvalidDataError from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path from sdv.utils import ( - is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format) + groupby_list, is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format) LOGGER = logging.getLogger(__name__) COND_IDX = str(uuid.uuid4()) @@ -893,7 +893,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, condition_columns = list(conditions.columns) conditions.index.name = COND_IDX conditions = conditions.reset_index() - grouped_conditions = conditions.groupby(condition_columns) + grouped_conditions = conditions.groupby(groupby_list(condition_columns)) # sample all_sampled_rows = [] @@ -935,7 +935,9 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, ) all_sampled_rows.append(sampled_rows) else: - transformed_groups = transformed_conditions.groupby(transformed_columns) + transformed_groups = transformed_conditions.groupby( + groupby_list(transformed_columns) + ) for transformed_group, transformed_dataframe in transformed_groups: if not isinstance(transformed_group, tuple): transformed_group = [transformed_group] diff --git a/sdv/utils.py b/sdv/utils.py index 9d90808a7..4e14c4722 100644 --- a/sdv/utils.py +++ b/sdv/utils.py @@ -164,3 +164,8 @@ def load_data_from_csv(filepath, pandas_kwargs=None): pandas_kwargs = pandas_kwargs or {} data = pd.read_csv(filepath, **pandas_kwargs) return data + + +def groupby_list(list_to_check): + """Return the first element of the list if the length is 1 else the entire list.""" + return list_to_check[0] if len(list_to_check) == 1 else list_to_check diff --git a/tests/unit/constraints/test_base.py b/tests/unit/constraints/test_base.py index a8b4025ae..98bfdecf4 100644 --- a/tests/unit/constraints/test_base.py +++ b/tests/unit/constraints/test_base.py @@ -1006,3 +1006,12 @@ def test_sample(self): }) instance._reject_sample.assert_any_call(num_rows=5, conditions={'b': 1}) pd.testing.assert_frame_equal(transformed_data, expected_result) + + @patch('warnings.warn') + def test_groupby_removed_warning(self, mock_warn): + """Test that the pandas warning is no longer raised.""" + # Run + self.test_sample() + + # Assert + mock_warn.assert_not_called()