-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
739cc86
commit b8cdb1d
Showing
4 changed files
with
838 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Metadata.""" | ||
|
||
from pathlib import Path | ||
|
||
from sdv.metadata.multi_table import MultiTableMetadata | ||
from sdv.metadata.single_table import SingleTableMetadata | ||
from sdv.metadata.utils import read_json | ||
|
||
|
||
class Metadata(MultiTableMetadata): | ||
"""Metadata class that handles all metadata.""" | ||
|
||
METADATA_SPEC_VERSION = 'V1' | ||
|
||
@classmethod | ||
def load_from_json(cls, filepath): | ||
"""Create a ``Metadata`` instance from a ``json`` file. | ||
Args: | ||
filepath (str): | ||
String that represents the ``path`` to the ``json`` file. | ||
Raises: | ||
- An ``Error`` if the path does not exist. | ||
- An ``Error`` if the ``json`` file does not contain the ``METADATA_SPEC_VERSION``. | ||
Returns: | ||
A ``Metadata`` instance. | ||
""" | ||
filename = Path(filepath).stem | ||
metadata = read_json(filepath) | ||
return cls.load_from_dict(metadata, filename) | ||
|
||
@classmethod | ||
def load_from_dict(cls, metadata_dict, single_table_name=None): | ||
"""Create a ``Metadata`` instance from a python ``dict``. | ||
Args: | ||
metadata_dict (dict): | ||
Python dictionary representing a ``MultiTableMetadata`` | ||
or ``SingleTableMetadata`` object. | ||
single_table_name (string): | ||
If the python dictionary represents a ``SingleTableMetadata`` then | ||
this arg is used for the name of the table. | ||
Returns: | ||
Instance of ``Metadata``. | ||
""" | ||
instance = cls() | ||
instance._set_metadata_dict(metadata_dict, single_table_name) | ||
return instance | ||
|
||
def _set_metadata_dict(self, metadata, single_table_name=None): | ||
"""Set a ``metadata`` dictionary to the current instance. | ||
Checks to see if the metadata is in the ``SingleTableMetadata`` or | ||
``MultiTableMetadata`` format and converts it to a standard | ||
``MultiTableMetadata`` format if necessary. | ||
Args: | ||
metadata (dict): | ||
Python dictionary representing a ``MultiTableMetadata`` or | ||
``SingleTableMetadata`` object. | ||
""" | ||
is_multi_table = 'tables' in metadata | ||
|
||
if is_multi_table: | ||
super()._set_metadata_dict(metadata) | ||
else: | ||
if single_table_name is None: | ||
single_table_name = 'default_table_name' | ||
|
||
self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from sdv.datasets.demo import download_demo | ||
from sdv.metadata.metadata import Metadata | ||
|
||
|
||
def test_metadata(): | ||
"""Test ``MultiTableMetadata``.""" | ||
# Create an instance | ||
instance = Metadata() | ||
|
||
# To dict | ||
result = instance.to_dict() | ||
|
||
# Assert | ||
assert result == {'tables': {}, 'relationships': [], 'METADATA_SPEC_VERSION': 'V1'} | ||
assert instance.tables == {} | ||
assert instance.relationships == [] | ||
|
||
|
||
def test_detect_from_dataframes_multi_table(): | ||
"""Test the ``detect_from_dataframes`` method works with multi-table.""" | ||
# Setup | ||
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') | ||
|
||
metadata = Metadata() | ||
|
||
# Run | ||
metadata.detect_from_dataframes(real_data) | ||
|
||
# Assert | ||
metadata.update_column( | ||
table_name='hotels', | ||
column_name='classification', | ||
sdtype='categorical', | ||
) | ||
|
||
expected_metadata = { | ||
'tables': { | ||
'hotels': { | ||
'columns': { | ||
'hotel_id': {'sdtype': 'id'}, | ||
'city': {'sdtype': 'city', 'pii': True}, | ||
'state': {'sdtype': 'administrative_unit', 'pii': True}, | ||
'rating': {'sdtype': 'numerical'}, | ||
'classification': {'sdtype': 'categorical'}, | ||
}, | ||
'primary_key': 'hotel_id', | ||
}, | ||
'guests': { | ||
'columns': { | ||
'guest_email': {'sdtype': 'email', 'pii': True}, | ||
'hotel_id': {'sdtype': 'id'}, | ||
'has_rewards': {'sdtype': 'categorical'}, | ||
'room_type': {'sdtype': 'categorical'}, | ||
'amenities_fee': {'sdtype': 'numerical'}, | ||
'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, | ||
'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, | ||
'room_rate': {'sdtype': 'numerical'}, | ||
'billing_address': {'sdtype': 'unknown', 'pii': True}, | ||
'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, | ||
}, | ||
'primary_key': 'guest_email', | ||
}, | ||
}, | ||
'relationships': [ | ||
{ | ||
'parent_table_name': 'hotels', | ||
'child_table_name': 'guests', | ||
'parent_primary_key': 'hotel_id', | ||
'child_foreign_key': 'hotel_id', | ||
} | ||
], | ||
'METADATA_SPEC_VERSION': 'V1', | ||
} | ||
assert metadata.to_dict() == expected_metadata | ||
|
||
|
||
def test_detect_from_data_frames_single_table(): | ||
"""Test the ``detect_from_dataframes`` method works with a single table.""" | ||
# Setup | ||
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') | ||
|
||
metadata = Metadata() | ||
metadata.detect_from_dataframes({'table_1': data['hotels']}) | ||
|
||
# Run | ||
metadata.validate() | ||
|
||
# Assert | ||
expected_metadata = { | ||
'METADATA_SPEC_VERSION': 'V1', | ||
'tables': { | ||
'table_1': { | ||
'columns': { | ||
'hotel_id': {'sdtype': 'id'}, | ||
'city': {'sdtype': 'city', 'pii': True}, | ||
'state': {'sdtype': 'administrative_unit', 'pii': True}, | ||
'rating': {'sdtype': 'numerical'}, | ||
'classification': {'sdtype': 'unknown', 'pii': True}, | ||
}, | ||
'primary_key': 'hotel_id', | ||
} | ||
}, | ||
'relationships': [], | ||
} | ||
assert metadata.to_dict() == expected_metadata | ||
|
||
|
||
def test_detect_from_csvs(tmp_path): | ||
"""Test the ``detect_from_csvs`` method.""" | ||
# Setup | ||
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') | ||
|
||
metadata = Metadata() | ||
|
||
for table_name, dataframe in real_data.items(): | ||
csv_path = tmp_path / f'{table_name}.csv' | ||
dataframe.to_csv(csv_path, index=False) | ||
|
||
# Run | ||
metadata.detect_from_csvs(folder_name=tmp_path) | ||
|
||
# Assert | ||
metadata.update_column( | ||
table_name='hotels', | ||
column_name='classification', | ||
sdtype='categorical', | ||
) | ||
|
||
expected_metadata = { | ||
'tables': { | ||
'hotels': { | ||
'columns': { | ||
'hotel_id': {'sdtype': 'id'}, | ||
'city': {'sdtype': 'city', 'pii': True}, | ||
'state': {'sdtype': 'administrative_unit', 'pii': True}, | ||
'rating': {'sdtype': 'numerical'}, | ||
'classification': {'sdtype': 'categorical'}, | ||
}, | ||
'primary_key': 'hotel_id', | ||
}, | ||
'guests': { | ||
'columns': { | ||
'guest_email': {'sdtype': 'email', 'pii': True}, | ||
'hotel_id': {'sdtype': 'id'}, | ||
'has_rewards': {'sdtype': 'categorical'}, | ||
'room_type': {'sdtype': 'categorical'}, | ||
'amenities_fee': {'sdtype': 'numerical'}, | ||
'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, | ||
'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, | ||
'room_rate': {'sdtype': 'numerical'}, | ||
'billing_address': {'sdtype': 'unknown', 'pii': True}, | ||
'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, | ||
}, | ||
'primary_key': 'guest_email', | ||
}, | ||
}, | ||
'relationships': [ | ||
{ | ||
'parent_table_name': 'hotels', | ||
'child_table_name': 'guests', | ||
'parent_primary_key': 'hotel_id', | ||
'child_foreign_key': 'hotel_id', | ||
} | ||
], | ||
'METADATA_SPEC_VERSION': 'V1', | ||
} | ||
|
||
assert metadata.to_dict() == expected_metadata | ||
|
||
|
||
def test_detect_table_from_csv(tmp_path): | ||
"""Test the ``detect_table_from_csv`` method.""" | ||
# Setup | ||
real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') | ||
|
||
metadata = Metadata() | ||
|
||
for table_name, dataframe in real_data.items(): | ||
csv_path = tmp_path / f'{table_name}.csv' | ||
dataframe.to_csv(csv_path, index=False) | ||
|
||
# Run | ||
metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv') | ||
|
||
# Assert | ||
metadata.update_column( | ||
table_name='hotels', | ||
column_name='city', | ||
sdtype='categorical', | ||
) | ||
metadata.update_column( | ||
table_name='hotels', | ||
column_name='state', | ||
sdtype='categorical', | ||
) | ||
metadata.update_column( | ||
table_name='hotels', | ||
column_name='classification', | ||
sdtype='categorical', | ||
) | ||
expected_metadata = { | ||
'tables': { | ||
'hotels': { | ||
'columns': { | ||
'hotel_id': {'sdtype': 'id'}, | ||
'city': {'sdtype': 'categorical'}, | ||
'state': {'sdtype': 'categorical'}, | ||
'rating': {'sdtype': 'numerical'}, | ||
'classification': {'sdtype': 'categorical'}, | ||
}, | ||
'primary_key': 'hotel_id', | ||
} | ||
}, | ||
'relationships': [], | ||
'METADATA_SPEC_VERSION': 'V1', | ||
} | ||
|
||
assert metadata.to_dict() == expected_metadata |
Oops, something went wrong.