Skip to content

Commit

Permalink
Revert "Pandas FutureWarning: Length 1 tuple will be returned (#1374)"
Browse files Browse the repository at this point in the history
This reverts commit f73fd8d.
  • Loading branch information
R-Palazzo authored Apr 20, 2023
1 parent f73fd8d commit e47c696
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 23 deletions.
3 changes: 1 addition & 2 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, MissingConstraintColumnError)
from sdv.errors import ConstraintsNotMetError
from sdv.utils import groupby_list

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -498,7 +497,7 @@ def sample(self, table_data):
Table data with additional ``constraint_columns``.
"""
condition_columns = [c for c in self.constraint_columns if c in table_data.columns]
grouped_conditions = table_data[condition_columns].groupby(groupby_list(condition_columns))
grouped_conditions = table_data[condition_columns].groupby(condition_columns)
all_sampled_rows = []
for group, dataframe in grouped_conditions:
if not isinstance(group, tuple):
Expand Down
5 changes: 3 additions & 2 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table.base import BaseSynthesizer
from sdv.utils import cast_to_iterable, groupby_list
from sdv.utils import cast_to_iterable

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,7 +134,8 @@ def add_constraints(self, constraints):
def _validate_context_columns(self, data):
errors = []
if self.context_columns:
for sequence_key_value, data_values in data.groupby(groupby_list(self._sequence_key)):
errors = []
for sequence_key_value, data_values in data.groupby(self._sequence_key):
for context_column in self.context_columns:
if len(data_values[context_column].unique()) > 1:
errors.append((
Expand Down
8 changes: 3 additions & 5 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sdv.single_table.errors import InvalidDataError
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path
from sdv.utils import (
groupby_list, is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)
is_boolean_type, is_datetime_type, is_numerical_type, validate_datetime_format)

LOGGER = logging.getLogger(__name__)
COND_IDX = str(uuid.uuid4())
Expand Down Expand Up @@ -893,7 +893,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
condition_columns = list(conditions.columns)
conditions.index.name = COND_IDX
conditions = conditions.reset_index()
grouped_conditions = conditions.groupby(groupby_list(condition_columns))
grouped_conditions = conditions.groupby(condition_columns)

# sample
all_sampled_rows = []
Expand Down Expand Up @@ -935,9 +935,7 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
)
all_sampled_rows.append(sampled_rows)
else:
transformed_groups = transformed_conditions.groupby(
groupby_list(transformed_columns)
)
transformed_groups = transformed_conditions.groupby(transformed_columns)
for transformed_group, transformed_dataframe in transformed_groups:
if not isinstance(transformed_group, tuple):
transformed_group = [transformed_group]
Expand Down
5 changes: 0 additions & 5 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,3 @@ def load_data_from_csv(filepath, pandas_kwargs=None):
pandas_kwargs = pandas_kwargs or {}
data = pd.read_csv(filepath, **pandas_kwargs)
return data


def groupby_list(list_to_check):
"""Return the first element of the list if the length is 1 else the entire list."""
return list_to_check[0] if len(list_to_check) == 1 else list_to_check
9 changes: 0 additions & 9 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,12 +1006,3 @@ def test_sample(self):
})
instance._reject_sample.assert_any_call(num_rows=5, conditions={'b': 1})
pd.testing.assert_frame_equal(transformed_data, expected_result)

@patch('warnings.warn')
def test_groupby_removed_warning(self, mock_warn):
"""Test that the pandas warning is no longer raised."""
# Run
self.test_sample()

# Assert
mock_warn.assert_not_called()

0 comments on commit e47c696

Please sign in to comment.