Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeWarning: invalid value encountered in cast. #1380

Merged
merged 3 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, FunctionError, InvalidFunctionError)
from sdv.constraints.utils import cast_to_datetime64, logit, matches_datetime_format, sigmoid
from sdv.utils import is_datetime_type
from sdv.utils import convert_to_timedelta, is_datetime_type

INEQUALITY_TO_OPERATION = {
'>': np.greater,
Expand Down Expand Up @@ -512,8 +512,8 @@ def _reverse_transform(self, table_data):

The transformation is reversed by computing an exponential of the difference value,
subtracting 1 and converting it to the original dtype. Finally, the obtained column
is added to the ``low_column_name`` column to get back the original
``high_column_name`` value.
is added to the ``low_column_name`` column to get back the original ``high_column_name``
value.

Args:
table_data (pandas.DataFrame):
Expand All @@ -528,7 +528,7 @@ def _reverse_transform(self, table_data):
diff_column = diff_column.round()

if self._is_datetime:
diff_column = pd.to_timedelta(diff_column)
diff_column = convert_to_timedelta(diff_column)

low = table_data[self._low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
Expand Down
59 changes: 21 additions & 38 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,6 @@ def cast_to_iterable(value):
return [value]


def display_tables(tables, max_rows=10, datetime_fmt='%Y-%m-%d %H:%M:%S', row=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this being deleted? Is just not used anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not used anywhere

"""Display mutiple tables side by side on a Jupyter Notebook.

Args:
tables (dict[str, DataFrame]):
``dict`` containing table names and pandas DataFrames.
max_rows (int):
Max rows to show per table. Defaults to 10.
datetime_fmt (str):
Format with which to display datetime columns.
"""
# Import here to avoid making IPython a hard dependency
from IPython.core.display import HTML

names = []
data = []
for name, table in tables.items():
table = table.copy()
for column in table.columns:
column_data = table[column]
if column_data.dtype.kind == 'M':
table[column] = column_data.dt.strftime(datetime_fmt)

names.append(f'<td style="text-align:left"><b>{name}</b></td>')
data.append(f'<td>{table.head(max_rows).to_html(index=False)}</td>')

if row:
html = f"<table><tr>{''.join(names)}</tr><tr>{''.join(data)}</tr></table>"
else:
rows = [
f'<tr>{name}</tr><tr>{table}</tr>'
for name, table in zip(names, data)
]
html = f"<table>{''.join(rows)}</table>"

return HTML(html)


def get_datetime_format(value):
"""Get the ``strftime`` format for a given ``value``.

Expand Down Expand Up @@ -150,6 +112,27 @@ def validate_datetime_format(value, datetime_format):
return True


def convert_to_timedelta(column):
"""Convert a ``pandas.Series`` to one with dtype ``timedelta``.

``pd.to_timedelta`` does not handle nans, so this function masks the nans, converts and then
reinserts them.

Args:
column (pandas.Series):
Column to convert.

Returns:
pandas.Series:
The column converted to timedeltas.
"""
nan_mask = pd.isna(column)
column[nan_mask] = 0
column = pd.to_timedelta(column)
column[nan_mask] = pd.NaT
return column


def load_data_from_csv(filepath, pandas_kwargs=None):
"""Load DataFrame from a filepath.

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import patch

import numpy as np
import pandas as pd

from sdv.utils import convert_to_timedelta
from tests.utils import SeriesMatcher


@patch('sdv.utils.pd.to_timedelta')
def test_convert_to_timedelta(to_timedelta_mock):
"""Test that nans and values are properly converted to timedeltas."""
# Setup
column = pd.Series([7200, 3600, np.nan])
to_timedelta_mock.return_value = pd.Series([
pd.Timedelta(hours=1),
pd.Timedelta(hours=2),
pd.Timedelta(hours=0)
], dtype='timedelta64[ns]')

# Run
converted_column = convert_to_timedelta(column)

# Assert
to_timedelta_mock.assert_called_with(SeriesMatcher(pd.Series([7200, 3600, 0.0])))
expected_column = pd.Series([
pd.Timedelta(hours=1),
pd.Timedelta(hours=2),
pd.NaT
], dtype='timedelta64[ns]')
pd.testing.assert_series_equal(converted_column, expected_column)
11 changes: 11 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ def __eq__(self, other):
return True


class SeriesMatcher:
"""Match a given Pandas Series in a mock function call."""

def __init__(self, series):
self.series = series

def __eq__(self, other):
pd.testing.assert_series_equal(self.series, other)
return True


def get_multi_table_metadata():
"""Return a ``MultiTableMetadata`` object to be used with tests."""
dict_metadata = {
Expand Down