From b444f47fc276a12ebfc4ae1271d8d1acb90cf7d3 Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Thu, 12 Sep 2024 10:03:22 -0500 Subject: [PATCH] SDV - Add a Metadata.detect_from_dataframe function (#2222) --- sdv/metadata/metadata.py | 24 ++++++++++++++++ tests/integration/metadata/test_metadata.py | 32 ++++++++++++++++++++- tests/unit/metadata/test_metadata.py | 31 ++++++++++++++++++-- 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 537d8bba2..8473dd3a0 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -2,6 +2,8 @@ import warnings +import pandas as pd + from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -51,6 +53,28 @@ def load_from_dict(cls, metadata_dict, single_table_name=None): instance._set_metadata_dict(metadata_dict, single_table_name) return instance + @classmethod + def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): + """Detect the metadata for a DataFrame. + + This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``. + All data column names are converted to strings. + + Args: + data (pandas.DataFrame): + Dictionary of table names to dataframes. + + Returns: + Metadata: + A new metadata object with the sdtypes detected from the data. + """ + if not isinstance(data, pd.DataFrame): + raise ValueError('The provided data must be a pandas DataFrame object.') + + metadata = Metadata() + metadata.detect_table_from_dataframe(table_name, data) + return metadata + def _set_metadata_dict(self, metadata, single_table_name=None): """Set a ``metadata`` dictionary to the current instance. diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 7efe8421d..d83544896 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -85,7 +85,7 @@ def test_detect_from_dataframes_multi_table(): assert metadata.to_dict() == expected_metadata -def test_detect_from_data_frames_single_table(): +def test_detect_from_dataframes_single_table(): """Test the ``detect_from_dataframes`` method works with a single table.""" # Setup data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') @@ -116,6 +116,36 @@ def test_detect_from_data_frames_single_table(): assert metadata.to_dict() == expected_metadata +def test_detect_from_dataframe(): + """Test that a single table can be detected as a DataFrame.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata.detect_from_dataframe(data['hotels']) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + DEFAULT_TABLE_NAME: { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + 'primary_key': 'hotel_id', + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + def test_detect_from_csvs(tmp_path): """Test the ``detect_from_csvs`` method.""" # Setup diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index da2046596..64c55999d 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -1,9 +1,10 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch +import pandas as pd import pytest from sdv.metadata.metadata import Metadata -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import DataFrameMatcher, get_multi_table_data, get_multi_table_metadata class TestMetadataClass: @@ -537,3 +538,29 @@ def test_validate_data_no_relationships(self): # Run and Assert metadata.validate_data(data) assert metadata.METADATA_SPEC_VERSION == 'V1' + + @patch('sdv.metadata.metadata.Metadata') + def test_detect_from_dataframe(self, mock_metadata): + """Test that the method calls the detection method and returns the metadata. + + Expected to call ``detect_table_from_dataframe`` for the dataframe. + """ + # Setup + mock_metadata.detect_table_from_dataframe = Mock() + data = pd.DataFrame() + + # Run + metadata = Metadata.detect_from_dataframe(data) + + # Assert + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data) + ) + assert metadata == mock_metadata.return_value + + def test_detect_from_dataframe_raises_error_if_not_dataframe(self): + """Test that the method raises an error if data isn't a DataFrame.""" + # Run and assert + expected_message = 'The provided data must be a pandas DataFrame object.' + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframe(Mock())