Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jul 27, 2022
1 parent 6ff64d1 commit 8238406
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 4 additions & 2 deletions examples/flax/text-classification/run_flax_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate, pad_shard_unpad
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
Expand Down Expand Up @@ -622,7 +622,9 @@ def eval_step(state, batch):
position=2,
):
labels = batch.pop("labels")
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size
)
metric.add_batch(predictions=np.array(predictions), references=labels)

eval_metric = metric.compute()
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/token-classification/run_flax_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate, pad_shard_unpad
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
Expand Down
6 changes: 4 additions & 2 deletions examples/flax/vision/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate, pad_shard_unpad
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
Expand Down Expand Up @@ -534,7 +534,9 @@ def eval_step(params, batch):
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(state.params, batch, min_device_batch=per_device_eval_batch_size)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)

eval_step_progress_bar.update(1)
Expand Down

0 comments on commit 8238406

Please sign in to comment.