Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

AttributeError: 'DummyModule' object has no attribute 'load_checkpoint' #1911

Open
almostimplemented opened this issue May 26, 2022 · 0 comments

Comments

@almostimplemented
Copy link

Description

In the TensorFlow v2 codepath, there is a bug preventing loading the checkpoint.

The bug is very clear in the code. In tensor2tensor/utils/contrib.py, in the absence of tensorflow.contrib, the method framework() will return a DummyModule. Then, in tensor2tensor/utils/t2t_model.py, we try to load the checkpoint via:

reader = contrib.framework().load_checkpoint(ckpt_dir)
variable_map = {}
for var in contrib.framework().get_trainable_variables():

I will open a pull request with my personal solution, and I am open to change it to best fit the project.

...

Environment information

OS: Linux

(venv) ace01@wynne:~/forks/magenta$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
-e git+https://github.com/tensorflow/tensor2tensor@c8fe559e0b357389d8754474e1306b6ca9afc4f3#egg=tensor2tensor
tensorboard==2.9.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.1
tensorflow-addons==0.17.0
tensorflow-datasets==4.5.2
tensorflow-estimator==2.9.0
tensorflow-gan==2.1.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.26.0
tensorflow-metadata==1.8.0
tensorflow-probability==0.7.0


(venv) ace01@wynne:~/forks/magenta$ python -V
Python 3.7.7

Reproduction notes

This is not a minimal reproduction -- this is simply how I encountered it. I am finetuning a musical note sequence model from Magenta.

I think the bug is obvious in the code, so I do not feel the need to provide a minimal repro, but I can if it is deemed necessary.

(venv) ace01@wynne:~/forks/magenta$ cat finetune_midi_transformer.sh
DATA_DIR=/homes/ace01/forks/magenta/jazz_piano_datagen
HPARAMS_SET=transformer_tpu
MODEL=transformer
PROBLEM=score2perf_jazz_piano_language_uncropped_aug
TRAIN_DIR=/homes/ace01/forks/magenta/finetune_dir
UNCONDITIONAL_CHECKPOINT=/homes/ace01/forks/magenta/unconditional_model_16.ckpt


HPARAMS=\
"label_smoothing=0.0,"\
"max_length=0,"\
"max_target_seq_length=2048,"\
"num_hidden_layers=16,"\
"learning_rate=0.005"


t2t_trainer \
  --data_dir="${DATA_DIR}" \
  --hparams=${HPARAMS} \
  --hparams_set=${HPARAMS_SET} \
  --model=${MODEL} \
  --output_dir=${TRAIN_DIR} \
  --problem=${PROBLEM} \
  --train_steps=1000000 \
  --warm_start_from=${UNCONDITIONAL_CHECKPOINT}

Error logs

INFO:tensorflow:Checkpoint dir: /homes/ace01/forks/magenta/unconditional_model_16.ckpt
I0526 11:08:49.628702 140498416027456 t2t_model.py:2341] Checkpoint dir: /homes/ace01/forks/magenta/unconditional_model_16.ckpt
Traceback (most recent call last):
  File "/homes/ace01/forks/magenta/venv/bin/t2t_trainer", line 11, in <module>
    load_entry_point('magenta', 'console_scripts', 't2t_trainer')()
  File "/homes/ace01/forks/magenta/magenta/tensor2tensor/t2t_trainer.py", line 32, in console_entry_point
    tf.app.run(main)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow/python/platform/app.py", line 36, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/homes/ace01/forks/magenta/magenta/tensor2tensor/t2t_trainer.py", line 26, in main
    t2t_trainer.main(argv)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/bin/t2t_trainer.py", line 419, in main
    execute_schedule(exp)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/bin/t2t_trainer.py", line 372, in execute_schedule
    getattr(exp, FLAGS.schedule)()
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/trainer_lib.py", line 469, in continuous_train_and_eval
    self._eval_spec)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 504, in train_and_evaluate
    return executor.run()
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 645, in run
    return self.run_local()
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/training.py", line 746, in run_local
    saving_listeners=saving_listeners)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 360, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1186, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1215, in _train_model_default
    self.config)
  File "/homes/ace01/forks/magenta/venv/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1174, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1422, in wrapping_model_fn
    use_tpu=use_tpu)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1549, in estimator_model_fn
    loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1592, in estimator_spec_train
    self.initialize_from_ckpt(self._hparams.warm_start_from)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 1552, in initialize_from_ckpt
    return initialize_from_ckpt(ckpt_dir=ckpt_dir, hparams=self._hparams)
  File "/homes/ace01/non_forks/tensor2tensor/tensor2tensor/utils/t2t_model.py", line 2342, in initialize_from_ckpt
    reader = contrib.framework().load_checkpoint(ckpt_dir)
AttributeError: 'DummyModule' object has no attribute 'load_checkpoint'
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant