Skip to content

Commit

Permalink
Refactor sample API (1/3) (#705)
Browse files Browse the repository at this point in the history
* Update sample method args

* Add unit tests

* remove conditioning logic for now

* Add error handling

* Make integration tests pass

* code review comments

* Add sample conditions method

* add back unit tests with conditions

* Update logic for handling multiple conditions

* fix lint

* fix integration tests

* Add method to sample remaining columns (3/3) (#708)

* Add method to sample remaining columns

* update integration tests

* add unit tests

* update tutorials and docs

* Enable batch sampling (#709)

* Add batch sampling and progress bar

* Make sure to close progress bar

* Periodically write to file

* add unit tests

* cr comments

* fix test
  • Loading branch information
katxiao authored Feb 23, 2022
1 parent f680e9d commit 0dae889
Show file tree
Hide file tree
Showing 19 changed files with 969 additions and 467 deletions.
2 changes: 1 addition & 1 deletion docs/user_guides/evaluation/evaluation_framework.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ it using the ``GaussianCopula`` model.
model = GaussianCopula()
model.fit(real_data)
synthetic_data = model.sample()
synthetic_data = model.sample(len(real_data))
After the previous steps we will have two tables:

Expand Down
37 changes: 22 additions & 15 deletions docs/user_guides/single_table/copulagan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -688,19 +688,23 @@ Conditional Sampling

As the name implies, conditional sampling allows us to sample from a conditional
distribution using the ``CopulaGAN`` model, which means we can generate only values that
satisfy certain conditions. These conditional values can be passed to the ``conditions``
parameter in the ``sample`` method either as a dataframe or a dictionary.
satisfy certain conditions. These conditional values can be passed to the ``sample_conditions``
method as a list of ``sdv.sampling.Condition`` objects or to the ``sample_remaining_columns`` method
as a dataframe.

In case a dictionary is passed, the model will generate as many rows as requested,
all of which will satisfy the specified conditions, such as ``gender = M``.
When specifying a ``sdv.sampling.Condition`` object, we can pass in the desired conditions
as a dictionary, as well as specify the number of desired rows for that condition.

.. ipython:: python
:okwarning:
conditions = {
from sdv.sampling import Condition
condition = Condition({
'gender': 'M'
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
It's also possible to condition on multiple columns, such as
Expand All @@ -709,14 +713,16 @@ It's also possible to condition on multiple columns, such as
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'gender': 'M',
'experience_years': 0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
The ``conditions`` can also be passed as a dataframe. In that case, the model
In the ``sample_remaining_columns`` method, ``conditions`` is
passed as a dataframe. In that case, the model
will generate one sample for each row of the dataframe, sorted in the same
order. Since the model already knows how many samples to generate, passing
it as a parameter is unnecessary. For example, if we want to generate three
Expand All @@ -731,7 +737,7 @@ following:
conditions = pd.DataFrame({
'gender': ['M', 'M', 'M', 'F', 'F', 'F'],
})
model.sample(conditions=conditions)
model.sample_remaining_columns(conditions)
``CopulaGAN`` also supports conditioning on continuous values, as long as the values
Expand All @@ -741,10 +747,11 @@ dataset are within 0 and 1, ``CopulaGAN`` will not be able to set this value to
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'degree_perc': 70.0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
.. note::
Expand Down
37 changes: 22 additions & 15 deletions docs/user_guides/single_table/ctgan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -499,19 +499,23 @@ Conditional Sampling

As the name implies, conditional sampling allows us to sample from a conditional
distribution using the ``CTGAN`` model, which means we can generate only values that
satisfy certain conditions. These conditional values can be passed to the ``conditions``
parameter in the ``sample`` method either as a dataframe or a dictionary.
satisfy certain conditions. These conditional values can be passed to the ``sample_conditions``
method as a list of ``sdv.sampling.Condition`` objects or to the ``sample_remaining_columns``
method as a dataframe.

In case a dictionary is passed, the model will generate as many rows as requested,
all of which will satisfy the specified conditions, such as ``gender = M``.
When specifying a ``sdv.sampling.Condition`` object, we can pass in the desired conditions
as a dictionary, as well as specify the number of desired rows for that condition.

.. ipython:: python
:okwarning:
conditions = {
from sdv.sampling import Condition
condition = Condition({
'gender': 'M'
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
It's also possible to condition on multiple columns, such as
Expand All @@ -520,14 +524,16 @@ It's also possible to condition on multiple columns, such as
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'gender': 'M',
'experience_years': 0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
The ``conditions`` can also be passed as a dataframe. In that case, the model
In the ``sample_remaining_columns`` method, ``conditions`` is
passed as a dataframe. In that case, the model
will generate one sample for each row of the dataframe, sorted in the same
order. Since the model already knows how many samples to generate, passing
it as a parameter is unnecessary. For example, if we want to generate three
Expand All @@ -542,7 +548,7 @@ following:
conditions = pd.DataFrame({
'gender': ['M', 'M', 'M', 'F', 'F', 'F'],
})
model.sample(conditions=conditions)
model.sample_remaining_columns(conditions)
``CTGAN`` also supports conditioning on continuous values, as long as the values
Expand All @@ -552,10 +558,11 @@ dataset are within 0 and 1, ``CTGAN`` will not be able to set this value to 1000
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'degree_perc': 70.0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
.. note::
Expand Down
37 changes: 22 additions & 15 deletions docs/user_guides/single_table/gaussian_copula.rst
Original file line number Diff line number Diff line change
Expand Up @@ -648,19 +648,23 @@ Conditional Sampling

As the name implies, conditional sampling allows us to sample from a conditional
distribution using the ``GaussianCopula`` model, which means we can generate only values that
satisfy certain conditions. These conditional values can be passed to the ``conditions``
parameter in the ``sample`` method either as a dataframe or a dictionary.
satisfy certain conditions. These conditional values can be passed to the ``sample_conditions``
method as a list of ``sdv.sampling.Condition`` objects or to the ``sample_remaining_columns``
method as a dataframe.

In case a dictionary is passed, the model will generate as many rows as requested,
all of which will satisfy the specified conditions, such as ``gender = M``.
When specifying a ``sdv.sampling.Condition`` object, we can pass in the desired conditions
as a dictionary, as well as specify the number of desired rows for that condition.

.. ipython:: python
:okwarning:
conditions = {
from sdv.sampling import Condition
condition = Condition({
'gender': 'M'
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
It's also possible to condition on multiple columns, such as
Expand All @@ -669,14 +673,16 @@ It's also possible to condition on multiple columns, such as
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'gender': 'M',
'experience_years': 0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
The ``conditions`` can also be passed as a dataframe. In that case, the model
In the ``sample_remaining_columns`` method, ``conditions`` is
passed as a dataframe. In that case, the model
will generate one sample for each row of the dataframe, sorted in the same
order. Since the model already knows how many samples to generate, passing
it as a parameter is unnecessary. For example, if we want to generate three
Expand All @@ -691,7 +697,7 @@ following:
conditions = pd.DataFrame({
'gender': ['M', 'M', 'M', 'F', 'F', 'F'],
})
model.sample(conditions=conditions)
model.sample_remaining_columns(conditions)
``GaussianCopula`` also supports conditioning on continuous values, as long as the values
Expand All @@ -701,10 +707,11 @@ dataset are within 0 and 1, ``GaussianCopula`` will not be able to set this valu
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'degree_perc': 70.0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
.. note::
Expand Down
36 changes: 21 additions & 15 deletions docs/user_guides/single_table/tvae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -484,19 +484,22 @@ Conditional Sampling

As the name implies, conditional sampling allows us to sample from a conditional
distribution using the ``TVAE`` model, which means we can generate only values that
satisfy certain conditions. These conditional values can be passed to the ``conditions``
parameter in the ``sample`` method either as a dataframe or a dictionary.
satisfy certain conditions. These conditional values can be passed to the ``sample_conditions``
method as a list of ``sdv.sampling.Condition`` objects or to the ``sample_remaining_columns``
method as a dataframe.

In case a dictionary is passed, the model will generate as many rows as requested,
all of which will satisfy the specified conditions, such as ``gender = M``.
When specifying a ``sdv.sampling.Condition`` object, we can pass in the desired conditions as a dictionary, as well as specify the number of desired rows for that condition.

.. ipython:: python
:okwarning:
conditions = {
from sdv.sampling import Condition
condition = Condition({
'gender': 'M'
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
It's also possible to condition on multiple columns, such as
Expand All @@ -505,14 +508,16 @@ It's also possible to condition on multiple columns, such as
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'gender': 'M',
'experience_years': 0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
The ``conditions`` can also be passed as a dataframe. In that case, the model
In the ``sample_remaining_columns`` method, ``conditions`` is
passed as a dataframe. In that case, the model
will generate one sample for each row of the dataframe, sorted in the same
order. Since the model already knows how many samples to generate, passing
it as a parameter is unnecessary. For example, if we want to generate three
Expand All @@ -527,7 +532,7 @@ following:
conditions = pd.DataFrame({
'gender': ['M', 'M', 'M', 'F', 'F', 'F'],
})
model.sample(conditions=conditions)
model.sample_remaining_columns(conditions)
``TVAE`` also supports conditioning on continuous values, as long as the values
Expand All @@ -537,10 +542,11 @@ dataset are within 0 and 1, ``TVAE`` will not be able to set this value to 1000.
.. ipython:: python
:okwarning:
conditions = {
condition = Condition({
'degree_perc': 70.0
}
model.sample(5, conditions=conditions)
}, num_rows=5)
model.sample_conditions(conditions=[condition])
.. note::
Expand Down
1 change: 1 addition & 0 deletions sdv/relational/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def _sample_rows(self, model, table_name, num_rows=None):
pandas.DataFrame:
Sampled rows, shape (, num_rows)
"""
num_rows = num_rows or model._num_rows
sampled = model.sample(num_rows)

primary_key_name = self.metadata.get_primary_key(table_name)
Expand Down
Loading

0 comments on commit 0dae889

Please sign in to comment.