Skip to content

Commit

Permalink
fix integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 22, 2022
1 parent 3a9fdf9 commit 17d1e8e
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 471 deletions.
20 changes: 7 additions & 13 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None,
will generate as many rows as there were in the
data passed to the ``fit`` method.
max_tries (int):
Number of times to try sampling rows.
Number of times to try sampling discarded rows.
Defaults to 100.
batch_size_per_try (int):
The batch size to use per attempt at sampling. Defaults to 10 times
Expand Down Expand Up @@ -336,21 +336,15 @@ def _make_condition_dfs(self, conditions):
for condition in conditions:
column_values = condition.get_column_values()
condition_dataframes[tuple(column_values.keys())].append(
pd.DataFrame(
condition.get_column_values(),
index=range(condition.get_num_rows())
)
)
pd.DataFrame(column_values, index=range(condition.get_num_rows())))

return [
pd.concat(
condition_list,
ignore_index=True,
) for condition_list in condition_dataframes.values()
pd.concat(condition_list, ignore_index=True)
for condition_list in condition_dataframes.values()
]

def _conditionally_sample_rows(self, dataframe, condition, transformed_condition,
max_tries=None, batch_size_per_try=None, float_rtol=None,
max_tries=None, batch_size_per_try=None, float_rtol=0.01,
graceful_reject_sampling=True):
num_rows = len(dataframe)
sampled_rows = self._sample_batch(
Expand All @@ -374,8 +368,8 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition
raise ValueError(error)

else:
warn(f'Warning: Only able to sample {len(sampled_rows)} rows for the given '
f'conditions. To sample more rows, try increasing `max_retrues '
warn(f'Only able to sample {len(sampled_rows)} rows for the given '
f'conditions. To sample more rows, try increasing `max_tries '
f'(currently: {max_tries}) or increasing `batch_size_per_try` '
f'(currently: {batch_size_per_try}. Note that increasing these values '
f'will also increase the sampling time.')
Expand Down
Loading

0 comments on commit 17d1e8e

Please sign in to comment.