From 7557df1ed329555433e89e33a206fbde6e8634f7 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 2 Oct 2023 16:53:24 -0400 Subject: [PATCH 1/5] Set the metadata only during first training run --- ludwig/api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ludwig/api.py b/ludwig/api.py index c7c3f46240d..42b1d414255 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -443,6 +443,14 @@ def train( `(training_set, validation_set, test_set)`. `output_directory` filepath to where training results are stored. """ + # Only reset the metadata if the model has not been trained before + if self.training_set_metadata: + logging.warn( + "Previous metadata has been detected. Overriding `training_set_metadata` with metadata from previous " + "training run." + ) + training_set_metadata = self.training_set_metadata + if self._user_config.get(HYPEROPT): print_boxed("WARNING") logger.warning(HYPEROPT_WARNING) From 9c2fd7d88f55a026a12b4dd92437c02183e84fa9 Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Tue, 3 Oct 2023 10:42:33 -0400 Subject: [PATCH 2/5] Change warning --- ludwig/api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ludwig/api.py b/ludwig/api.py index 42b1d414255..abee878a3cf 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -445,9 +445,11 @@ def train( """ # Only reset the metadata if the model has not been trained before if self.training_set_metadata: - logging.warn( - "Previous metadata has been detected. Overriding `training_set_metadata` with metadata from previous " - "training run." + logger.warning( + "This model has been trained before. Its architecture has been defined by the original training set " + "(for example, the number of possible categorical outputs). The current training data will be mapped " + "to this architecture. If you want to change the architecture of the model, please concatenate your " + "new training data with the original and train a new model from scratch." ) training_set_metadata = self.training_set_metadata From cc3ae51da22a40b50bbbabc4eb47723d909108ba Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Mon, 2 Oct 2023 16:53:24 -0400 Subject: [PATCH 3/5] Set the metadata only during first training run --- ludwig/api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ludwig/api.py b/ludwig/api.py index c7c3f46240d..42b1d414255 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -443,6 +443,14 @@ def train( `(training_set, validation_set, test_set)`. `output_directory` filepath to where training results are stored. """ + # Only reset the metadata if the model has not been trained before + if self.training_set_metadata: + logging.warn( + "Previous metadata has been detected. Overriding `training_set_metadata` with metadata from previous " + "training run." + ) + training_set_metadata = self.training_set_metadata + if self._user_config.get(HYPEROPT): print_boxed("WARNING") logger.warning(HYPEROPT_WARNING) From 00e76c4b65abe96b02c4ece426900756884ec98e Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Tue, 3 Oct 2023 10:42:33 -0400 Subject: [PATCH 4/5] Change warning --- ludwig/api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ludwig/api.py b/ludwig/api.py index 42b1d414255..abee878a3cf 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -445,9 +445,11 @@ def train( """ # Only reset the metadata if the model has not been trained before if self.training_set_metadata: - logging.warn( - "Previous metadata has been detected. Overriding `training_set_metadata` with metadata from previous " - "training run." + logger.warning( + "This model has been trained before. Its architecture has been defined by the original training set " + "(for example, the number of possible categorical outputs). The current training data will be mapped " + "to this architecture. If you want to change the architecture of the model, please concatenate your " + "new training data with the original and train a new model from scratch." ) training_set_metadata = self.training_set_metadata From 8037e87b53681cedf677d73058559935c50542cf Mon Sep 17 00:00:00 2001 From: Infernaught <89timw@gmail.com> Date: Wed, 4 Oct 2023 10:36:59 -0400 Subject: [PATCH 5/5] Add test to verify metadata stays constant --- tests/integration_tests/test_api.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/integration_tests/test_api.py b/tests/integration_tests/test_api.py index 95879c8326f..e7a9d5102ba 100644 --- a/tests/integration_tests/test_api.py +++ b/tests/integration_tests/test_api.py @@ -740,3 +740,28 @@ def test_saved_weights_in_checkpoint(tmpdir): input_feature_encoder = saved_input_feature["encoder"] assert "saved_weights_in_checkpoint" in input_feature_encoder assert input_feature_encoder["saved_weights_in_checkpoint"] + + +def test_constant_metadata(tmpdir): + input_features = [category_feature(encoder={"vocab_size": 5})] + output_features = [category_feature(name="class", decoder={"vocab_size": 5}, output_feature=True)] + + data_csv1 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset1.csv")) + val_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "validation1.csv")) + test_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "test1.csv")) + + config = { + "input_features": input_features, + "output_features": output_features, + } + model = LudwigModel(config) + model.train(training_set=data_csv1, validation_set=val_csv1, test_set=test_csv1, output_directory=tmpdir) + metadata1 = model.training_set_metadata + + data_csv2 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset2.csv"), num_examples=10) + val_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "validation2.csv")) + test_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "test2.csv")) + model.train(training_set=data_csv2, validation_set=val_csv2, test_set=test_csv2, output_directory=tmpdir) + metadata2 = model.training_set_metadata + + assert metadata1 == metadata2