Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable batch sampling #709

Merged
merged 6 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 120 additions & 19 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Base Class for tabular models."""

import functools
import logging
import math
import os
import pickle
import uuid
from collections import defaultdict
from warnings import warn

import numpy as np
import pandas as pd
from tqdm import tqdm

from sdv.errors import ConstraintsNotMetError
from sdv.metadata import Table
Expand Down Expand Up @@ -278,7 +282,8 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
return sampled, num_rows

def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None,
conditions=None, transformed_conditions=None, float_rtol=0.01):
conditions=None, transformed_conditions=None, float_rtol=0.01,
progress_bar=None, output_file_path=None):
"""Sample a batch of rows with the given conditions.

This will enter a reject-sampling loop in which rows will be sampled until
Expand Down Expand Up @@ -314,6 +319,12 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None,
The dictionary of conditioning values transformed to the model format.
float_rtol (float):
Maximum tolerance when considering a float match.
progress_bar (tqdm.tqdm or None):
The progress bar to update when sampling. If None, a new tqdm progress
bar will be created.
output_file_path (str or None):
The file to periodically write sampled rows to. If None, does not write
rows anywhere.

Returns:
pandas.DataFrame:
Expand All @@ -322,21 +333,42 @@ def _sample_batch(self, num_rows=None, max_tries=100, batch_size_per_try=None,
if not batch_size_per_try:
batch_size_per_try = num_rows * 10

sampled, num_valid = self._sample_rows(
num_rows, conditions, transformed_conditions, float_rtol)
if not progress_bar:
progress_bar = tqdm(total=num_rows)

counter = 0
num_valid = 0
prev_num_valid = None
remaining = num_rows
sampled = pd.DataFrame()

counter = 1
while num_valid < num_rows:
if counter >= max_tries:
break

remaining = num_rows - num_valid

LOGGER.info(f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows')
prev_num_valid = num_valid
sampled, num_valid = self._sample_rows(
batch_size_per_try, conditions, transformed_conditions, float_rtol, sampled,
)

num_increase = min(num_valid - prev_num_valid, remaining)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case would remaining be less than num_valid - prev_num_valid

Copy link
Contributor Author

@katxiao katxiao Feb 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are sampling up to batch_size_per_try rows, we could sample more than the number of remaining rows. For example, if we wish to sample 100 rows and we've already sampled 90 valid rows, but our batch size is 50, then we could sample something like 38 rows in the _sample_rows call above

if num_increase > 0:
if output_file_path:
append_kwargs = {'mode': 'a', 'header': False} if os.path.exists(
output_file_path) else {}
sampled.head(min(len(sampled), num_rows)).tail(num_increase).to_csv(
output_file_path,
index=False,
**append_kwargs,
)

progress_bar.update(num_increase)

remaining = num_rows - num_valid
if remaining > 0:
LOGGER.info(
f'{remaining} valid rows remaining. Resampling {batch_size_per_try} rows')

counter += 1

return sampled.head(min(len(sampled), num_rows))
Expand Down Expand Up @@ -367,7 +399,8 @@ def _make_condition_dfs(self, conditions):

def _conditionally_sample_rows(self, dataframe, condition, transformed_condition,
max_tries=None, batch_size_per_try=None, float_rtol=0.01,
graceful_reject_sampling=True):
graceful_reject_sampling=True, progress_bar=None,
output_file_path=None):
num_rows = len(dataframe)
sampled_rows = self._sample_batch(
num_rows,
Expand All @@ -376,6 +409,8 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition
condition,
transformed_condition,
float_rtol,
progress_bar,
output_file_path,
)
num_sampled_rows = len(sampled_rows)

Expand All @@ -401,8 +436,17 @@ def _conditionally_sample_rows(self, dataframe, condition, transformed_condition

return sampled_rows

def _validate_file_path(self, output_file_path):
output_path = None
if output_file_path:
output_path = os.path.abspath(output_file_path)
if os.path.exists(output_path):
raise AssertionError(f'{output_path} already exists.')

return output_path

@validate_sample_args
def sample(self, num_rows, randomize_samples=True):
def sample(self, num_rows, randomize_samples=True, batch_size=None, output_file_path=None):
"""Sample rows from this table.

Args:
Expand All @@ -411,6 +455,11 @@ def sample(self, num_rows, randomize_samples=True):
randomize_samples (bool):
Whether or not to use a fixed seed when sampling. Defaults
to True.
batch_size (int or None):
The batch size to sample. Defaults to `num_rows`, if None.
output_file_path (str or None):
The file to periodically write sampled rows to. If None, does not
write rows anywhere.

Returns:
pandas.DataFrame:
Expand All @@ -419,9 +468,30 @@ def sample(self, num_rows, randomize_samples=True):
if num_rows is None:
raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).')

return self._sample_batch(num_rows)
if num_rows == 0:
return pd.DataFrame()

output_file_path = self._validate_file_path(output_file_path)

def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
batch_size = min(batch_size, num_rows) if batch_size else num_rows

sampled = []
with tqdm(total=num_rows) as progress_bar:
progress_bar.set_description(
f'Sampling {num_rows} rows of data in batches of size {batch_size}')
for step in range(math.ceil(num_rows / batch_size)):
sampled_rows = self._sample_batch(
batch_size,
batch_size_per_try=batch_size,
progress_bar=progress_bar,
output_file_path=output_file_path,
)
sampled.append(sampled_rows)

return pd.concat(sampled, ignore_index=True)

def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try,
progress_bar=None, output_file_path=None):
"""Sample rows with conditions.

Args:
Expand All @@ -432,6 +502,11 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
batch_size_per_try (int):
The batch size to use per attempt at sampling. Defaults to 10 times
the number of rows.
progress_bar (tqdm.tqdm or None):
The progress bar to update.
output_file_path (str or None):
The file to periodically write sampled rows to. Defaults to
a temporary file, if None.

Returns:
pandas.DataFrame:
Expand Down Expand Up @@ -480,6 +555,8 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
None,
max_tries,
batch_size_per_try,
progress_bar=progress_bar,
output_file_path=output_file_path,
)
all_sampled_rows.append(sampled_rows)
else:
Expand All @@ -496,6 +573,8 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
transformed_condition,
max_tries,
batch_size_per_try,
progress_bar=progress_bar,
output_file_path=output_file_path,
)
all_sampled_rows.append(sampled_rows)

