From 4b0825bd4be74d31c969213cedfb88b862481f8e Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Sat, 10 Sep 2022 18:31:46 -0700 Subject: [PATCH] Fixed backwards compatibility for training_set_metadata and bfill (#2472) --- ludwig/data/dataset_synthesizer.py | 12 +++- ludwig/data/preprocessing.py | 9 ++- ludwig/utils/backward_compatibility.py | 70 ++++++++++++------- .../regression_tests/model/test_old_models.py | 27 +++++++ 4 files changed, 88 insertions(+), 30 deletions(-) diff --git a/ludwig/data/dataset_synthesizer.py b/ludwig/data/dataset_synthesizer.py index f041e910356..9c3cd267552 100644 --- a/ludwig/data/dataset_synthesizer.py +++ b/ludwig/data/dataset_synthesizer.py @@ -20,9 +20,10 @@ import string import sys import uuid -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np +import pandas as pd import torch import torchaudio import yaml @@ -37,8 +38,10 @@ ENCODER, H3, IMAGE, + INPUT_FEATURES, NAME, NUMBER, + OUTPUT_FEATURES, PREPROCESSING, SEQUENCE, SET, @@ -157,6 +160,13 @@ def build_feature_parameters(features): } +def build_synthetic_dataset_df(dataset_size: int, config: Dict[str, Any]) -> pd.DataFrame: + features = config[INPUT_FEATURES] + config[OUTPUT_FEATURES] + df = build_synthetic_dataset(dataset_size, features) + data = [next(df) for _ in range(dataset_size + 1)] + return pd.DataFrame(data[1:], columns=data[0]) + + def build_synthetic_dataset(dataset_size: int, features: List[dict], outdir: str = "."): """Synthesizes a dataset for testing purposes. diff --git a/ludwig/data/preprocessing.py b/ludwig/data/preprocessing.py index 8511bcd2e66..7d44cabd3c3 100644 --- a/ludwig/data/preprocessing.py +++ b/ludwig/data/preprocessing.py @@ -56,6 +56,7 @@ from ludwig.features.feature_registries import base_type_registry from ludwig.features.feature_utils import compute_feature_hash from ludwig.utils import data_utils, strings_utils +from ludwig.utils.backward_compatibility import upgrade_metadata from ludwig.utils.config_utils import merge_config_preprocessing_with_feature_specific_defaults from ludwig.utils.data_utils import ( CACHEABLE_FORMATS, @@ -1498,9 +1499,13 @@ def shuffle(df): return training_set, test_set, validation_set -def load_metadata(metadata_file_path): +def load_metadata(metadata_file_path: str) -> Dict[str, Any]: logging.info(f"Loading metadata from: {metadata_file_path}") - return data_utils.load_json(metadata_file_path) + training_set_metadata = data_utils.load_json(metadata_file_path) + # TODO(travis): decouple config from training_set_metadata so we don't need to + # upgrade it over time. + training_set_metadata = upgrade_metadata(training_set_metadata) + return training_set_metadata def preprocess_for_training( diff --git a/ludwig/utils/backward_compatibility.py b/ludwig/utils/backward_compatibility.py index fd410f39852..a3947cc49d1 100644 --- a/ludwig/utils/backward_compatibility.py +++ b/ludwig/utils/backward_compatibility.py @@ -523,35 +523,51 @@ def update_training(config: Dict[str, Any]) -> Dict[str, Any]: @register_config_transformation("0.6") def upgrade_missing_value_strategy(config: Dict[str, Any]) -> Dict[str, Any]: - def __is_old_missing_value_strategy(feature_config: Dict[str, Any]): - if PREPROCESSING not in feature_config: - return False - missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None) - if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"): - return False - return True - - def __update_old_missing_value_strategy(feature_config: Dict[str, Any]): - missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY) - replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill" - feature_name = feature_config.get(NAME) - warnings.warn( - f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy" - f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8", - DeprecationWarning, - ) - feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy}) - - for input_feature in config.get(INPUT_FEATURES, {}): - if __is_old_missing_value_strategy(input_feature): - __update_old_missing_value_strategy(input_feature) + for input_feature in config.get(INPUT_FEATURES, []): + if _is_old_missing_value_strategy(input_feature): + _update_old_missing_value_strategy(input_feature) - for output_feature in config.get(OUTPUT_FEATURES, {}): - if __is_old_missing_value_strategy(output_feature): - __update_old_missing_value_strategy(output_feature) + for output_feature in config.get(OUTPUT_FEATURES, []): + if _is_old_missing_value_strategy(output_feature): + _update_old_missing_value_strategy(output_feature) for feature, feature_defaults in config.get(DEFAULTS, {}).items(): - if __is_old_missing_value_strategy(feature_defaults): - __update_old_missing_value_strategy(config.get(DEFAULTS).get(feature)) + if _is_old_missing_value_strategy(feature_defaults): + _update_old_missing_value_strategy(config.get(DEFAULTS).get(feature)) return config + + +def upgrade_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + # TODO(travis): stopgap solution, we should make it so we don't need to do this + # by decoupling config and metadata + metadata = copy.deepcopy(metadata) + _upgrade_metadata_mising_values(metadata) + return metadata + + +def _upgrade_metadata_mising_values(metadata: Dict[str, Any]): + for k, v in metadata.items(): + if isinstance(v, dict) and _is_old_missing_value_strategy(v): + _update_old_missing_value_strategy(v) + + +def _update_old_missing_value_strategy(feature_config: Dict[str, Any]): + missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY) + replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill" + feature_name = feature_config.get(NAME) + warnings.warn( + f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy" + f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8", + DeprecationWarning, + ) + feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy}) + + +def _is_old_missing_value_strategy(feature_config: Dict[str, Any]): + if PREPROCESSING not in feature_config: + return False + missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None) + if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"): + return False + return True diff --git a/tests/regression_tests/model/test_old_models.py b/tests/regression_tests/model/test_old_models.py index eea9db873a5..2c6427ea036 100644 --- a/tests/regression_tests/model/test_old_models.py +++ b/tests/regression_tests/model/test_old_models.py @@ -2,9 +2,13 @@ import zipfile import pandas as pd +import pytest import wget from ludwig.api import LudwigModel +from ludwig.data.dataset_synthesizer import build_synthetic_dataset_df + +NUM_EXAMPLES = 25 def test_model_loaded_from_old_config_prediction_works(tmpdir): @@ -32,3 +36,26 @@ def test_model_loaded_from_old_config_prediction_works(tmpdir): predictions, _ = ludwig_model.predict(dataset=test_set) assert predictions.to_dict()["Survived_predictions"] == {0: False} + + +@pytest.mark.parametrize( + "model_url", + [ + "https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/twitter_bots_v05.zip", + "https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/respiratory_v05.zip", + ], + ids=["twitter_bots", "respiratory"], +) +def test_predict_deprecated_model(model_url, tmpdir): + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + archive_path = wget.download(model_url, tmpdir) + with zipfile.ZipFile(archive_path, "r") as zip_ref: + zip_ref.extractall(model_dir) + + ludwig_model = LudwigModel.load(model_dir) + df = build_synthetic_dataset_df(NUM_EXAMPLES, ludwig_model.config) + + pred_df = ludwig_model.predict(df) + assert len(pred_df) > 0