diff --git a/ludwig/api.py b/ludwig/api.py index 8035a22fd51..0afea35b1e0 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -444,6 +444,16 @@ def train( `(training_set, validation_set, test_set)`. `output_directory` filepath to where training results are stored. """ + # Only reset the metadata if the model has not been trained before + if self.training_set_metadata: + logger.warning( + "This model has been trained before. Its architecture has been defined by the original training set " + "(for example, the number of possible categorical outputs). The current training data will be mapped " + "to this architecture. If you want to change the architecture of the model, please concatenate your " + "new training data with the original and train a new model from scratch." + ) + training_set_metadata = self.training_set_metadata + if self._user_config.get(HYPEROPT): print_boxed("WARNING") logger.warning(HYPEROPT_WARNING) diff --git a/tests/integration_tests/test_api.py b/tests/integration_tests/test_api.py index 95879c8326f..e7a9d5102ba 100644 --- a/tests/integration_tests/test_api.py +++ b/tests/integration_tests/test_api.py @@ -740,3 +740,28 @@ def test_saved_weights_in_checkpoint(tmpdir): input_feature_encoder = saved_input_feature["encoder"] assert "saved_weights_in_checkpoint" in input_feature_encoder assert input_feature_encoder["saved_weights_in_checkpoint"] + + +def test_constant_metadata(tmpdir): + input_features = [category_feature(encoder={"vocab_size": 5})] + output_features = [category_feature(name="class", decoder={"vocab_size": 5}, output_feature=True)] + + data_csv1 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset1.csv")) + val_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "validation1.csv")) + test_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "test1.csv")) + + config = { + "input_features": input_features, + "output_features": output_features, + } + model = LudwigModel(config) + model.train(training_set=data_csv1, validation_set=val_csv1, test_set=test_csv1, output_directory=tmpdir) + metadata1 = model.training_set_metadata + + data_csv2 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset2.csv"), num_examples=10) + val_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "validation2.csv")) + test_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "test2.csv")) + model.train(training_set=data_csv2, validation_set=val_csv2, test_set=test_csv2, output_directory=tmpdir) + metadata2 = model.training_set_metadata + + assert metadata1 == metadata2