Skip to content

Commit

Permalink
SDV - Add a Metadata.detect_from_dataframe function (#2222)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored and R-Palazzo committed Sep 23, 2024
1 parent d66e750 commit b444f47
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
24 changes: 24 additions & 0 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import warnings

import pandas as pd

from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
Expand Down Expand Up @@ -51,6 +53,28 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
instance._set_metadata_dict(metadata_dict, single_table_name)
return instance

@classmethod
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
"""Detect the metadata for a DataFrame.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data column names are converted to strings.
Args:
data (pandas.DataFrame):
Dictionary of table names to dataframes.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not isinstance(data, pd.DataFrame):
raise ValueError('The provided data must be a pandas DataFrame object.')

metadata = Metadata()
metadata.detect_table_from_dataframe(table_name, data)
return metadata

def _set_metadata_dict(self, metadata, single_table_name=None):
"""Set a ``metadata`` dictionary to the current instance.
Expand Down
32 changes: 31 additions & 1 deletion tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_detect_from_dataframes_multi_table():
assert metadata.to_dict() == expected_metadata


def test_detect_from_data_frames_single_table():
def test_detect_from_dataframes_single_table():
"""Test the ``detect_from_dataframes`` method works with a single table."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
Expand Down Expand Up @@ -116,6 +116,36 @@ def test_detect_from_data_frames_single_table():
assert metadata.to_dict() == expected_metadata


def test_detect_from_dataframe():
"""Test that a single table can be detected as a DataFrame."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')

metadata = Metadata.detect_from_dataframe(data['hotels'])

# Run
metadata.validate()

# Assert
expected_metadata = {
'METADATA_SPEC_VERSION': 'V1',
'tables': {
DEFAULT_TABLE_NAME: {
'columns': {
'hotel_id': {'sdtype': 'id'},
'city': {'sdtype': 'city', 'pii': True},
'state': {'sdtype': 'administrative_unit', 'pii': True},
'rating': {'sdtype': 'numerical'},
'classification': {'sdtype': 'unknown', 'pii': True},
},
'primary_key': 'hotel_id',
}
},
'relationships': [],
}
assert metadata.to_dict() == expected_metadata


def test_detect_from_csvs(tmp_path):
"""Test the ``detect_from_csvs`` method."""
# Setup
Expand Down
31 changes: 29 additions & 2 deletions tests/unit/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unittest.mock import patch
from unittest.mock import Mock, patch

import pandas as pd
import pytest

from sdv.metadata.metadata import Metadata
from tests.utils import get_multi_table_data, get_multi_table_metadata
from tests.utils import DataFrameMatcher, get_multi_table_data, get_multi_table_metadata


class TestMetadataClass:
Expand Down Expand Up @@ -537,3 +538,29 @@ def test_validate_data_no_relationships(self):
# Run and Assert
metadata.validate_data(data)
assert metadata.METADATA_SPEC_VERSION == 'V1'

@patch('sdv.metadata.metadata.Metadata')
def test_detect_from_dataframe(self, mock_metadata):
"""Test that the method calls the detection method and returns the metadata.
Expected to call ``detect_table_from_dataframe`` for the dataframe.
"""
# Setup
mock_metadata.detect_table_from_dataframe = Mock()
data = pd.DataFrame()

# Run
metadata = Metadata.detect_from_dataframe(data)

# Assert
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data)
)
assert metadata == mock_metadata.return_value

def test_detect_from_dataframe_raises_error_if_not_dataframe(self):
"""Test that the method raises an error if data isn't a DataFrame."""
# Run and assert
expected_message = 'The provided data must be a pandas DataFrame object.'
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframe(Mock())

0 comments on commit b444f47

Please sign in to comment.