Skip to content

Commit

Permalink
change rtol comparsition from lt to leq (#771)
Browse files Browse the repository at this point in the history
  • Loading branch information
tssbas authored Apr 28, 2022
1 parent 00238f1 commit 9808f8a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _filter_conditions(sampled, conditions, float_rtol):
column_values = sampled[column]
if column_values.dtype.kind == 'f':
distance = value * float_rtol
sampled = sampled[np.abs(column_values - value) < distance]
sampled = sampled[np.abs(column_values - value) <= distance]
sampled[column] = value
else:
sampled = sampled[column_values == value]
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,31 @@ def test__sample_conditions_graceful_reject_sampling(model):
assert len(output) == 2, 'Only expected 2 valid rows.'


@pytest.mark.parametrize('model', MODELS)
def test__sample_conditions_with_value_zero(model):
data = pd.DataFrame({
'column1': list(range(100)),
'column2': list(range(100)),
'column3': list(range(100))
})
data = data.astype(float)

conditions = [
Condition(
{'column1': 0},
num_rows=1,
),
Condition(
{'column1': 0.0},
num_rows=1,
)
]

model.fit(data)
output = model._sample_conditions(conditions, 100, None, True, None)
assert len(output) == 2, 'Expected 2 valid rows.'


def test__sample_rows_previous_rows_appended_correctly():
"""Test the ``BaseTabularModel._sample_rows`` method.
Expand Down

0 comments on commit 9808f8a

Please sign in to comment.