Skip to content

Commit

Permalink
Fix batch summary data (#84)
Browse files Browse the repository at this point in the history
* Create a summary data class for synthetics generation.

Co-authored-by: Temesghen Kahsai <teme@gretel.ai>
  • Loading branch information
lememta and Temesghen Kahsai authored Jan 26, 2021
1 parent bf7aa64 commit e119a26
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
30 changes: 20 additions & 10 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@
PATH_HOLDER = "___path_holder___"


@dataclass
class GenerationSummary:
"""A class to capture the summary data after synthetic data is generated.
"""

valid_lines: int = 0
invalid_lines: int = 0
is_valid: bool = False


@dataclass
class Batch:
"""A representation of a synthetic data workflow. It should not be used
Expand Down Expand Up @@ -514,7 +524,7 @@ def generate_batch_lines(
num_lines: int = None,
seed_fields: Union[dict, List[dict]] = None,
parallelism: int = 0,
) -> dict:
) -> GenerationSummary:
"""Generate lines for a single batch. Lines generated are added
to the underlying ``Batch`` object for each batch. The lines
can be accessed after generation and re-assembled into a DataFrame.
Expand Down Expand Up @@ -566,7 +576,7 @@ def generate_batch_lines(
t = tqdm(total=num_lines, desc="Valid record count ")
t2 = tqdm(total=max_invalid, desc="Invalid record count ")
line: GenText
n_valid, n_invalid = 0, 0
summary = GenerationSummary()
try:
for line in generate_text(
batch.config,
Expand All @@ -579,20 +589,20 @@ def generate_batch_lines(
if line.valid is None or line.valid is True:
batch.add_valid_data(line)
t.update(1)
n_valid += 1
summary.valid_lines += 1
else:
t2.update(1)
batch.gen_data_invalid.append(line)
n_invalid += 1
summary.invalid_lines += 1
except TooManyInvalidError:
if raise_on_exceed_invalid:
raise
else:
return False
return summary
t.close()
t2.close()
is_valid = batch.gen_data_count >= num_lines
return {'valid_lines': n_valid, 'invalid_lines': n_invalid, 'is_valid': is_valid}
summary.is_valid = batch.gen_data_count >= num_lines
return summary

def generate_all_batch_lines(
self,
Expand All @@ -601,7 +611,7 @@ def generate_all_batch_lines(
num_lines: int = None,
seed_fields: Union[dict, List[dict]] = None,
parallelism: int = 0,
) -> dict:
) -> Dict[int, GenerationSummary]:
"""Generate synthetic lines for all batches. Lines for each batch
are added to the individual ``Batch`` objects. Once generateion is
done, you may re-assemble the dataset into a DataFrame.
Expand Down Expand Up @@ -641,8 +651,8 @@ def generate_all_batch_lines(
that shows if each batch was able to generate the full number of requested lines::
{
0: {'valid_lines' : 1000, 'invalid_lines': 10, 'is_valid': True},
1: {'valid_lines' : 500, 'invalid_lines': 5, 'is_valid': True}
0: GenerationSummary(valid_lines=1000, invalid_lines=10, is_valid=True),
1: GenerationSummary(valid_lines=500, invalid_lines=5, is_valid=True)
}
"""
batch_status = {}
Expand Down
14 changes: 8 additions & 6 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import pandas as pd

from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID, READ, ORIG_HEADERS
from gretel_synthetics.batch import DataFrameBatch, MAX_INVALID, READ, ORIG_HEADERS, GenerationSummary
from gretel_synthetics.generate import GenText
from gretel_synthetics.errors import TooManyInvalidError

Expand Down Expand Up @@ -158,20 +158,20 @@ def bad():
with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
summary = batches.generate_batch_lines(5, max_invalid=1)
assert summary.get('is_valid')
assert summary.is_valid
check_call = mock_gen.mock_calls[0]
_, _, kwargs = check_call
assert kwargs["max_invalid"] == 1

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
summary = batches.generate_batch_lines(5)
assert summary.get('is_valid')
assert summary.is_valid

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
summary = batches.generate_batch_lines(5)
assert not summary.get('is_valid')
assert not summary.is_valid

with patch.object(batches, "generate_batch_lines") as mock_gen:
batches.generate_all_batch_lines(max_invalid=15)
Expand Down Expand Up @@ -214,12 +214,14 @@ def test_generate_batch_lines_raise_on_exceed(test_data):

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.side_effect = TooManyInvalidError()
assert not batches.generate_batch_lines(0)
summary = batches.generate_batch_lines(0)
assert not summary.is_valid

with patch("gretel_synthetics.batch.generate_text") as mock_gen:
mock_gen.side_effect = TooManyInvalidError()
with pytest.raises(TooManyInvalidError):
assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
summary = batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
assert not summary.is_valid


def test_generate_batch_lines_always_raise_other_exceptions(test_data):
Expand Down

0 comments on commit e119a26

Please sign in to comment.