Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 10, 2022
1 parent a674f17 commit be49b8e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
6 changes: 3 additions & 3 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import numpy as np
import pandas as pd

from sdv.errors import ConstraintsNotMetError
from sdv.metadata import Table
from sdv.sampling import Condition

LOGGER = logging.getLogger(__name__)
COND_IDX = str(uuid.uuid4())
Expand Down Expand Up @@ -464,7 +464,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
for column in conditions.columns:
if column not in self._metadata.get_fields():
raise ValueError(f'Error: Unexpected column name `{column}`. '
f'Use a column name that was present in the original data.')
f'Use a column name that was present in the original data.')

try:
transformed_conditions = self._metadata.transform(conditions, on_missing_column='drop')
Expand Down Expand Up @@ -524,7 +524,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
return all_sampled_rows

def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
randomize_samples=True):
randomize_samples=True):
"""Sample rows from this table with the given conditions.
Args:
Expand Down
27 changes: 12 additions & 15 deletions tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sdv.metadata.table import Table
from sdv.sampling import Condition
from sdv.tabular.base import BaseTabularModel, COND_IDX
from sdv.tabular.base import COND_IDX, BaseTabularModel
from sdv.tabular.copulagan import CopulaGAN
from sdv.tabular.copulas import GaussianCopula
from sdv.tabular.ctgan import CTGAN, TVAE
Expand Down Expand Up @@ -219,8 +219,8 @@ def test_sample_conditions_with_multiple_conditions(self):
})

gaussian_copula._make_conditions_df.return_value = [
pd.DataFrame([condition_values1]*2),
pd.DataFrame([condition_values2]*3),
pd.DataFrame([condition_values1] * 2),
pd.DataFrame([condition_values2] * 3),
]
gaussian_copula._sample_with_conditions.side_effect = [
sampled1,
Expand All @@ -232,8 +232,8 @@ def test_sample_conditions_with_multiple_conditions(self):

# Asserts
gaussian_copula._sample_with_conditions.assert_has_calls([
call(DataFrameMatcher(pd.DataFrame([condition_values1]*2)), 100, None),
call(DataFrameMatcher(pd.DataFrame([condition_values2]*3)), 100, None),
call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, None),
call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, None),
])
pd.testing.assert_frame_equal(out, expected)

Expand Down Expand Up @@ -287,10 +287,10 @@ def test_conditional_sampling_graceful_reject_sampling(model):
})

conditions = [
Condition({
'column1': "this is not used"
},
num_rows=5)
Condition(
{'column1': "this is not used"},
num_rows=5,
)
]

model._sample_batch = Mock()
Expand Down Expand Up @@ -429,9 +429,6 @@ def test_sample_batches_transform_conditions_correctly():
})

condition_values = [25, 25, 25, 30, 30]
conditions = {
'column1': condition_values,
}
conditions_series = pd.Series([25, 25, 25, 30, 30], name='column1')
model._sample_batch = Mock()
expected_outputs = [
Expand Down Expand Up @@ -585,7 +582,7 @@ def test__make_conditions_df_with_multiple_conditions_same_column(model):
Condition(column_values=column_values1, num_rows=2),
Condition(column_values=column_values2, num_rows=3),
]
expected_conditions = pd.DataFrame([column_values1]*2 + [column_values2]*3)
expected_conditions = pd.DataFrame([column_values1] * 2 + [column_values2] * 3)

# Run
result_conditions_list = model._make_conditions_df(conditions=conditions)
Expand Down Expand Up @@ -618,8 +615,8 @@ def test__make_conditions_df_with_multiple_conditions_different_columns(model):
Condition(column_values=column_values1, num_rows=2),
Condition(column_values=column_values2, num_rows=3),
]
expected_conditions1 = pd.DataFrame([column_values1]*2)
expected_conditions2 = pd.DataFrame([column_values2]*3)
expected_conditions1 = pd.DataFrame([column_values1] * 2)
expected_conditions2 = pd.DataFrame([column_values2] * 3)

# Run
result_conditions_list = model._make_conditions_df(conditions=conditions)
Expand Down

0 comments on commit be49b8e

Please sign in to comment.