Skip to content

Commit

Permalink
For single-table use cases, make it frictionless to update Metadata (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Sep 26, 2024
1 parent b647d13 commit a100edc
Show file tree
Hide file tree
Showing 25 changed files with 439 additions and 277 deletions.
72 changes: 68 additions & 4 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,18 @@ def _convert_to_single_table(self):

return next(iter(self.tables.values()), SingleTableMetadata())

def set_sequence_index(self, table_name, column_name):
def _handle_table_name(self, table_name):
if table_name is None:
if len(self.tables) == 1:
table_name = next(iter(self.tables))
else:
raise ValueError(
'Metadata contains more than one table, please specify the `table_name`.'
)

return table_name

def set_sequence_index(self, column_name, table_name=None):
"""Set the sequence index of a table.
Args:
Expand All @@ -170,18 +181,21 @@ def set_sequence_index(self, table_name, column_name):
column_name (str):
Name of the sequence index column.
"""
table_name = self._handle_table_name(table_name)
self._validate_table_exists(table_name)
self.tables[table_name].set_sequence_index(column_name)

def set_sequence_key(self, table_name, column_name):
def set_sequence_key(self, column_name, table_name=None):
"""Set the sequence key of a table.
Args:
table_name (str):
Name of the table to set the sequence key.
column_name (str, tulple[str]):
Name (or tuple of names) of the sequence key column(s).
table_name (str):
Name of the table to set the sequence key.
Defaults to None.
"""
table_name = self._handle_table_name(table_name)
self._validate_table_exists(table_name)
self.tables[table_name].set_sequence_key(column_name)

Expand All @@ -206,3 +220,53 @@ def validate_table(self, data, table_name=None):
)

return self.validate_data({table_name: data})

def get_column_names(self, table_name=None, **kwargs):
"""Return a list of column names that match the given metadata keyword arguments."""
table_name = self._handle_table_name(table_name)
return super().get_column_names(table_name, **kwargs)

def update_column(self, column_name, table_name=None, **kwargs):
"""Update an existing column for a table in the ``Metadata``."""
table_name = self._handle_table_name(table_name)
super().update_column(table_name, column_name, **kwargs)

def update_columns(self, column_names, table_name=None, **kwargs):
"""Update the metadata of multiple columns."""
table_name = self._handle_table_name(table_name)
super().update_columns(table_name, column_names, **kwargs)

def update_columns_metadata(self, column_metadata, table_name=None):
"""Update the metadata of multiple columns."""
table_name = self._handle_table_name(table_name)
super().update_columns_metadata(table_name, column_metadata)

def add_column(self, column_name, table_name=None, **kwargs):
"""Add a column to the metadata."""
table_name = self._handle_table_name(table_name)
super().add_column(table_name, column_name, **kwargs)

def add_column_relationship(
self,
relationship_type,
column_names,
table_name=None,
):
"""Add a column relationship to the metadata."""
table_name = self._handle_table_name(table_name)
super().add_column_relationship(table_name, relationship_type, column_names)

def set_primary_key(self, column_name, table_name=None):
"""Set the primary key of a table."""
table_name = self._handle_table_name(table_name)
super().set_primary_key(table_name, column_name)

def remove_primary_key(self, table_name=None):
"""Remove the primary key of a table."""
table_name = self._handle_table_name(table_name)
super().remove_primary_key(table_name)

def add_alternate_keys(self, column_names, table_name=None):
"""Add alternate keys to a table."""
table_name = self._handle_table_name(table_name)
super().add_alternate_keys(table_name, column_names)
8 changes: 6 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,18 @@ def _detect_relationships(self, data=None):
try:
original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype']
if original_foreign_key_sdtype != 'id':
self.update_column(child_candidate, primary_key, sdtype='id')
self.update_column(
table_name=child_candidate, column_name=primary_key, sdtype='id'
)

self.add_relationship(
parent_candidate, child_candidate, primary_key, primary_key
)
except InvalidMetadataError:
self.update_column(
child_candidate, primary_key, sdtype=original_foreign_key_sdtype
table_name=child_candidate,
column_name=primary_key,
sdtype=original_foreign_key_sdtype,
)
continue

Expand Down
3 changes: 2 additions & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def _set_temp_numpy_seed(self):
def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
synthesizer_parameters = {'locales': self.locales}
synthesizer_parameters.update(self._table_parameters.get(table_name, {}))
metadata_dict = {'tables': {table_name: table_metadata.to_dict()}}
metadata = Metadata.load_from_dict(metadata_dict)
self._table_synthesizers[table_name] = self._synthesizer(
Expand Down
24 changes: 12 additions & 12 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_with_anonymized_columns(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Instance ``DataProcessor``
dp = DataProcessor(metadata._convert_to_single_table())
Expand Down Expand Up @@ -101,11 +101,11 @@ def test_with_anonymized_columns_and_primary_key(self):
data, metadata = download_demo('single_table', 'adult')

# Add anonymized field
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

# Add primary key field
metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('adult', 'id')
metadata.add_column('id', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -159,8 +159,8 @@ def test_with_primary_key_numerical(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -198,13 +198,13 @@ def test_with_alternate_keys(self):
adult_metadata = Metadata.detect_from_dataframes({'adult': data})

# Add primary key field
adult_metadata.add_column('adult', 'id', sdtype='id')
adult_metadata.set_primary_key('adult', 'id')
adult_metadata.add_column('id', 'adult', sdtype='id')
adult_metadata.set_primary_key('id', 'adult')

adult_metadata.add_column('adult', 'secondary_id', sdtype='id')
adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')
adult_metadata.add_column('secondary_id', 'adult', sdtype='id')
adult_metadata.update_column('fnlwgt', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]')

adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt'])
adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'], 'adult')

# Add id
size = len(data)
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_localized_anonymized_columns(self):
"""Test data processor uses the default locale for anonymized columns."""
# Setup
data, metadata = download_demo('single_table', 'adult')
metadata.update_column('adult', 'occupation', sdtype='job', pii=True)
metadata.update_column('occupation', 'adult', sdtype='job', pii=True)

dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA'])

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_evaluation():
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('table', 'col', sdtype='numerical')
metadata.add_column('col', 'table', sdtype='numerical')
synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm')

# Run and Assert
Expand Down
130 changes: 80 additions & 50 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from copy import deepcopy

import pytest

Expand Down Expand Up @@ -241,56 +242,6 @@ def test_detect_from_csvs(tmp_path):
assert metadata.to_dict() == expected_metadata


def test_detect_table_from_csv(tmp_path):
"""Test the ``detect_table_from_csv`` method."""
# Setup
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata()

for table_name, dataframe in real_data.items():
csv_path = tmp_path / f'{table_name}.csv'
dataframe.to_csv(csv_path, index=False)

# Run
metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv')

# Assert
metadata.update_column(
table_name='hotels',
column_name='city',
sdtype='categorical',
)
metadata.update_column(
table_name='hotels',
column_name='state',
sdtype='categorical',
)
metadata.update_column(
table_name='hotels',
column_name='classification',
sdtype='categorical',
)
expected_metadata = {
'tables': {
'hotels': {
'columns': {
'hotel_id': {'sdtype': 'id'},
'city': {'sdtype': 'categorical'},
'state': {'sdtype': 'categorical'},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'categorical'},
},
'primary_key': 'hotel_id',
}
},
'relationships': [],
'METADATA_SPEC_VERSION': 'V1',
}

assert metadata.to_dict() == expected_metadata


def test_single_table_compatibility(tmp_path):
"""Test if SingleTableMetadata still has compatibility with single table synthesizers."""
# Setup
Expand Down Expand Up @@ -432,3 +383,82 @@ def test_multi_table_compatibility(tmp_path):
assert expected_metadata == synthesizer_2.metadata.to_dict()
for table in metadata_sample:
assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list()


params = [
('update_column', ['column_name'], {'column_name': 'has_rewards', 'sdtype': 'categorical'}),
(
'update_columns',
['column_names'],
{'column_names': ['has_rewards', 'billing_address'], 'sdtype': 'categorical'},
),
(
'update_columns_metadata',
['column_metadata'],
{'column_metadata': {'has_rewards': {'sdtype': 'categorical'}}},
),
('add_column', ['column_name'], {'column_name': 'has_rewards_2', 'sdtype': 'categorical'}),
('set_primary_key', ['column_name'], {'column_name': 'billing_address'}),
('remove_primary_key', [], {}),
(
'add_column_relationship',
['relationship_type', 'column_names'],
{'column_names': ['billing_address'], 'relationship_type': 'address'},
),
('add_alternate_keys', ['column_names'], {'column_names': ['billing_address']}),
('set_sequence_key', ['column_name'], {'column_name': 'billing_address'}),
('get_column_names', [], {'sdtype': 'datetime'}),
]


@pytest.mark.parametrize('method, args, kwargs', params)
def test_any_metadata_update_single_table(method, args, kwargs):
"""Test that any method that updates metadata works for single-table case."""
# Setup
_, metadata = download_demo('single_table', 'fake_hotel_guests')
metadata.update_column(
table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address'
)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()

# Run
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']


@pytest.mark.parametrize('method, args, kwargs', params)
def test_any_metadata_update_multi_table(method, args, kwargs):
"""Test that any method that updates metadata works for multi-table case."""
# Setup
_, metadata = download_demo('multi_table', 'fake_hotels')
metadata.update_column(
table_name='guests', column_name='billing_address', sdtype='street_address'
)
parameter = [kwargs[arg] for arg in args]
remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args}
metadata_before = deepcopy(metadata).to_dict()
expected_error = re.escape(
'Metadata contains more than one table, please specify the `table_name`.'
)

# Run
with pytest.raises(ValueError, match=expected_error):
getattr(metadata, method)(*parameter, **remaining_kwargs)

parameter.append('guests')
result = getattr(metadata, method)(*parameter, **remaining_kwargs)

# Assert
expected_dict = metadata.to_dict()
if method != 'get_column_names':
assert expected_dict != metadata_before
else:
assert result == ['checkin_date', 'checkout_date']
12 changes: 6 additions & 6 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes):

mock_rdt_transformers.address.RandomLocationGenerator = RandomLocationGeneratorMock
_, instance = download_demo('multi_table', 'fake_hotels')
instance.update_column('hotels', 'city', sdtype='city')
instance.update_column('hotels', 'state', sdtype='state')
instance.update_column('city', 'hotels', sdtype='city')
instance.update_column('state', 'hotels', sdtype='state')

# Run
instance.add_column_relationship('hotels', 'address', ['city', 'state'])
instance.add_column_relationship('address', ['city', 'state'], 'hotels')

# Assert
instance.validate()
Expand Down Expand Up @@ -303,9 +303,9 @@ def test_get_table_metadata():
"""Test the ``get_table_metadata`` method."""
# Setup
metadata = get_multi_table_metadata()
metadata.add_column('nesreca', 'latitude', sdtype='latitude')
metadata.add_column('nesreca', 'longitude', sdtype='longitude')
metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude'])
metadata.add_column('latitude', 'nesreca', sdtype='latitude')
metadata.add_column('longitude', 'nesreca', sdtype='longitude')
metadata.add_column_relationship('gps', ['latitude', 'longitude'], 'nesreca')

# Run
table_metadata = metadata.get_table_metadata('nesreca')
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/metadata/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_visualize_graph_for_multi_table():
data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']})
tables = {'1': data1, '2': data2}
metadata = Metadata.detect_from_dataframes(tables)
metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id')
metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '1', sdtype='id')
metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '2', sdtype='id')
metadata.set_primary_key('\\|=/bla@#$324%^,"&*()><...', '1')
metadata.add_relationship(
'1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...'
)
Expand Down
Loading

0 comments on commit a100edc

Please sign in to comment.