Skip to content

Commit

Permalink
RuntimeWarning: invalid value encountered in cast. (#1380)
Browse files Browse the repository at this point in the history
* fixing error by masking nans before converting to timedeltas

* fix tests

* trying to fix test again
  • Loading branch information
amontanez24 authored Apr 21, 2023
1 parent 46fc824 commit 23125a0
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 42 deletions.
8 changes: 4 additions & 4 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from sdv.constraints.errors import (
AggregateConstraintsError, ConstraintMetadataError, FunctionError, InvalidFunctionError)
from sdv.constraints.utils import cast_to_datetime64, logit, matches_datetime_format, sigmoid
from sdv.utils import is_datetime_type
from sdv.utils import convert_to_timedelta, is_datetime_type

INEQUALITY_TO_OPERATION = {
'>': np.greater,
Expand Down Expand Up @@ -512,8 +512,8 @@ def _reverse_transform(self, table_data):
The transformation is reversed by computing an exponential of the difference value,
subtracting 1 and converting it to the original dtype. Finally, the obtained column
is added to the ``low_column_name`` column to get back the original
``high_column_name`` value.
is added to the ``low_column_name`` column to get back the original ``high_column_name``
value.
Args:
table_data (pandas.DataFrame):
Expand All @@ -528,7 +528,7 @@ def _reverse_transform(self, table_data):
diff_column = diff_column.round()

if self._is_datetime:
diff_column = pd.to_timedelta(diff_column)
diff_column = convert_to_timedelta(diff_column)

low = table_data[self._low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
Expand Down
59 changes: 21 additions & 38 deletions sdv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,6 @@ def cast_to_iterable(value):
return [value]


def display_tables(tables, max_rows=10, datetime_fmt='%Y-%m-%d %H:%M:%S', row=True):
"""Display mutiple tables side by side on a Jupyter Notebook.
Args:
tables (dict[str, DataFrame]):
``dict`` containing table names and pandas DataFrames.
max_rows (int):
Max rows to show per table. Defaults to 10.
datetime_fmt (str):
Format with which to display datetime columns.
"""
# Import here to avoid making IPython a hard dependency
from IPython.core.display import HTML

names = []
data = []
for name, table in tables.items():
table = table.copy()
for column in table.columns:
column_data = table[column]
if column_data.dtype.kind == 'M':
table[column] = column_data.dt.strftime(datetime_fmt)

names.append(f'<td style="text-align:left"><b>{name}</b></td>')
data.append(f'<td>{table.head(max_rows).to_html(index=False)}</td>')

if row:
html = f"<table><tr>{''.join(names)}</tr><tr>{''.join(data)}</tr></table>"
else:
rows = [
f'<tr>{name}</tr><tr>{table}</tr>'
for name, table in zip(names, data)
]
html = f"<table>{''.join(rows)}</table>"

return HTML(html)


def get_datetime_format(value):
"""Get the ``strftime`` format for a given ``value``.
Expand Down Expand Up @@ -150,6 +112,27 @@ def validate_datetime_format(value, datetime_format):
return True


def convert_to_timedelta(column):
"""Convert a ``pandas.Series`` to one with dtype ``timedelta``.
``pd.to_timedelta`` does not handle nans, so this function masks the nans, converts and then
reinserts them.
Args:
column (pandas.Series):
Column to convert.
Returns:
pandas.Series:
The column converted to timedeltas.
"""
nan_mask = pd.isna(column)
column[nan_mask] = 0
column = pd.to_timedelta(column)
column[nan_mask] = pd.NaT
return column


def load_data_from_csv(filepath, pandas_kwargs=None):
"""Load DataFrame from a filepath.
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import patch

import numpy as np
import pandas as pd

from sdv.utils import convert_to_timedelta
from tests.utils import SeriesMatcher


@patch('sdv.utils.pd.to_timedelta')
def test_convert_to_timedelta(to_timedelta_mock):
"""Test that nans and values are properly converted to timedeltas."""
# Setup
column = pd.Series([7200, 3600, np.nan])
to_timedelta_mock.return_value = pd.Series([
pd.Timedelta(hours=1),
pd.Timedelta(hours=2),
pd.Timedelta(hours=0)
], dtype='timedelta64[ns]')

# Run
converted_column = convert_to_timedelta(column)

# Assert
to_timedelta_mock.assert_called_with(SeriesMatcher(pd.Series([7200, 3600, 0.0])))
expected_column = pd.Series([
pd.Timedelta(hours=1),
pd.Timedelta(hours=2),
pd.NaT
], dtype='timedelta64[ns]')
pd.testing.assert_series_equal(converted_column, expected_column)
11 changes: 11 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ def __eq__(self, other):
return True


class SeriesMatcher:
"""Match a given Pandas Series in a mock function call."""

def __init__(self, series):
self.series = series

def __eq__(self, other):
pd.testing.assert_series_equal(self.series, other)
return True


def get_multi_table_metadata():
"""Return a ``MultiTableMetadata`` object to be used with tests."""
dict_metadata = {
Expand Down

0 comments on commit 23125a0

Please sign in to comment.