Skip to content

Commit

Permalink
Fixed backwards compatibility for training_set_metadata and bfill (#2472
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tgaddair authored Sep 11, 2022
1 parent e7c9951 commit 4b0825b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 30 deletions.
12 changes: 11 additions & 1 deletion ludwig/data/dataset_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import string
import sys
import uuid
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import torch
import torchaudio
import yaml
Expand All @@ -37,8 +38,10 @@
ENCODER,
H3,
IMAGE,
INPUT_FEATURES,
NAME,
NUMBER,
OUTPUT_FEATURES,
PREPROCESSING,
SEQUENCE,
SET,
Expand Down Expand Up @@ -157,6 +160,13 @@ def build_feature_parameters(features):
}


def build_synthetic_dataset_df(dataset_size: int, config: Dict[str, Any]) -> pd.DataFrame:
features = config[INPUT_FEATURES] + config[OUTPUT_FEATURES]
df = build_synthetic_dataset(dataset_size, features)
data = [next(df) for _ in range(dataset_size + 1)]
return pd.DataFrame(data[1:], columns=data[0])


def build_synthetic_dataset(dataset_size: int, features: List[dict], outdir: str = "."):
"""Synthesizes a dataset for testing purposes.
Expand Down
9 changes: 7 additions & 2 deletions ludwig/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from ludwig.features.feature_registries import base_type_registry
from ludwig.features.feature_utils import compute_feature_hash
from ludwig.utils import data_utils, strings_utils
from ludwig.utils.backward_compatibility import upgrade_metadata
from ludwig.utils.config_utils import merge_config_preprocessing_with_feature_specific_defaults
from ludwig.utils.data_utils import (
CACHEABLE_FORMATS,
Expand Down Expand Up @@ -1498,9 +1499,13 @@ def shuffle(df):
return training_set, test_set, validation_set


def load_metadata(metadata_file_path):
def load_metadata(metadata_file_path: str) -> Dict[str, Any]:
logging.info(f"Loading metadata from: {metadata_file_path}")
return data_utils.load_json(metadata_file_path)
training_set_metadata = data_utils.load_json(metadata_file_path)
# TODO(travis): decouple config from training_set_metadata so we don't need to
# upgrade it over time.
training_set_metadata = upgrade_metadata(training_set_metadata)
return training_set_metadata


def preprocess_for_training(
Expand Down
70 changes: 43 additions & 27 deletions ludwig/utils/backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,35 +523,51 @@ def update_training(config: Dict[str, Any]) -> Dict[str, Any]:

@register_config_transformation("0.6")
def upgrade_missing_value_strategy(config: Dict[str, Any]) -> Dict[str, Any]:
def __is_old_missing_value_strategy(feature_config: Dict[str, Any]):
if PREPROCESSING not in feature_config:
return False
missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None)
if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"):
return False
return True

def __update_old_missing_value_strategy(feature_config: Dict[str, Any]):
missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY)
replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill"
feature_name = feature_config.get(NAME)
warnings.warn(
f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy"
f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8",
DeprecationWarning,
)
feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy})

for input_feature in config.get(INPUT_FEATURES, {}):
if __is_old_missing_value_strategy(input_feature):
__update_old_missing_value_strategy(input_feature)
for input_feature in config.get(INPUT_FEATURES, []):
if _is_old_missing_value_strategy(input_feature):
_update_old_missing_value_strategy(input_feature)

for output_feature in config.get(OUTPUT_FEATURES, {}):
if __is_old_missing_value_strategy(output_feature):
__update_old_missing_value_strategy(output_feature)
for output_feature in config.get(OUTPUT_FEATURES, []):
if _is_old_missing_value_strategy(output_feature):
_update_old_missing_value_strategy(output_feature)

for feature, feature_defaults in config.get(DEFAULTS, {}).items():
if __is_old_missing_value_strategy(feature_defaults):
__update_old_missing_value_strategy(config.get(DEFAULTS).get(feature))
if _is_old_missing_value_strategy(feature_defaults):
_update_old_missing_value_strategy(config.get(DEFAULTS).get(feature))

return config


def upgrade_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]:
# TODO(travis): stopgap solution, we should make it so we don't need to do this
# by decoupling config and metadata
metadata = copy.deepcopy(metadata)
_upgrade_metadata_mising_values(metadata)
return metadata


def _upgrade_metadata_mising_values(metadata: Dict[str, Any]):
for k, v in metadata.items():
if isinstance(v, dict) and _is_old_missing_value_strategy(v):
_update_old_missing_value_strategy(v)


def _update_old_missing_value_strategy(feature_config: Dict[str, Any]):
missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY)
replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill"
feature_name = feature_config.get(NAME)
warnings.warn(
f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy"
f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8",
DeprecationWarning,
)
feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy})


def _is_old_missing_value_strategy(feature_config: Dict[str, Any]):
if PREPROCESSING not in feature_config:
return False
missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None)
if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"):
return False
return True
27 changes: 27 additions & 0 deletions tests/regression_tests/model/test_old_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import zipfile

import pandas as pd
import pytest
import wget

from ludwig.api import LudwigModel
from ludwig.data.dataset_synthesizer import build_synthetic_dataset_df

NUM_EXAMPLES = 25


def test_model_loaded_from_old_config_prediction_works(tmpdir):
Expand Down Expand Up @@ -32,3 +36,26 @@ def test_model_loaded_from_old_config_prediction_works(tmpdir):
predictions, _ = ludwig_model.predict(dataset=test_set)

assert predictions.to_dict()["Survived_predictions"] == {0: False}


@pytest.mark.parametrize(
"model_url",
[
"https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/twitter_bots_v05.zip",
"https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/respiratory_v05.zip",
],
ids=["twitter_bots", "respiratory"],
)
def test_predict_deprecated_model(model_url, tmpdir):
model_dir = os.path.join(tmpdir, "model")
os.makedirs(model_dir)

archive_path = wget.download(model_url, tmpdir)
with zipfile.ZipFile(archive_path, "r") as zip_ref:
zip_ref.extractall(model_dir)

ludwig_model = LudwigModel.load(model_dir)
df = build_synthetic_dataset_df(NUM_EXAMPLES, ludwig_model.config)

pred_df = ludwig_model.predict(df)
assert len(pred_df) > 0

0 comments on commit 4b0825b

Please sign in to comment.