Skip to content

Commit

Permalink
Backwards compatibility for class_weights (#2469)
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccorm authored Sep 10, 2022
1 parent 877635f commit 7870cf7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
INFER_IMAGE_MAX_WIDTH = "infer_image_max_width"
INFER_IMAGE_SAMPLE_SIZE = "infer_image_sample_size"
NUM_CHANNELS = "num_channels"
CLASS_WEIGHTS = "class_weights"
LOSS = "loss"
ROC_AUC = "roc_auc"
EVAL_LOSS = "eval_loss"
Expand Down
13 changes: 13 additions & 0 deletions ludwig/utils/backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ludwig.constants import (
AUDIO,
BIAS,
CLASS_WEIGHTS,
COLUMN,
CONV_BIAS,
CONV_USE_BIAS,
Expand All @@ -32,6 +33,7 @@
EXECUTOR,
FORCE_SPLIT,
INPUT_FEATURES,
LOSS,
MISSING_VALUE_STRATEGY,
NAME,
NUM_SAMPLES,
Expand Down Expand Up @@ -156,6 +158,17 @@ def _traverse_dicts(config: Any, f: Callable[[Dict], None]):
_traverse_dicts(v, f)


@register_config_transformation("0.4", ["output_features"])
def update_class_weights_in_features(feature: Dict[str, Any]) -> Dict[str, Any]:
if LOSS in feature:
class_weights = feature[LOSS].get(CLASS_WEIGHTS, None)
if not isinstance(class_weights, list):
class_weights = None
feature[LOSS][CLASS_WEIGHTS] = class_weights

return feature


@register_config_transformation("0.5")
def rename_training_to_trainer(config: Dict[str, Any]) -> Dict[str, Any]:
if TRAINING in config:
Expand Down
39 changes: 39 additions & 0 deletions tests/ludwig/utils/test_backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from ludwig.constants import (
BFILL,
CLASS_WEIGHTS,
EVAL_BATCH_SIZE,
EXECUTOR,
HYPEROPT,
INPUT_FEATURES,
LOSS,
NUMBER,
OUTPUT_FEATURES,
PREPROCESSING,
Expand Down Expand Up @@ -494,6 +496,43 @@ def test_update_missing_value_strategy(missing_value_strategy: str):
assert updated_config == expected_config


def test_old_class_weights_default():
old_config = {
"input_features": [
{
"name": "input_feature_1",
"type": "category",
}
],
"output_features": [
{"name": "output_feature_1", "type": "category", "loss": {"class_weights": 1}},
],
}

new_config = {
"input_features": [
{
"name": "input_feature_1",
"type": "category",
}
],
"output_features": [
{"name": "output_feature_1", "type": "category", "loss": {"class_weights": None}},
],
}

upgraded_config = upgrade_to_latest_version(old_config)
del upgraded_config["ludwig_version"]
assert new_config == upgraded_config

old_config[OUTPUT_FEATURES][0][LOSS][CLASS_WEIGHTS] = [0.5, 0.8, 1]
new_config[OUTPUT_FEATURES][0][LOSS][CLASS_WEIGHTS] = [0.5, 0.8, 1]

upgraded_config = upgrade_to_latest_version(old_config)
del upgraded_config["ludwig_version"]
assert new_config == upgraded_config


def test_upgrade_model_progress():
old_model_progress = {
"batch_size": 64,
Expand Down

0 comments on commit 7870cf7

Please sign in to comment.