Skip to content

Commit

Permalink
Set the metadata only during first training run (#3684)
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Zhao <justinxzhao@gmail.com>
  • Loading branch information
Infernaught and justinxzhao authored Oct 11, 2023
1 parent 2772e9a commit 626d9fc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ def train(
`(training_set, validation_set, test_set)`.
`output_directory` filepath to where training results are stored.
"""
# Only reset the metadata if the model has not been trained before
if self.training_set_metadata:
logger.warning(
"This model has been trained before. Its architecture has been defined by the original training set "
"(for example, the number of possible categorical outputs). The current training data will be mapped "
"to this architecture. If you want to change the architecture of the model, please concatenate your "
"new training data with the original and train a new model from scratch."
)
training_set_metadata = self.training_set_metadata

if self._user_config.get(HYPEROPT):
print_boxed("WARNING")
logger.warning(HYPEROPT_WARNING)
Expand Down
25 changes: 25 additions & 0 deletions tests/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,28 @@ def test_saved_weights_in_checkpoint(tmpdir):
input_feature_encoder = saved_input_feature["encoder"]
assert "saved_weights_in_checkpoint" in input_feature_encoder
assert input_feature_encoder["saved_weights_in_checkpoint"]


def test_constant_metadata(tmpdir):
input_features = [category_feature(encoder={"vocab_size": 5})]
output_features = [category_feature(name="class", decoder={"vocab_size": 5}, output_feature=True)]

data_csv1 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset1.csv"))
val_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "validation1.csv"))
test_csv1 = shutil.copyfile(data_csv1, os.path.join(tmpdir, "test1.csv"))

config = {
"input_features": input_features,
"output_features": output_features,
}
model = LudwigModel(config)
model.train(training_set=data_csv1, validation_set=val_csv1, test_set=test_csv1, output_directory=tmpdir)
metadata1 = model.training_set_metadata

data_csv2 = generate_data(input_features, output_features, os.path.join(tmpdir, "dataset2.csv"), num_examples=10)
val_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "validation2.csv"))
test_csv2 = shutil.copyfile(data_csv2, os.path.join(tmpdir, "test2.csv"))
model.train(training_set=data_csv2, validation_set=val_csv2, test_set=test_csv2, output_directory=tmpdir)
metadata2 = model.training_set_metadata

assert metadata1 == metadata2

0 comments on commit 626d9fc

Please sign in to comment.