Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata Feature Branch #2186

Merged
merged 16 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import warnings
from collections import defaultdict
from pathlib import Path
from zipfile import ZipFile
Expand All @@ -15,8 +16,7 @@
from botocore.client import Config
from botocore.exceptions import ClientError

from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.metadata import Metadata

LOGGER = logging.getLogger(__name__)
BUCKET = 'sdv-demo-datasets'
Expand Down Expand Up @@ -104,15 +104,20 @@ def _get_data(modality, output_folder_name, in_memory_directory):
return data


def _get_metadata(modality, output_folder_name, in_memory_directory):
metadata = MultiTableMetadata() if modality == 'multi_table' else SingleTableMetadata()
def _get_metadata(output_folder_name, in_memory_directory, dataset_name):
metadata = Metadata()
if output_folder_name:
metadata_path = os.path.join(output_folder_name, METADATA_FILENAME)
metadata = metadata.load_from_json(metadata_path)
metadata = metadata.load_from_json(metadata_path, dataset_name)

else:
metadict = json.loads(in_memory_directory['metadata_v1.json'])
metadata = metadata.load_from_dict(metadict)
metadata_path = 'metadata_v2.json'
if metadata_path not in in_memory_directory:
warnings.warn(f'Metadata for {dataset_name} is missing updated version v2.')
metadata_path = 'metadata_v1.json'

metadict = json.loads(in_memory_directory[metadata_path])
metadata = metadata.load_from_dict(metadict, dataset_name)

return metadata

Expand All @@ -134,22 +139,20 @@ def download_demo(modality, dataset_name, output_folder_name=None):
tuple (data, metadata):
If ``data`` is single table or sequential, it is a DataFrame.
If ``data`` is multi table, it is a dictionary mapping table name to DataFrame.
If ``metadata`` is single table or sequential, it is a ``SingleTableMetadata`` object.
If ``metadata`` is multi table, it is a ``MultiTableMetadata`` object.
``metadata`` is of class ``Metadata`` which can represent single table or multi table.

Raises:
Error:
* If the ``dataset_name`` exists in the bucket but under a different modality.
* If the ``dataset_name`` doesn't exist in the bucket.
* If there is already a folder named ``output_folder_name``.
* If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
"""
_validate_modalities(modality)
_validate_output_folder(output_folder_name)
bytes_io = _download(modality, dataset_name)
in_memory_directory = _extract_data(bytes_io, output_folder_name)
data = _get_data(modality, output_folder_name, in_memory_directory)
metadata = _get_metadata(modality, output_folder_name, in_memory_directory)
metadata = _get_metadata(output_folder_name, in_memory_directory, dataset_name)

return data, metadata

Expand Down
10 changes: 5 additions & 5 deletions sdv/evaluation/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
Dictionary containing the real table data.
synthetic_column (dict):
Dictionary containing the synthetic table data.
metadata (MultiTableMetadata):
metadata (Metadata):
The metadata object describing the real/synthetic data.
verbose (bool):
Whether or not to print report summary and progress.
Expand All @@ -38,7 +38,7 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True):
Dictionary containing the real table data.
synthetic_column (dict):
Dictionary containing the synthetic table data.
metadata (MultiTableMetadata):
metadata (Metadata):
The metadata object describing the real/synthetic data.
verbose (bool):
Whether or not to print report summary and progress.
Expand All @@ -61,7 +61,7 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name
Dictionary containing the real table data.
synthetic_column (dict):
Dictionary containing the synthetic table data.
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata describing the data.
table_name (str):
The name of the table.
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_column_pair_plot(
Dictionary containing the real table data.
synthetic_column (dict):
Dictionary containing the synthetic table data.
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata describing the data.
table_name (str):
The name of the table.
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_cardinality_plot(
The name of the parent table.
child_foreign_key (string):
The name of the foreign key column in the child table.
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata describing the data.
plot_type (str):
The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'.
Expand Down
24 changes: 20 additions & 4 deletions sdv/evaluation/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sdmetrics.reports.single_table.quality_report import QualityReport

from sdv.errors import VisualizationUnavailableError
from sdv.metadata.metadata import Metadata


def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
Expand All @@ -16,7 +17,7 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
The table containing the real data.
synthetic_data (pd.DataFrame):
The table containing the synthetic data.
metadata (SingleTableMetadata):
metadata (Metadata):
The metadata object describing the real/synthetic data.
verbose (bool):
Whether or not to print report summary and progress.
Expand All @@ -26,7 +27,13 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
QualityReport:
Single table quality report object.
"""
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

