Skip to content

Commit

Permalink
Expand UniqueCombinations constraint to handle non-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Jun 25, 2021
1 parent 1dbbfda commit 65e0408
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 6 deletions.
40 changes: 36 additions & 4 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
on the other columns of the table.
"""

import uuid

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -74,6 +76,8 @@ class UniqueCombinations(Constraint):

_separator = None
_joint_column = None
_combination_map = None
_unique_value_map = None

def __init__(self, columns, handling_strategy='transform', fit_columns_model=True):
self._columns = columns
Expand All @@ -98,7 +102,8 @@ def _valid_separator(self, table_data):
Whether the separator is valid for this data or not.
"""
for column in self._columns:
if table_data[column].str.contains(self._separator).any():
if isinstance(table_data[column], str) and \
table_data[column].str.contains(self._separator).any():
return False

if self._separator.join(self._columns) in table_data:
Expand All @@ -111,7 +116,7 @@ def _fit(self, table_data):
The fit process consists on:
- Finding a separtor that works for the
- Finding a separator that works for the
current data by iteratively adding `#` to it.
- Generating the joint column name by concatenating
the names of ``self._columns`` with the separator.
Expand Down Expand Up @@ -163,8 +168,25 @@ def _transform(self, table_data):
Transformed data.
"""
lists_series = pd.Series(table_data[self._columns].values.tolist())

non_string_cols = [dt for x, dt in table_data.dtypes[
self._columns].items() if dt != object]
if len(non_string_cols) > 0:
u_lists_series = []
self._combination_map = {}
self._unique_value_map = {}

for ls in lists_series:
u = str(uuid.uuid4())
self._combination_map[tuple(ls)] = u
self._unique_value_map[u] = ls
u_lists_series.append(u)

lists_series = pd.Series(u_lists_series)

table_data = table_data.drop(self._columns, axis=1)
table_data[self._joint_column] = lists_series.str.join(self._separator)
table_data[self._joint_column] = lists_series.str.join(
self._separator) if self._combination_map is None else lists_series

return table_data

Expand All @@ -185,7 +207,17 @@ def reverse_transform(self, table_data):
Transformed data.
"""
table_data = table_data.copy()
columns = table_data.pop(self._joint_column).str.split(self._separator)

if self._combination_map is None:
columns = table_data.pop(self._joint_column).str.split(self._separator)
else:
uuids = table_data.pop(self._joint_column)
combinations = []
for u in uuids:
combinations.append(self._unique_value_map[u])

columns = pd.Series(combinations, name=self._joint_column)

for index, column in enumerate(self._columns):
table_data[column] = columns.str[index]

Expand Down
80 changes: 78 additions & 2 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the sdv.constraints.tabular module."""

import uuid

import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -281,6 +283,40 @@ def test_transform(self):
})
pd.testing.assert_frame_equal(expected_out, out)

def test_transform_non_string(self):
"""Test the ``UniqueCombinations.transform`` method with non strings.
It is expected to return a Table data with the columns concatenated by the separator.
Input:
- Table data (pandas.DataFrame)
Output:
- Table data transformed, with the columns as UUIDs.
Side effects:
- Since the ``transform`` method needs ``self._joint_column``, method ``fit``
must be called as well.
"""
# Setup
table_data = pd.DataFrame({
'a': ['a', 'b', 'c'],
'b': [1, 2, 3],
'c': ['g', 'h', 'i']
})
columns = ['b', 'c']
instance = UniqueCombinations(columns=columns)
instance.fit(table_data)

# Run
out = instance.transform(table_data)

# Assert
expected_out_a = pd.Series(['a', 'b', 'c'], name='a')
pd.testing.assert_series_equal(expected_out_a, out['a'])
try:
[uuid.UUID(u) for c, u in out['b#c'].items()]
except ValueError:
assert False

def test_transform_not_all_columns_provided(self):
"""Test the ``UniqueCombinations.transform`` method.
Expand All @@ -306,7 +342,7 @@ def test_transform_not_all_columns_provided(self):
with pytest.raises(MissingConstraintColumnError):
instance.transform(pd.DataFrame({'a': ['a', 'b', 'c']}))

def reverse_transform(self):
def test_reverse_transform(self):
"""Test the ``UniqueCombinations.reverse_transform`` method.
It is expected to return the original data separating the concatenated columns.
Expand All @@ -320,13 +356,18 @@ def reverse_transform(self):
must be called as well.
"""
# Setup
table_data = pd.DataFrame({
'a': ['a', 'b', 'c'],
'b': ['d', 'e', 'f'],
'c': ['g', 'h', 'i']
})
transformed_data = pd.DataFrame({
'a': ['a', 'b', 'c'],
'b#c': ['d#g', 'e#h', 'f#i']
})
columns = ['b', 'c']
instance = UniqueCombinations(columns=columns)
instance.fit(transformed_data)
instance.fit(table_data)

# Run
out = instance.reverse_transform(transformed_data)
Expand All @@ -339,6 +380,41 @@ def reverse_transform(self):
})
pd.testing.assert_frame_equal(expected_out, out)

def test_reverse_transform_non_string(self):
"""Test the ``UniqueCombinations.reverse_transform`` method with a non string column.
It is expected to return the original data separating the concatenated columns.
Input:
- Table data transformed (pandas.DataFrame)
Output:
- Original table data, with the concatenated columns separated (pandas.DataFrame)
Side effects:
- Since the ``transform`` method needs ``self._joint_column``, method ``fit``
must be called as well.
"""
# Setup
table_data = pd.DataFrame({
'a': ['a', 'b', 'c'],
'b': [1, 2, 3],
'c': ['g', 'h', 'i']
})
columns = ['b', 'c']
instance = UniqueCombinations(columns=columns)
instance.fit(table_data)

# Run
transformed_data = instance.transform(table_data)
out = instance.reverse_transform(transformed_data)

# Assert
expected_out = pd.DataFrame({
'a': ['a', 'b', 'c'],
'b': [1, 2, 3],
'c': ['g', 'h', 'i']
})
pd.testing.assert_frame_equal(expected_out, out)


class TestGreaterThan():

Expand Down

0 comments on commit 65e0408

Please sign in to comment.