Skip to content

Commit

Permalink
Merge pull request huggingface#9 from ROCmSoftwarePlatform/bert-tf2
Browse files Browse the repository at this point in the history
Bert tf2
  • Loading branch information
stevenireeves authored Feb 22, 2022
2 parents 8682754 + 65cf0d6 commit 25329fb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
37 changes: 37 additions & 0 deletions scripts/bert/bert_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import tensorflow as tf
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
raw_datasets = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")

def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

model = TFAutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2)
tf_train_dataset = small_train_dataset.remove_columns(["text"]).with_format("tensorflow")
tf_eval_dataset = small_eval_dataset.remove_columns(["text"]).with_format("tensorflow")

train_features = {x: tf_train_dataset[x] for x in tokenizer.model_input_names}
train_tf_dataset = tf.data.Dataset.from_tensor_slices((train_features, tf_train_dataset["label"]))
train_tf_dataset = train_tf_dataset.shuffle(len(tf_train_dataset)).batch(8)

eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}
eval_tf_dataset = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset["label"]))
eval_tf_dataset = eval_tf_dataset.batch(8)

model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=tf.metrics.SparseCategoricalAccuracy(),
)

print("==================================== Evaluating Model =================================")
model.fit(train_tf_dataset, validation_data=eval_tf_dataset, epochs=3)
12 changes: 12 additions & 0 deletions scripts/bert/bert_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
set -e
set -x
pip3 install transformers datasets


cd ~ && git clone --branch bert-tf2 https://github.com/ROCmSoftwarePlatform/transformers
# Script to train the small 117M model
python3 transformers/scripts/bert/bert_train.py > log.txt
cat log.txt | tail -n 1
cat log.txt | tail -n 1 | awk '{ print "Accuracy: " $(NF) }'

0 comments on commit 25329fb

Please sign in to comment.