Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pandas FutureWarning: Length 1 tuple will be returned #1374

Merged
merged 5 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError)
from sdv.errors import ConstraintsNotMetError
from sdv.utils import groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -497,7 +498,7 @@ def sample(self, table_data):
Table data with additional ``constraint_columns``.
"""
condition_columns = [c for c in self.constraint_columns if c in table_data.columns]
grouped_conditions = table_data[condition_columns].groupby(condition_columns)
grouped_conditions = table_data[condition_columns].groupby(groupby_list(condition_columns))
all_sampled_rows = []
for group, dataframe in grouped_conditions:
if not isinstance(group, tuple):
Expand Down
5 changes: 2 additions & 3 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.utils import cast_to_iterable
from sdv.utils import cast_to_iterable, groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,8 +134,7 @@ def add_constraints(self, constraints):
def _validate_context_columns(self, data):
errors = []
if self.context_columns:
errors = []
for sequence_key_value, data_values in data.groupby(self._sequence_key):
for sequence_key_value, data_values in data.groupby(groupby_list(self._sequence_key)):
for context_column in self.context_columns:
if len(data_values[context_column].unique()) > 1:
errors.append((
Expand Down
8 changes: 5 additions & 3 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sdv.single_table.errors import InvalidDataError
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path
from sdv.utils import (
is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)
groupby_list, is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)

LOGGER = logging.getLogger(__name__)
COND_IDX = str(uuid.uuid4())
Expand Down Expand Up @@ -893,7 +893,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
condition_columns = list(conditions.columns)
conditions.index.name = COND_IDX
conditions = conditions.reset_index()
grouped_conditions = conditions.groupby(condition_columns)
grouped_conditions = conditions.groupby(groupby_list(condition_columns))

# sample
all_sampled_rows = []
Expand Down Expand Up @@ -935,7 +935,9 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
)
all_sampled_rows.append(sampled_rows)
else:
transformed_groups = transformed_conditions.groupby(transformed_columns)
transformed_groups = transformed_conditions.groupby(
groupby_list(transformed_columns)
)
for transformed_group, transformed_dataframe in transformed_groups:
if not isinstance(transformed_group, tuple):
transformed_group = [transformed_group]
Expand Down
5 changes: 5 additions & 0 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,8 @@ def load_data_from_csv(filepath, pandas_kwargs=None):
pandas_kwargs = pandas_kwargs or {}
data = pd.read_csv(filepath, **pandas_kwargs)
return data


def groupby_list(list_to_check):
"""Return the first element of the list if the length is 1 else the entire list."""
return list_to_check[0] if len(list_to_check) == 1 else list_to_check
9 changes: 9 additions & 0 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,3 +1006,12 @@ def test_sample(self):
})
instance._reject_sample.assert_any_call(num_rows=5, conditions={'b': 1})
pd.testing.assert_frame_equal(transformed_data, expected_result)

@patch('warnings.warn')
def test_groupby_removed_warning(self, mock_warn):
"""Test that the pandas warning is no longer raised."""
# Run
self.test_sample()

# Assert
mock_warn.assert_not_called()