Skip to content

Commit

Permalink
Merge pull request #35 from gretelai/#34-batch-size-aw
Browse files Browse the repository at this point in the history
#34 correct batch size
  • Loading branch information
zredlined authored Jun 19, 2020
2 parents 4ab732c + d2cbc14 commit 57e6222
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.2
0.10.3
8 changes: 5 additions & 3 deletions src/gretel_synthetics/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dataclasses import dataclass, field
from pathlib import Path
import gzip
from math import ceil
from typing import List, Type, Callable, Dict
import pickle
from copy import deepcopy
Expand All @@ -31,6 +32,7 @@


MAX_INVALID = 1000
BATCH_SIZE = 15
FIELD_DELIM = "field_delimiter"
GEN_LINES = "gen_lines"

Expand Down Expand Up @@ -171,7 +173,7 @@ class DataFrameBatch:
Args:
df: The input, source DataFrame
batch_size: If ``batch_headers`` is not provided we automatically break up
the number of colums in the source DataFrame into batches of N columns.
the number of columns in the source DataFrame into batches of N columns.
batch_headers: A list of lists of strings can be provided which will control
the number of batches. The number of inner lists is the number of batches, and each
inner list represents the columns that belong to that batch
Expand All @@ -193,7 +195,7 @@ def __init__(
self,
*,
df: pd.DataFrame,
batch_size: int = 15,
batch_size: int = BATCH_SIZE,
batch_headers: List[List[str]] = None,
config: dict = None
):
Expand Down Expand Up @@ -228,7 +230,7 @@ def __init__(
)

def _create_header_batches(self):
num_batches = len(self._source_df.columns) // self.batch_size
num_batches = ceil(len(self._source_df.columns) / self.batch_size)
tmp = np.array_split(list(self._source_df.columns), num_batches)
return [list(row) for row in tmp]

Expand Down
2 changes: 1 addition & 1 deletion src/gretel_synthetics/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def generate_text(
"""A generator that will load a model and start creating records.
Args:
store: A configuration object, which you must have created previously
config: A configuration object, which you must have created previously
start_string: A prefix string that is used to seed the record generation.
By default we use a newline, but you may substitue any initial value here
which will influence how the generator predicts what to generate.
Expand Down
18 changes: 15 additions & 3 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,23 @@ def test_data():
shutil.rmtree(checkpoint_dir)
pass


def simple_validator(line: str):
return len(line.split(",")) == 5


def test_batch_size(test_data):
test_data = test_data.iloc[:, :60]
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
assert batches.batch_size == 15
assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 15]

test_data = test_data.iloc[:, :59]
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
assert batches.batch_size == 15
assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 14]


def test_missing_config(test_data):
with pytest.raises(ValueError):
DataFrameBatch(df=test_data)
Expand All @@ -74,7 +87,7 @@ def test_init(test_data):

# should create the dir structure based on auto
# batch sizing
batches = DataFrameBatch(df=test_data, config=config_template)
batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
first_row = [
"ID_code",
"target",
Expand All @@ -91,10 +104,9 @@ def test_init(test_data):
"var_10",
"var_11",
"var_12",
"var_13",
]
assert batches.batches[0].headers == first_row
assert len(batches.batches.keys()) == 13
assert len(batches.batches.keys()) == 14
for i, batch in batches.batches.items():
assert Path(batch.checkpoint_dir).is_dir()
assert Path(batch.checkpoint_dir).name == f"batch_{i}"
Expand Down

0 comments on commit 57e6222

Please sign in to comment.