Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For single-table use cases, make it frictionless to update Metadata #2228

Merged
merged 9 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,18 @@ def _convert_to_single_table(self):

return next(iter(self.tables.values()), SingleTableMetadata())

def set_sequence_index(self, table_name, column_name):
def _handle_table_name(self, table_name):
if table_name is None:
if len(self.tables) == 1:
table_name = next(iter(self.tables))
else:
raise ValueError(
'Metadata contains more than one table, please specify the `table_name`.'
)

return table_name

def set_sequence_index(self, column_name, table_name=None):
"""Set the sequence index of a table.

Args:
Expand All @@ -170,18 +181,21 @@ def set_sequence_index(self, table_name, column_name):
column_name (str):
Name of the sequence index column.
"""
table_name = self._handle_table_name(table_name)
self._validate_table_exists(table_name)
self.tables[table_name].set_sequence_index(column_name)

def set_sequence_key(self, table_name, column_name):
def set_sequence_key(self, column_name, table_name=None):
"""Set the sequence key of a table.

Args:
table_name (str):
Name of the table to set the sequence key.
column_name (str, tulple[str]):
Name (or tuple of names) of the sequence key column(s).
table_name (str):
Name of the table to set the sequence key.
Defaults to None.
"""
table_name = self._handle_table_name(table_name)
self._validate_table_exists(table_name)
self.tables[table_name].set_sequence_key(column_name)

Expand All @@ -206,3 +220,53 @@ def validate_table(self, data, table_name=None):
)

return self.validate_data({table_name: data})

def get_column_names(self, table_name=None, **kwargs):
"""Return a list of column names that match the given metadata keyword arguments."""
table_name = self._handle_table_name(table_name)
return super().get_column_names(table_name, **kwargs)

def update_column(self, column_name, table_name=None, **kwargs):
"""Update an existing column for a table in the ``Metadata``."""
table_name = self._handle_table_name(table_name)
super().update_column(table_name, column_name, **kwargs)

def update_columns(self, column_names, table_name=None, **kwargs):
"""Update the metadata of multiple columns."""
table_name = self._handle_table_name(table_name)
super().update_columns(table_name, column_names, **kwargs)

def update_columns_metadata(self, column_metadata, table_name=None):
"""Update the metadata of multiple columns."""
table_name = self._handle_table_name(table_name)
super().update_columns_metadata(table_name, column_metadata)

def add_column(self, column_name, table_name=None, **kwargs):
"""Add a column to the metadata."""
table_name = self._handle_table_name(table_name)
super().add_column(table_name, column_name, **kwargs)

def add_column_relationship(
self,
relationship_type,
column_names,
table_name=None,
):
"""Add a column relationship to the metadata."""
table_name = self._handle_table_name(table_name)
super().add_column_relationship(table_name, relationship_type, column_names)

def set_primary_key(self, column_name, table_name=None):
"""Set the primary key of a table."""
table_name = self._handle_table_name(table_name)
super().set_primary_key(table_name, column_name)

def remove_primary_key(self, table_name=None):
"""Remove the primary key of a table."""
table_name = self._handle_table_name(table_name)
super().remove_primary_key(table_name)

def add_alternate_keys(self, column_names, table_name=None):
"""Add alternate keys to a table."""
table_name = self._handle_table_name(table_name)
super().add_alternate_keys(table_name, column_names)
8 changes: 6 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,18 @@ def _detect_relationships(self, data=None):
try:
original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype']
if original_foreign_key_sdtype != 'id':
self.update_column(child_candidate, primary_key, sdtype='id')
self.update_column(
table_name=child_candidate, column_name=primary_key, sdtype='id'
)

self.add_relationship(
parent_candidate, child_candidate, primary_key, primary_key
)
except InvalidMetadataError:
self.update_column(
child_candidate, primary_key, sdtype=original_foreign_key_sdtype
table_name=child_candidate,
column_name=primary_key,
sdtype=original_foreign_key_sdtype,
)
continue

