Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate primary key detection functionality #2132

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,72 @@ def _determine_sdtype_for_objects(self, data):

return sdtype

def _detect_primary_key(self, data):
"""Detect the table's primary key.

This method will loop through the columns and select the first column that was detected as
an id. If there are none of those, it will pick the first unique pii column. If there is
still no primary key, it will return None. All other id columns after the first will be
reassigned to the 'unknown' sdtype.

Args:
data (pandas.DataFrame):
The data to be analyzed.

Returns:
str:
The column name of the selected primary key.

Raises:
RuntimeError:
If the sdtypes for all columns haven't been detected or set yet.
"""
original_columns = data.columns
stringified_columns = data.columns.astype(str)
data.columns = stringified_columns
Comment on lines +561 to +563
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this here and in _detect_columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so because the column names that get added to self.columns seem to be the stringified ones. So if we don't do it here then we might get a KeyError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually now that I think about it, this could be a problem for every other function that takes data and looks in the columns. Maybe we should change the dict to store the original column names instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we modified the data so the column names are the stringified ones, no? Then at the end we convert it back to the original column names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also do it at the beginning of _detect_columns and then restore them at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but the idea is eventually people will call this method on their own

for column in data.columns:
if not self.columns.get(column, {}).get('sdtype'):
raise RuntimeError(
'All columns must have sdtypes detected or set manually to detect the primary '
'key.'
)

candidates = []
first_pii_field = None
for column, column_meta in self.columns.items():
sdtype = column_meta['sdtype']
column_data = data[column]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
if sdtype == 'id':
candidates.append(column)
if len(candidates) > 1:
self.columns[column]['sdtype'] = 'unknown'
self.columns[column]['pii'] = True

elif sdtype_in_reference and first_pii_field is None and valid_potential_primary_key:
first_pii_field = column

data.columns = original_columns
if candidates:
return candidates[0]
if first_pii_field:
return first_pii_field

return None

def _detect_columns(self, data):
"""Detect the columns' sdtype and the primary key from the data.
"""Detect the columns' sdtypes from the data.

Args:
data (pandas.DataFrame):
The data to be analyzed.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
first_pii_field = None
for field in data:
column_data = data[field]
has_nan = column_data.isna().any()
valid_potential_primary_key = column_data.is_unique and not has_nan
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

Expand All @@ -571,30 +623,19 @@ def _detect_columns(self, data):
"The valid data types are: 'object', 'int', 'float', 'datetime', 'bool'."
)

# Set the first ID column we detect to be the primary key
if sdtype == 'id':
if self.primary_key is None and valid_potential_primary_key:
self.primary_key = field
else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()

if sdtype_in_reference or sdtype == 'unknown':
column_dict['pii'] = True
if sdtype_in_reference and first_pii_field is None and not has_nan:
first_pii_field = field

if sdtype == 'datetime' and dtype == 'O':
datetime_format = _get_datetime_format(column_data.iloc[:100])
column_dict['datetime_format'] = datetime_format

self.columns[field] = deepcopy(column_dict)

# When no primary key column was set, choose the first pii field
if self.primary_key is None and first_pii_field and valid_potential_primary_key:
self.primary_key = first_pii_field

self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns

Expand Down Expand Up @@ -1265,6 +1306,7 @@ def load_from_dict(cls, metadata_dict):
}
setattr(instance, f'{key}', value)

instance._primary_key_candidates = None
return instance

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ def test__validate_constraint_dict_columns_in_relationships(self):
}

metadata = SingleTableMetadata()
metadata.add_column('country_column', sdtype='country_code')
metadata.add_column('city_column', sdtype='city')
metadata.columns['country_column'] = {'sdtype': 'country_code', 'pii': True}
metadata.columns['city_column'] = {'sdtype': 'city', 'pii': True}
custom_constraint = Mock()

dp = DataProcessor(metadata)
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def test_update_columns_sdtype_in_kwargs_error(self):
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.update_columns(['col_1', 'col_2'], sdtype='numerical', pii=True)

def test_update_columns_multiple_erros(self):
def test_update_columns_multiple_errors(self):
"""Test the ``update_columns`` method.

Test that ``update_columns`` with multiple errors.
Expand Down Expand Up @@ -1249,6 +1249,23 @@ def test__detect_columns_with_error(self, mock__get_datetime_format):
instance._determine_sdtype_for_objects.assert_called_once()
mock__get_datetime_format.assert_called_once()

def test__detect_primary_key_missing_sdtypes(self):
"""The method should raise an error if not all sdtypes were detected."""
# Setup
data = pd.DataFrame({
'string_id': ['1', '2', '3', '4', '5', '6'],
'num_id': [1, 2, 3, 4, 5, 6],
})
metadata = SingleTableMetadata()
metadata.columns = {'string_id': {'sdtype': 'id'}}

# Run and Assert
message = (
'All columns must have sdtypes detected or set manually to detect the primary key.'
)
with pytest.raises(RuntimeError, match=message):
metadata._detect_primary_key(data)

def test_detect_from_dataframe_raises_error(self):
"""Test the ``detect_from_dataframe`` method.

Expand Down
Loading