diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index f405dd9fc767..f3453926fec9 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -119,12 +119,6 @@ def parse_args(): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." ) @@ -457,13 +451,13 @@ def eval_step(state, batch): logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 + # make sure weights are replicated on each device + state = replicate(state) + for epoch in range(1, num_epochs + 1): logger.info(f"Epoch {epoch}") logger.info(" Training...") - # make sure weights are replicated on each device - state = replicate(state) - train_start = time.time() train_metrics = [] rng, input_rng, dropout_rng = jax.random.split(rng, 3) @@ -501,6 +495,9 @@ def eval_step(state, batch): predictions = eval_step(state, batch) metric.add_batch(predictions=predictions, references=labels) + # make sure weights are replicated on each device + state = replicate(state) + eval_metric = metric.compute() logger.info(f" Done! Eval metrics: {eval_metric}")