From c75f7db5bb4868397eab58d54f3ff4543310fc3c Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 19 Jun 2020 15:57:36 -0700 Subject: [PATCH 1/4] fix: round up to keep batches under `batch_size` --- src/gretel_synthetics/batch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gretel_synthetics/batch.py b/src/gretel_synthetics/batch.py index f846af6d..324c6710 100644 --- a/src/gretel_synthetics/batch.py +++ b/src/gretel_synthetics/batch.py @@ -10,6 +10,7 @@ from dataclasses import dataclass, field from pathlib import Path import gzip +from math import ceil from typing import List, Type, Callable, Dict import pickle from copy import deepcopy @@ -31,6 +32,7 @@ MAX_INVALID = 1000 +BATCH_SIZE = 15 FIELD_DELIM = "field_delimiter" GEN_LINES = "gen_lines" @@ -171,7 +173,7 @@ class DataFrameBatch: Args: df: The input, source DataFrame batch_size: If ``batch_headers`` is not provided we automatically break up - the number of colums in the source DataFrame into batches of N columns. + the number of columns in the source DataFrame into batches of N columns. batch_headers: A list of lists of strings can be provided which will control the number of batches. The number of inner lists is the number of batches, and each inner list represents the columns that belong to that batch @@ -193,7 +195,7 @@ def __init__( self, *, df: pd.DataFrame, - batch_size: int = 15, + batch_size: int = BATCH_SIZE, batch_headers: List[List[str]] = None, config: dict = None ): @@ -228,7 +230,7 @@ def __init__( ) def _create_header_batches(self): - num_batches = len(self._source_df.columns) // self.batch_size + num_batches = ceil(len(self._source_df.columns) / self.batch_size) tmp = np.array_split(list(self._source_df.columns), num_batches) return [list(row) for row in tmp] From 75ddef8aa49e74c3fb9e3099512a7d33da952b28 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 19 Jun 2020 16:18:20 -0700 Subject: [PATCH 2/4] add test cases for batch size --- tests/test_batch.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_batch.py b/tests/test_batch.py index fafae1b9..ece03a2d 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -45,10 +45,23 @@ def test_data(): shutil.rmtree(checkpoint_dir) pass + def simple_validator(line: str): return len(line.split(",")) == 5 +def test_batch_size(test_data): + test_data = test_data.iloc[:, :60] + batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) + assert batches.batch_size == 15 + assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 15] + + test_data = test_data.iloc[:, :59] + batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) + assert batches.batch_size == 15 + assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 14] + + def test_missing_config(test_data): with pytest.raises(ValueError): DataFrameBatch(df=test_data) @@ -74,7 +87,7 @@ def test_init(test_data): # should create the dir structure based on auto # batch sizing - batches = DataFrameBatch(df=test_data, config=config_template) + batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) first_row = [ "ID_code", "target", @@ -91,10 +104,9 @@ def test_init(test_data): "var_10", "var_11", "var_12", - "var_13", ] assert batches.batches[0].headers == first_row - assert len(batches.batches.keys()) == 13 + assert len(batches.batches.keys()) == 14 for i, batch in batches.batches.items(): assert Path(batch.checkpoint_dir).is_dir() assert Path(batch.checkpoint_dir).name == f"batch_{i}" From 2510a5618a608e7528ca962ef1ae0e748d47923f Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 19 Jun 2020 16:19:41 -0700 Subject: [PATCH 3/4] update version to 0.10.3 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 42624f31..a3f5a8ed 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.10.2 \ No newline at end of file +0.10.3 From d2cbc143aed3e9f68fe9dce478e0056c89f3f0a3 Mon Sep 17 00:00:00 2001 From: Alexander Watson Date: Fri, 19 Jun 2020 16:24:01 -0700 Subject: [PATCH 4/4] docstring update --- src/gretel_synthetics/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gretel_synthetics/generate.py b/src/gretel_synthetics/generate.py index 6a38e784..8c1a21dd 100644 --- a/src/gretel_synthetics/generate.py +++ b/src/gretel_synthetics/generate.py @@ -110,7 +110,7 @@ def generate_text( """A generator that will load a model and start creating records. Args: - store: A configuration object, which you must have created previously + config: A configuration object, which you must have created previously start_string: A prefix string that is used to seed the record generation. By default we use a newline, but you may substitue any initial value here which will influence how the generator predicts what to generate.