Expand All @@ -508,7 +587,7 @@ def _sample_with_conditions(self, conditions, max_tries, batch_size_per_try):
return all_sampled_rows

def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
randomize_samples=True):
randomize_samples=True, output_file_path=None):
"""Sample rows from this table with the given conditions.

Args:
Expand All @@ -524,6 +603,9 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
randomize_samples (bool):
Whether or not to use a fixed seed when sampling. Defaults
to True.
output_file_path (str or None):
The file to periodically write sampled rows to. Defaults to
a temporary file, if None.

Returns:
pandas.DataFrame:
Expand All @@ -537,18 +619,28 @@ def sample_conditions(self, conditions, max_tries=100, batch_size_per_try=None,
* any of the conditions' columns are not valid.
* no rows could be generated.
"""
output_file_path = self._validate_file_path(output_file_path)

num_rows = functools.reduce(
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0)
conditions = self._make_condition_dfs(conditions)

sampled = pd.DataFrame()
for condition_dataframe in conditions:
sampled_for_condition = self._sample_with_conditions(
condition_dataframe, max_tries, batch_size_per_try)
sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True)
with tqdm(total=num_rows) as progress_bar:
sampled = pd.DataFrame()
for condition_dataframe in conditions:
sampled_for_condition = self._sample_with_conditions(
condition_dataframe,
max_tries,
batch_size_per_try,
progress_bar,
output_file_path,
)
sampled = pd.concat([sampled, sampled_for_condition], ignore_index=True)

return sampled

def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_try=None,
randomize_samples=True):
randomize_samples=True, output_file_path=None):
"""Sample rows from this table.

Args:
Expand All @@ -564,6 +656,9 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_
randomize_samples (bool):
Whether or not to use a fixed seed when sampling. Defaults
to True.
output_file_path (str or None):
The file to periodically write sampled rows to. Defaults to
a temporary file, if None.

Returns:
pandas.DataFrame:
Expand All @@ -577,7 +672,13 @@ def sample_remaining_columns(self, known_columns, max_tries=100, batch_size_per_
* any of the conditions' columns are not valid.
* no rows could be generated.
"""
return self._sample_with_conditions(known_columns, max_tries, batch_size_per_try)
output_file_path = self._validate_file_path(output_file_path)

with tqdm(total=len(known_columns)) as progress_bar:
sampled = self._sample_with_conditions(
known_columns, max_tries, batch_size_per_try, progress_bar, output_file_path)

return sampled

def _get_parameters(self):
raise NonParametricError()
Expand Down
1 change: 0 additions & 1 deletion tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock):
})
sample_calls = model._model.sample.mock_calls
assert len(sample_calls) == 2
model._model.sample.assert_any_call(5, conditions=expected_transformed_conditions)
model._model.sample.assert_any_call(50, conditions=expected_transformed_conditions)
pd.testing.assert_frame_equal(sampled_data, expected_data)

Expand Down
Loading