diff --git a/ludwig/constants.py b/ludwig/constants.py index bd0043cde2c..00c594bde23 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -41,6 +41,7 @@ INFER_IMAGE_MAX_WIDTH = "infer_image_max_width" INFER_IMAGE_SAMPLE_SIZE = "infer_image_sample_size" NUM_CHANNELS = "num_channels" +CLASS_WEIGHTS = "class_weights" LOSS = "loss" ROC_AUC = "roc_auc" EVAL_LOSS = "eval_loss" diff --git a/ludwig/data/dataset_synthesizer.py b/ludwig/data/dataset_synthesizer.py index f041e910356..9c3cd267552 100644 --- a/ludwig/data/dataset_synthesizer.py +++ b/ludwig/data/dataset_synthesizer.py @@ -20,9 +20,10 @@ import string import sys import uuid -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np +import pandas as pd import torch import torchaudio import yaml @@ -37,8 +38,10 @@ ENCODER, H3, IMAGE, + INPUT_FEATURES, NAME, NUMBER, + OUTPUT_FEATURES, PREPROCESSING, SEQUENCE, SET, @@ -157,6 +160,13 @@ def build_feature_parameters(features): } +def build_synthetic_dataset_df(dataset_size: int, config: Dict[str, Any]) -> pd.DataFrame: + features = config[INPUT_FEATURES] + config[OUTPUT_FEATURES] + df = build_synthetic_dataset(dataset_size, features) + data = [next(df) for _ in range(dataset_size + 1)] + return pd.DataFrame(data[1:], columns=data[0]) + + def build_synthetic_dataset(dataset_size: int, features: List[dict], outdir: str = "."): """Synthesizes a dataset for testing purposes. diff --git a/ludwig/data/preprocessing.py b/ludwig/data/preprocessing.py index 8511bcd2e66..7d44cabd3c3 100644 --- a/ludwig/data/preprocessing.py +++ b/ludwig/data/preprocessing.py @@ -56,6 +56,7 @@ from ludwig.features.feature_registries import base_type_registry from ludwig.features.feature_utils import compute_feature_hash from ludwig.utils import data_utils, strings_utils +from ludwig.utils.backward_compatibility import upgrade_metadata from ludwig.utils.config_utils import merge_config_preprocessing_with_feature_specific_defaults from ludwig.utils.data_utils import ( CACHEABLE_FORMATS, @@ -1498,9 +1499,13 @@ def shuffle(df): return training_set, test_set, validation_set -def load_metadata(metadata_file_path): +def load_metadata(metadata_file_path: str) -> Dict[str, Any]: logging.info(f"Loading metadata from: {metadata_file_path}") - return data_utils.load_json(metadata_file_path) + training_set_metadata = data_utils.load_json(metadata_file_path) + # TODO(travis): decouple config from training_set_metadata so we don't need to + # upgrade it over time. + training_set_metadata = upgrade_metadata(training_set_metadata) + return training_set_metadata def preprocess_for_training( diff --git a/ludwig/utils/backward_compatibility.py b/ludwig/utils/backward_compatibility.py index d39e7a003ba..a3947cc49d1 100644 --- a/ludwig/utils/backward_compatibility.py +++ b/ludwig/utils/backward_compatibility.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import copy import warnings from typing import Any, Callable, Dict, List, Union from ludwig.constants import ( AUDIO, BIAS, + CLASS_WEIGHTS, COLUMN, CONV_BIAS, CONV_USE_BIAS, @@ -32,6 +33,7 @@ EXECUTOR, FORCE_SPLIT, INPUT_FEATURES, + LOSS, MISSING_VALUE_STRATEGY, NAME, NUM_SAMPLES, @@ -55,6 +57,7 @@ ) from ludwig.features.feature_registries import base_type_registry from ludwig.globals import LUDWIG_VERSION +from ludwig.utils.metric_utils import TrainerMetric from ludwig.utils.misc_utils import merge_dict from ludwig.utils.version_transformation import VersionTransformation, VersionTransformationRegistry @@ -83,7 +86,7 @@ def wrap(fn: Callable[[Dict], Dict]): return wrap -def upgrade_to_latest_version(config: Dict): +def upgrade_to_latest_version(config: Dict) -> Dict: """Updates config from an older version of Ludwig to the current version. If config does not have a "ludwig_version" key, all updates are applied. @@ -98,6 +101,49 @@ def upgrade_to_latest_version(config: Dict): ) +def upgrade_model_progress(model_progress: Dict) -> Dict: + """Updates model progress info to be compatible with latest ProgressTracker implementation. + + Notably, we convert epoch-based stats to their step-based equivalents and reformat metrics into `TrainerMetric` + tuples. + """ + ret = copy.deepcopy(model_progress) + + if "last_improvement_epoch" in ret: + ret["last_improvement_steps"] = ret["last_improvement_epoch"] * ret["batch_size"] + del ret["last_improvement_epoch"] + + if "last_learning_rate_reduction_epoch" in ret: + ret["last_learning_rate_reduction_steps"] = ret["last_learning_rate_reduction_epoch"] * ret["batch_size"] + del ret["last_learning_rate_reduction_epoch"] + + if "last_increase_batch_size_epoch" in ret: + ret["last_increase_batch_size_steps"] = ret["last_increase_batch_size_epoch"] * ret["batch_size"] + del ret["last_increase_batch_size_epoch"] + + if "vali_metrics" in ret: + ret["validation_metrics"] = ret["vali_metrics"] + del ret["vali_metrics"] + + for metric_group in ("train_metrics", "test_metrics", "validation_metrics"): + for tgt in ret[metric_group]: + for metric in ret[metric_group][tgt]: + if len(ret[metric_group][tgt][metric]) == 0 or isinstance( + ret[metric_group][tgt][metric][0], (tuple, list) + ): + continue + + ret[metric_group][tgt][metric] = [ + TrainerMetric(i + 1, (i + 1) * ret["batch_size"], val) + for i, val in enumerate(ret[metric_group][tgt][metric]) + ] + + if "tune_checkpoint_num" not in ret: + ret["tune_checkpoint_num"] = 0 + + return ret + + def _traverse_dicts(config: Any, f: Callable[[Dict], None]): """Recursively applies function f to every dictionary contained in config. @@ -112,6 +158,17 @@ def _traverse_dicts(config: Any, f: Callable[[Dict], None]): _traverse_dicts(v, f) +@register_config_transformation("0.4", ["output_features"]) +def update_class_weights_in_features(feature: Dict[str, Any]) -> Dict[str, Any]: + if LOSS in feature: + class_weights = feature[LOSS].get(CLASS_WEIGHTS, None) + if not isinstance(class_weights, list): + class_weights = None + feature[LOSS][CLASS_WEIGHTS] = class_weights + + return feature + + @register_config_transformation("0.5") def rename_training_to_trainer(config: Dict[str, Any]) -> Dict[str, Any]: if TRAINING in config: @@ -466,35 +523,51 @@ def update_training(config: Dict[str, Any]) -> Dict[str, Any]: @register_config_transformation("0.6") def upgrade_missing_value_strategy(config: Dict[str, Any]) -> Dict[str, Any]: - def __is_old_missing_value_strategy(feature_config: Dict[str, Any]): - if PREPROCESSING not in feature_config: - return False - missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None) - if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"): - return False - return True - - def __update_old_missing_value_strategy(feature_config: Dict[str, Any]): - missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY) - replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill" - feature_name = feature_config.get(NAME) - warnings.warn( - f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy" - f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8", - DeprecationWarning, - ) - feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy}) - - for input_feature in config.get(INPUT_FEATURES, {}): - if __is_old_missing_value_strategy(input_feature): - __update_old_missing_value_strategy(input_feature) + for input_feature in config.get(INPUT_FEATURES, []): + if _is_old_missing_value_strategy(input_feature): + _update_old_missing_value_strategy(input_feature) - for output_feature in config.get(OUTPUT_FEATURES, {}): - if __is_old_missing_value_strategy(output_feature): - __update_old_missing_value_strategy(output_feature) + for output_feature in config.get(OUTPUT_FEATURES, []): + if _is_old_missing_value_strategy(output_feature): + _update_old_missing_value_strategy(output_feature) for feature, feature_defaults in config.get(DEFAULTS, {}).items(): - if __is_old_missing_value_strategy(feature_defaults): - __update_old_missing_value_strategy(config.get(DEFAULTS).get(feature)) + if _is_old_missing_value_strategy(feature_defaults): + _update_old_missing_value_strategy(config.get(DEFAULTS).get(feature)) return config + + +def upgrade_metadata(metadata: Dict[str, Any]) -> Dict[str, Any]: + # TODO(travis): stopgap solution, we should make it so we don't need to do this + # by decoupling config and metadata + metadata = copy.deepcopy(metadata) + _upgrade_metadata_mising_values(metadata) + return metadata + + +def _upgrade_metadata_mising_values(metadata: Dict[str, Any]): + for k, v in metadata.items(): + if isinstance(v, dict) and _is_old_missing_value_strategy(v): + _update_old_missing_value_strategy(v) + + +def _update_old_missing_value_strategy(feature_config: Dict[str, Any]): + missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY) + replacement_strategy = "bfill" if missing_value_strategy == "backfill" else "ffill" + feature_name = feature_config.get(NAME) + warnings.warn( + f"Using `{replacement_strategy}` instead of `{missing_value_strategy}` as the missing value strategy" + f" for `{feature_name}`. These are identical. `{missing_value_strategy}` will be removed in v0.8", + DeprecationWarning, + ) + feature_config[PREPROCESSING].update({MISSING_VALUE_STRATEGY: replacement_strategy}) + + +def _is_old_missing_value_strategy(feature_config: Dict[str, Any]): + if PREPROCESSING not in feature_config: + return False + missing_value_strategy = feature_config.get(PREPROCESSING).get(MISSING_VALUE_STRATEGY, None) + if not missing_value_strategy or missing_value_strategy not in ("backfill", "pad"): + return False + return True diff --git a/ludwig/utils/trainer_utils.py b/ludwig/utils/trainer_utils.py index cf8be05244f..8f76392d2b4 100644 --- a/ludwig/utils/trainer_utils.py +++ b/ludwig/utils/trainer_utils.py @@ -117,6 +117,11 @@ def save(self, filepath): @staticmethod def load(filepath): loaded = load_json(filepath) + + from ludwig.utils.backward_compatibility import upgrade_model_progress + + loaded = upgrade_model_progress(loaded) + return ProgressTracker(**loaded) def log_metrics(self): @@ -142,6 +147,8 @@ def log_metrics(self): if metrics_tuples: # For logging, get the latest metrics. The second "-1" indexes into the TrainerMetric # namedtuple. The last element of the TrainerMetric namedtuple is the actual metric value. + # + # TODO: when loading an existing model, this loses metric values for all but the last epoch. log_metrics[f"{metrics_dict_name}.{feature_name}.{metric_name}"] = metrics_tuples[-1][-1] return log_metrics diff --git a/tests/ludwig/utils/test_backward_compatibility.py b/tests/ludwig/utils/test_backward_compatibility.py index c0705881f9c..6ca999232dc 100644 --- a/tests/ludwig/utils/test_backward_compatibility.py +++ b/tests/ludwig/utils/test_backward_compatibility.py @@ -1,13 +1,16 @@ import copy +import math import pytest from ludwig.constants import ( BFILL, + CLASS_WEIGHTS, EVAL_BATCH_SIZE, EXECUTOR, HYPEROPT, INPUT_FEATURES, + LOSS, NUMBER, OUTPUT_FEATURES, PREPROCESSING, @@ -23,6 +26,7 @@ _upgrade_feature, _upgrade_preprocessing_split, upgrade_missing_value_strategy, + upgrade_model_progress, upgrade_to_latest_version, ) from ludwig.utils.defaults import merge_with_defaults @@ -490,3 +494,134 @@ def test_update_missing_value_strategy(missing_value_strategy: str): expected_config["input_features"][0]["preprocessing"]["missing_value_strategy"] == "ffill" assert updated_config == expected_config + + +def test_old_class_weights_default(): + old_config = { + "input_features": [ + { + "name": "input_feature_1", + "type": "category", + } + ], + "output_features": [ + {"name": "output_feature_1", "type": "category", "loss": {"class_weights": 1}}, + ], + } + + new_config = { + "input_features": [ + { + "name": "input_feature_1", + "type": "category", + } + ], + "output_features": [ + {"name": "output_feature_1", "type": "category", "loss": {"class_weights": None}}, + ], + } + + upgraded_config = upgrade_to_latest_version(old_config) + del upgraded_config["ludwig_version"] + assert new_config == upgraded_config + + old_config[OUTPUT_FEATURES][0][LOSS][CLASS_WEIGHTS] = [0.5, 0.8, 1] + new_config[OUTPUT_FEATURES][0][LOSS][CLASS_WEIGHTS] = [0.5, 0.8, 1] + + upgraded_config = upgrade_to_latest_version(old_config) + del upgraded_config["ludwig_version"] + assert new_config == upgraded_config + + +def test_upgrade_model_progress(): + old_model_progress = { + "batch_size": 64, + "best_eval_metric": 0.5, + "best_increase_batch_size_eval_metric": math.inf, + "best_reduce_learning_rate_eval_metric": math.inf, + "epoch": 2, + "last_improvement": 1, + "last_improvement_epoch": 1, + "last_increase_batch_size": 0, + "last_increase_batch_size_epoch": 0, + "last_increase_batch_size_eval_metric_improvement": 0, + "last_learning_rate_reduction": 0, + "last_learning_rate_reduction_epoch": 0, + "last_reduce_learning_rate_eval_metric_improvement": 0, + "learning_rate": 0.001, + "num_increases_batch_size": 0, + "num_reductions_learning_rate": 0, + "steps": 224, + "test_metrics": { + "combined": {"loss": [0.59, 0.56]}, + "delinquent": { + "accuracy": [0.77, 0.78], + }, + }, + "train_metrics": {"combined": {"loss": [0.58, 0.55]}, "delinquent": {"roc_auc": [0.53, 0.54]}}, + "vali_metrics": {"combined": {"loss": [0.59, 0.60]}, "delinquent": {"roc_auc": [0.53, 0.44]}}, + } + + new_model_progress = upgrade_model_progress(old_model_progress) + + for stat in ("improvement", "increase_batch_size", "learning_rate_reduction"): + assert f"last_{stat}_epoch" not in new_model_progress + assert f"last_{stat}_steps" in new_model_progress + assert ( + new_model_progress[f"last_{stat}_steps"] + == old_model_progress[f"last_{stat}_epoch"] * old_model_progress["batch_size"] + ) + + assert "tune_checkpoint_num" in new_model_progress + + assert "vali_metrics" not in new_model_progress + assert "validation_metrics" in new_model_progress + + metric = new_model_progress["validation_metrics"]["combined"]["loss"][0] + assert len(metric) == 3 + assert metric[-1] == 0.59 + + # Verify that we don't make changes to already-valid model progress dicts. + # To do so, we modify the batch size value and re-run the upgrade on the otherwise-valid `new_model_progress` dict. + new_model_progress["batch_size"] = 1 + unchanged_model_progress = upgrade_model_progress(new_model_progress) + assert unchanged_model_progress == new_model_progress + + +def test_upgrade_model_progress_already_valid(): + # Verify that we don't make changes to already-valid model progress dicts. + valid_model_progress = { + "batch_size": 128, + "best_eval_metric": 5.541325569152832, + "best_increase_batch_size_eval_metric": math.inf, + "best_reduce_learning_rate_eval_metric": math.inf, + "epoch": 5, + "last_improvement": 0, + "last_improvement_steps": 25, + "last_increase_batch_size": 0, + "last_increase_batch_size_eval_metric_improvement": 0, + "last_increase_batch_size_steps": 0, + "last_learning_rate_reduction": 0, + "last_learning_rate_reduction_steps": 0, + "last_reduce_learning_rate_eval_metric_improvement": 0, + "learning_rate": 0.001, + "num_increases_batch_size": 0, + "num_reductions_learning_rate": 0, + "steps": 25, + "test_metrics": { + "Survived": {"accuracy": [[0, 5, 0.39], [1, 10, 0.38]], "loss": [[0, 5, 7.35], [1, 10, 7.08]]}, + "combined": {"loss": [[0, 5, 7.35], [1, 10, 6.24]]}, + }, + "train_metrics": { + "Survived": {"accuracy": [[0, 5, 0.39], [1, 10, 0.40]], "loss": [[0, 5, 7.67], [1, 10, 6.57]]}, + "combined": {"loss": [[0, 5, 7.67], [1, 10, 6.57]]}, + }, + "validation_metrics": { + "Survived": {"accuracy": [[0, 5, 0.38], [1, 10, 0.38]], "loss": [[0, 5, 6.56], [1, 10, 5.54]]}, + "combined": {"loss": [[0, 5, 6.56], [1, 10, 5.54]]}, + }, + "tune_checkpoint_num": 0, + } + + unchanged_model_progress = upgrade_model_progress(valid_model_progress) + assert unchanged_model_progress == valid_model_progress diff --git a/tests/regression_tests/model/test_old_models.py b/tests/regression_tests/model/test_old_models.py index eea9db873a5..2c6427ea036 100644 --- a/tests/regression_tests/model/test_old_models.py +++ b/tests/regression_tests/model/test_old_models.py @@ -2,9 +2,13 @@ import zipfile import pandas as pd +import pytest import wget from ludwig.api import LudwigModel +from ludwig.data.dataset_synthesizer import build_synthetic_dataset_df + +NUM_EXAMPLES = 25 def test_model_loaded_from_old_config_prediction_works(tmpdir): @@ -32,3 +36,26 @@ def test_model_loaded_from_old_config_prediction_works(tmpdir): predictions, _ = ludwig_model.predict(dataset=test_set) assert predictions.to_dict()["Survived_predictions"] == {0: False} + + +@pytest.mark.parametrize( + "model_url", + [ + "https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/twitter_bots_v05.zip", + "https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/ludwig_unit_tests/respiratory_v05.zip", + ], + ids=["twitter_bots", "respiratory"], +) +def test_predict_deprecated_model(model_url, tmpdir): + model_dir = os.path.join(tmpdir, "model") + os.makedirs(model_dir) + + archive_path = wget.download(model_url, tmpdir) + with zipfile.ZipFile(archive_path, "r") as zip_ref: + zip_ref.extractall(model_dir) + + ludwig_model = LudwigModel.load(model_dir) + df = build_synthetic_dataset_df(NUM_EXAMPLES, ludwig_model.config) + + pred_df = ludwig_model.predict(df) + assert len(pred_df) > 0