Skip to content

Commit

Permalink
correct example script (#11726)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored May 14, 2021
1 parent bd3b599 commit 113eaa7
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,6 @@ def parse_args():
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
Expand Down Expand Up @@ -457,13 +451,13 @@ def eval_step(state, batch):
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0

# make sure weights are replicated on each device
state = replicate(state)

for epoch in range(1, num_epochs + 1):
logger.info(f"Epoch {epoch}")
logger.info(" Training...")

# make sure weights are replicated on each device
state = replicate(state)

train_start = time.time()
train_metrics = []
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
Expand Down Expand Up @@ -501,6 +495,9 @@ def eval_step(state, batch):
predictions = eval_step(state, batch)
metric.add_batch(predictions=predictions, references=labels)

# make sure weights are replicated on each device
state = replicate(state)

eval_metric = metric.compute()
logger.info(f" Done! Eval metrics: {eval_metric}")

Expand Down

0 comments on commit 113eaa7

Please sign in to comment.