Skip to content

Commit

Permalink
Added backwards compatibility test
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored and justinxzhao committed Jun 21, 2022
1 parent d3f0758 commit 599b6db
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
5 changes: 5 additions & 0 deletions ludwig/data/dataframe/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def reduce_objects(self, series, reduce_fn):
return series.reduction(reduce_fn, aggregate=reduce_fn, meta=("data", "object")).compute()[0]

def split(self, df, probabilities):
# Split the DataFrame proprotionately along partitions. This is an inexact solution designed
# to speed up the split process, as splitting within partitions would be significantly
# more expensive.
# TODO(travis): revisit in the future to make this more precise

# First ensure that every split receives at least one partition.
# If not, we need to increase the number of partitions to satisfy this constraint.
min_prob = min(probabilities)
Expand Down
12 changes: 4 additions & 8 deletions ludwig/utils/backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def _upgrade_trainer(trainer: Dict[str, Any]):
def _upgrade_preprocessing(preprocessing: Dict[str, Any]):
split_params = {}

force_split = preprocessing.get("force_split")
split_probabilities = preprocessing.get("split_probabilities")
stratify = preprocessing.get("stratify")
force_split = preprocessing.pop("force_split", None)
split_probabilities = preprocessing.pop("split_probabilities", None)
stratify = preprocessing.pop("stratify", None)

if split_probabilities is not None:
split_params["probabilities"] = split_probabilities
Expand All @@ -168,7 +168,6 @@ def _upgrade_preprocessing(preprocessing: Dict[str, Any]):
"will be flagged as error in v0.7",
DeprecationWarning,
)
del preprocessing["split_probabilities"]

if stratify is not None:
split_params["type"] = "stratify"
Expand All @@ -179,7 +178,6 @@ def _upgrade_preprocessing(preprocessing: Dict[str, Any]):
"will be flagged as error in v0.7",
DeprecationWarning,
)
del preprocessing["stratify"]

if force_split is not None:
warnings.warn(
Expand All @@ -189,9 +187,7 @@ def _upgrade_preprocessing(preprocessing: Dict[str, Any]):
)

if "type" not in split_params:
split_params["type"] = "random"

del preprocessing["force_split"]
split_params["type"] = "random" if force_split else "fixed"

if split_params:
preprocessing["split"] = split_params
Expand Down
41 changes: 39 additions & 2 deletions tests/ludwig/utils/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NUMBER,
PREPROCESSING,
SCHEDULER,
SPLIT,
TRAINER,
TYPE,
)
Expand Down Expand Up @@ -121,8 +122,8 @@ def test_missing_outputs_drop_rows():

def test_deprecated_field_aliases():
config = {
"input_features": [{"name": "num_in", "type": "number"}],
"output_features": [{"name": "num_out", "type": "number"}],
"input_features": [{"name": "num_in", "type": "numerical"}],
"output_features": [{"name": "num_out", "type": "numerical"}],
"training": {
"epochs": 2,
"eval_batch_size": 0,
Expand Down Expand Up @@ -164,6 +165,42 @@ def test_deprecated_field_aliases():
assert "scheduler" in merged_config[HYPEROPT]["executor"]


@pytest.mark.parametrize("force_split", [None, False, True])
@pytest.mark.parametrize("stratify", [None, "cat_in"])
def test_deprecated_split_aliases(stratify, force_split):
split_probabilities = [0.6, 0.2, 0.2]
config = {
"input_features": [{"name": "num_in", "type": "number"}, {"name": "cat_in", "type": "category"}],
"output_features": [{"name": "num_out", "type": "number"}],
"preprocessing": {
"force_split": force_split,
"split_probabilities": split_probabilities,
"stratify": stratify,
},
}

merged_config = merge_with_defaults(config)

assert "force_split" not in merged_config[PREPROCESSING]
assert "split_probabilities" not in merged_config[PREPROCESSING]
assert "stratify" not in merged_config[PREPROCESSING]

assert SPLIT in merged_config[PREPROCESSING]
split = merged_config[PREPROCESSING][SPLIT]

assert split["probabilities"] == split_probabilities
if stratify is None:
if force_split:
assert split.get(TYPE) == "random"
elif force_split is False:
assert split.get(TYPE) == "fixed"
else:
assert split.get(TYPE) is None
else:
assert split.get(TYPE) == "stratify"
assert split.get("column") == stratify


def test_merge_with_defaults():
# configuration with legacy parameters
legacy_config_format = {
Expand Down

0 comments on commit 599b6db

Please sign in to comment.