diff --git a/examples/imagenet/configs/default.py b/examples/imagenet/configs/default.py index f6ebb6d2a8..a105bfdac2 100644 --- a/examples/imagenet/configs/default.py +++ b/examples/imagenet/configs/default.py @@ -56,6 +56,10 @@ def get_config(): # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 + + # whether to profile the training loop + config.profile = True + return config diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index 2dc2f04b54..d2f8d01ccc 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -371,8 +371,8 @@ def train_and_evaluate( train_metrics = [] hooks = [] - if jax.process_index() == 0: - hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)] + if jax.process_index() == 0 and config.profile: + hooks += [periodic_actions.Profile(num_profile_steps=3, logdir=workdir)] train_metrics_last_t = time.time() logging.info('Initial compilation, this might take some minutes...') for step, batch in zip(range(step_offset, num_steps), train_iter):