Skip to content

Commit

Permalink
Add warnings that were suggested from metadata bughunt (#2203)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored and amontanez24 committed Sep 27, 2024
1 parent 7bc31f7 commit 61ba90c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
4 changes: 4 additions & 0 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def _set_metadata_dict(self, metadata, single_table_name=None):
else:
if single_table_name is None:
single_table_name = self.DEFAULT_SINGLE_TABLE_NAME
warnings.warn(
'No table name was provided to metadata containing only one table. '
f'Assigning name: {single_table_name}'
)
self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata)

def _get_single_table_name(self):
Expand Down
11 changes: 8 additions & 3 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ def update_columns(self, table_name, column_names, **kwargs):
**kwargs:
Any key word arguments that describe metadata for the columns.
"""
if not isinstance(column_names, list):
raise InvalidMetadataError('Please pass in a list to column_names arg.')
self._validate_table_exists(table_name)
table = self.tables.get(table_name)
table.update_columns(column_names, **kwargs)
Expand Down Expand Up @@ -832,8 +834,8 @@ def validate_data(self, data):
* all foreign keys belong to a primay key
Args:
data (pd.DataFrame):
The data to validate.
data (dict):
A dictionary of table names to pd.DataFrames.
Raises:
InvalidDataError:
Expand All @@ -843,6 +845,9 @@ def validate_data(self, data):
A warning is being raised if ``datetime_format`` is missing from a column represented
as ``object`` in the dataframe and its sdtype is ``datetime``.
"""
if not isinstance(data, dict):
raise InvalidMetadataError('Please pass in a dictionary mapping tables to dataframes.')

errors = []
errors += self._validate_missing_tables(data)
errors += self._validate_all_tables(data)
Expand Down Expand Up @@ -880,7 +885,7 @@ def get_column_names(self, table_name, **kwargs):
Args:
table_name (str):
The name of the table to get column names for.s
The name of the table to get column names for.
**kwargs:
Metadata keywords to filter on, for example sdtype='id' or pii=True.
Expand Down
7 changes: 4 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def _set_temp_numpy_seed(self):
def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = {'locales': self.locales}
synthesizer_parameters.update(self._table_parameters.get(table_name, {}))
synthesizer_parameters = self._table_parameters.get(table_name, {})
metadata_dict = {'tables': {table_name: table_metadata.to_dict()}}
metadata = Metadata.load_from_dict(metadata_dict)
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, **synthesizer_parameters
metadata=metadata, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name

Expand Down
39 changes: 39 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3160,3 +3160,42 @@ def test_anonymize(self, mock_load):
'parent_primary_key': 'col1',
'child_foreign_key': 'col2',
}

def test_update_columns_no_list_error(self):
"""Test that ``update_columns`` only takes in list and that an error is thrown."""
# Setup
metadata = MultiTableMetadata()
metadata.add_table('table')
metadata.add_column('table', 'col1', sdtype='numerical')

error_msg = re.escape('Please pass in a list to column_names arg.')
# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
metadata.update_columns('table', 'col1', sdtype='categorical')

def test_validate_data_without_dict(self):
"""Test that ``validate_data`` only takes in dict and that an error is thrown otherwise."""
# Setup
metadata = MultiTableMetadata.load_from_dict({
'tables': {
'table_1': {
'columns': {
'col_1': {'sdtype': 'numerical'},
'col_2': {'sdtype': 'categorical'},
'latitude': {'sdtype': 'latitude'},
'longitude': {'sdtype': 'longitude'},
}
}
}
})
data = pd.DataFrame({
'col_1': [1, 2, 3],
'col_2': ['a', 'b', 'c'],
'latitude': [1, 2, 3],
'longitude': [1, 2, 3],
})
error_msg = re.escape('Please pass in a dictionary mapping tables to dataframes.')

# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
metadata.validate_data(data)

0 comments on commit 61ba90c

Please sign in to comment.