Expand Down
24 changes: 12 additions & 12 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_with_anonymized_columns(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Instance ``DataProcessor``
dp = DataProcessor(metadata._convert_to_single_table())
Expand Down Expand Up @@ -101,11 +101,11 @@ def test_with_anonymized_columns_and_primary_key(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Add primary key field
metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('adult', 'id')
metadata.add_column('id', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -159,8 +159,8 @@ def test_with_primary_key_numerical(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -198,13 +198,13 @@ def test_with_alternate_keys(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

adult_metadata.add_column('adult', 'secondary_id', sdtype='id')
adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')
adult_metadata.add_column('secondary_id', 'adult', sdtype='id')
adult_metadata.update_column('fnlwgt', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')

adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt'])
adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'], 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_localized_anonymized_columns(self):
"""Test data processor uses the default locale for anonymized columns."""
# Setup
data, metadata = download_demo('single_table', 'adult')
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA'])

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_evaluation():
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('table', 'col', sdtype='numerical')
metadata.add_column('col', 'table', sdtype='numerical')
synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm')

# Run and Assert
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from copy import deepcopy

import pytest

Expand Down Expand Up @@ -382,3 +383,82 @@ def test_multi_table_compatibility(tmp_path):
assert expected_metadata == synthesizer_2.metadata.to_dict()
for table in metadata_sample:
assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list()


params = [
('update_column', ['column_name'], {'column_name': 'has_rewards', 'sdtype': 'categorical'}),
(
'update_columns',
['column_names'],
{'column_names': ['has_rewards', 'billing_address'], 'sdtype': 'categorical'},
),
(
'update_columns_metadata',
['column_metadata'],
{'column_metadata': {'has_rewards': {'sdtype': 'categorical'}}},
),
('add_column', ['column_name'], {'column_name': 'has_rewards_2', 'sdtype': 'categorical'}),
('set_primary_key', ['column_name'], {'column_name': 'billing_address'}),
('remove_primary_key', [], {}),
(
'add_column_relationship',
['relationship_type', 'column_names'],
{'column_names': ['billing_address'], 'relationship_type': 'address'},
),
('add_alternate_keys', ['column_names'], {'column_names': ['billing_address']}),
('set_sequence_key', ['column_name'], {'column_name': 'billing_address'}),
('get_column_names', [], {'sdtype': 'datetime'}),
]


@pytest.mark.parametrize('method, args, kwargs', params)
def test_any_metadata_update_single_table(method, args, kwargs):
"""Test that any method that updates metadata works for single-table case."""
# Setup
_, metadata = download_demo('single_table', 'fake_hotel_guests')
metadata.update_column(
table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address'
)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()

# Run
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']


@pytest.mark.parametrize('method, args, kwargs', params)
def test_any_metadata_update_multi_table(method, args, kwargs):
"""Test that any method that updates metadata works for multi-table case."""
# Setup
_, metadata = download_demo('multi_table', 'fake_hotels')
metadata.update_column(
table_name='guests', column_name='billing_address', sdtype='street_address'
)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()
expected_error = re.escape(
'Metadata contains more than one table, please specify the `table_name`.'
)

# Run
with pytest.raises(ValueError, match=expected_error):
getattr(metadata, method)(*parameter, **remaining_kwargs)

parameter.append('guests')
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']
12 changes: 6 additions & 6 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes):

mock_rdt_transformers.address.RandomLocationGenerator = RandomLocationGeneratorMock
_, instance = download_demo('multi_table', 'fake_hotels')
instance.update_column('hotels', 'city', sdtype='city')
instance.update_column('hotels', 'state', sdtype='state')
instance.update_column('city', 'hotels', sdtype='city')
instance.update_column('state', 'hotels', sdtype='state')

# Run
instance.add_column_relationship('hotels', 'address', ['city', 'state'])
instance.add_column_relationship('address', ['city', 'state'], 'hotels')

# Assert
instance.validate()
Expand Down Expand Up @@ -303,9 +303,9 @@ def test_get_table_metadata():
"""Test the ``get_table_metadata`` method."""
# Setup
metadata = get_multi_table_metadata()
metadata.add_column('nesreca', 'latitude', sdtype='latitude')
metadata.add_column('nesreca', 'longitude', sdtype='longitude')
metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude'])
metadata.add_column('latitude', 'nesreca', sdtype='latitude')
metadata.add_column('longitude', 'nesreca', sdtype='longitude')
metadata.add_column_relationship('gps', ['latitude', 'longitude'], 'nesreca')

# Run
table_metadata = metadata.get_table_metadata('nesreca')
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_visualize_graph_for_multi_table():
data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
tables = {'1': data1, '2': data2}
metadata = Metadata.detect_from_dataframes(tables)
metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '1', sdtype='id')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '2', sdtype='id')
metadata.set_primary_key('\\|=/bla@#$324%^,"&*()><...', '1')
metadata.add_relationship(
'1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...'
)
Expand Down
Loading
Loading