quality_report = QualityReport()
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

quality_report.generate(real_data, synthetic_data, metadata.to_dict(), verbose)
return quality_report

Expand All @@ -39,7 +46,7 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True):
The table containing the real data.
synthetic_data (pd.DataFrame):
The table containing the synthetic data.
metadata (SingleTableMetadata):
metadata (Metadata):
The metadata object describing the real/synthetic data.
verbose (bool):
Whether or not to print report summary and progress.
Expand All @@ -50,6 +57,9 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True):
Single table diagnostic report object.
"""
diagnostic_report = DiagnosticReport()
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

diagnostic_report.generate(real_data, synthetic_data, metadata.to_dict(), verbose)
return diagnostic_report

Expand All @@ -62,7 +72,7 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
The real table data.
synthetic_data (pandas.DataFrame):
The synthetic table data.
metadata (SingleTableMetadata):
metadata (Metadata):
The table metadata.
column_name (str):
The name of the column.
Expand All @@ -76,6 +86,9 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type=
plotly.graph_objects._figure.Figure:
1D marginal distribution plot (i.e. a histogram) of the columns.
"""
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

sdtype = metadata.columns.get(column_name)['sdtype']
if plot_type is None:
if sdtype in ['datetime', 'numerical']:
Expand Down Expand Up @@ -114,7 +127,7 @@ def get_column_pair_plot(
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
metadata (SingleTableMetadata):
metadata (Metadata):
The table metadata.
column_names (list[string]):
The names of the two columns to plot.
Expand All @@ -131,6 +144,9 @@ def get_column_pair_plot(
plotly.graph_objects._figure.Figure:
2D bivariate distribution plot (i.e. a scatterplot) of the columns.
"""
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()

real_data = real_data.copy()
synthetic_data = synthetic_data.copy()
if plot_type is None:
Expand Down
9 changes: 4 additions & 5 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from sdv.metadata import MultiTableMetadata
from sdv.metadata.metadata import Metadata


class BaseLocalHandler:
Expand All @@ -25,12 +25,11 @@ def create_metadata(self, data):
Dictionary of table names to dataframes.

Returns:
MultiTableMetadata:
An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata
Metadata:
An ``sdv.metadata.Metadata`` object with the detected metadata
properties from the data.
"""
metadata = MultiTableMetadata()
metadata.detect_from_dataframes(data)
metadata = Metadata.detect_from_dataframes(data)
return metadata

def read(self):
Expand Down
17 changes: 14 additions & 3 deletions sdv/lite/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cloudpickle

from sdv.metadata.metadata import Metadata
from sdv.single_table import GaussianCopulaSynthesizer

LOGGER = logging.getLogger(__name__)
Expand All @@ -20,13 +21,19 @@
"functionality, please use the 'GaussianCopulaSynthesizer'."
)

META_DEPRECATION_MSG = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class SingleTablePreset:
"""Class for all single table synthesizer presets.

Args:
metadata (sdv.metadata.SingleTableMetadata):
``SingleTableMetadata`` instance.
metadata (sdv.metadata.Metadata):
``Metadata`` instance.
* sdv.metadata.SingleTableMetadata can be used but will be deprecated.
name (str):
The preset to use.
locales (list or str):
Expand All @@ -49,6 +56,10 @@ def __init__(self, metadata, name, locales=['en_US']):
raise ValueError(f"'name' must be one of {PRESETS}.")

self.name = name
if isinstance(metadata, Metadata):
metadata = metadata._convert_to_single_table()
else:
warnings.warn(META_DEPRECATION_MSG, FutureWarning)
if name == FAST_ML_PRESET:
self._setup_fast_preset(metadata, self.locales)

Expand All @@ -65,7 +76,7 @@ def add_constraints(self, constraints):
self._synthesizer.add_constraints(constraints)

def get_metadata(self):
"""Return the ``SingleTableMetadata`` for this synthesizer."""
"""Return the ``Metadata`` for this synthesizer."""
warnings.warn(DEPRECATION_MSG, FutureWarning)
return self._synthesizer.get_metadata()

Expand Down
2 changes: 2 additions & 0 deletions sdv/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from sdv.metadata.errors import InvalidMetadataError, MetadataNotFittedError
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.metadata import Metadata

__all__ = (
'InvalidMetadataError',
'Metadata',
'MetadataNotFittedError',
'MultiTableMetadata',
'SingleTableMetadata',
Expand Down
Loading
Loading