Skip to content

Commit

Permalink
Pandas FutureWarning: Length 1 tuple will be returned (#1374)
Browse files Browse the repository at this point in the history
* warning fix

* add groupby_list method

* lint

* formatting

* add unit test
  • Loading branch information
R-Palazzo authored Apr 20, 2023
1 parent 197769a commit f73fd8d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 7 deletions.
3 changes: 2 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError)
from sdv.errors import ConstraintsNotMetError
from sdv.utils import groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -497,7 +498,7 @@ def sample(self, table_data):
Table data with additional ``constraint_columns``.
"""
condition_columns = [c for c in self.constraint_columns if c in table_data.columns]
grouped_conditions = table_data[condition_columns].groupby(condition_columns)
grouped_conditions = table_data[condition_columns].groupby(groupby_list(condition_columns))
all_sampled_rows = []
for group, dataframe in grouped_conditions:
if not isinstance(group, tuple):
Expand Down
5 changes: 2 additions & 3 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.utils import cast_to_iterable
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,8 +134,7 @@ def add_constraints(self, constraints):
def _validate_context_columns(self, data):
errors = []
if self.context_columns:
errors = []
for sequence_key_value, data_values in data.groupby(self._sequence_key):
for sequence_key_value, data_values in data.groupby(groupby_list(self._sequence_key)):
for context_column in self.context_columns:
if len(data_values[context_column].unique()) > 1:
errors.append((
Expand Down
8 changes: 5 additions & 3 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sdv.single_table.errors import InvalidDataError
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path
from sdv.utils import (
is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)
groupby_list, is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)

LOGGER = logging.getLogger(__name__)
COND_IDX = str(uuid.uuid4())
Expand Down Expand Up @@ -893,7 +893,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
condition_columns = list(conditions.columns)
conditions.index.name = COND_IDX
conditions = conditions.reset_index()
grouped_conditions = conditions.groupby(condition_columns)
grouped_conditions = conditions.groupby(groupby_list(condition_columns))

# sample
all_sampled_rows = []
Expand Down Expand Up @@ -935,7 +935,9 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
)
all_sampled_rows.append(sampled_rows)
else:
transformed_groups = transformed_conditions.groupby(transformed_columns)
transformed_groups = transformed_conditions.groupby(
groupby_list(transformed_columns)
)
for transformed_group, transformed_dataframe in transformed_groups:
if not isinstance(transformed_group, tuple):
transformed_group = [transformed_group]
Expand Down
5 changes: 5 additions & 0 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,8 @@ def load_data_from_csv(filepath, pandas_kwargs=None):
pandas_kwargs = pandas_kwargs or {}
data = pd.read_csv(filepath, **pandas_kwargs)
return data


def groupby_list(list_to_check):
"""Return the first element of the list if the length is 1 else the entire list."""
return list_to_check[0] if len(list_to_check) == 1 else list_to_check
9 changes: 9 additions & 0 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,12 @@ def test_sample(self):
})
instance._reject_sample.assert_any_call(num_rows=5, conditions={'b': 1})
pd.testing.assert_frame_equal(transformed_data, expected_result)

@patch('warnings.warn')
def test_groupby_removed_warning(self, mock_warn):
"""Test that the pandas warning is no longer raised."""
# Run
self.test_sample()

# Assert
mock_warn.assert_not_called()

0 comments on commit f73fd8d

Please sign in to comment.