Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Federated HuggingFace Transformers using Flower #863

Merged
merged 13 commits into from
Dec 23, 2021
61 changes: 61 additions & 0 deletions examples/transformers-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Federated HuggingFace Transformers using Flower and PyTorch

This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart_pytorch](https://flower.dev/docs/quickstart_pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for detailed explaination for the transformer pipeline.

Like `quickstart_pytorch`, running this example in itself is also meant to be quite easy.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/transformers-pytorch . && rm -rf flower && cd transformers-pytorch
```

This will create a new directory called `transformers-pytorch` containing the following files:

```shell
-- pyproject.toml
-- client.py
-- server.py
-- README.md
```

Project dependencies (such as `torch` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
python3 -c "import flwr"
```

If you don't see any errors you're good to go!

# Run Federated Learning with Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminal windows and run the following commands.

Start client 1 in the first terminal:

```shell
python3 client.py
```

Start client 2 in the second terminal:

```shell
python3 client.py
```

You will see that PyTorch is starting a federated training.
123 changes: 123 additions & 0 deletions examples/transformers-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from collections import OrderedDict
import warnings

import flwr as fl
import torch
import numpy as np

import random
from torch.utils.data import DataLoader

from datasets import load_dataset, load_metric

from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AdamW

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint

def load_data():
"""Load IMDB data (training and eval)"""
raw_datasets = load_dataset("imdb")
raw_datasets = raw_datasets.shuffle(seed=42)

# remove unnecessary data split
del raw_datasets["unsupervised"]

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)

# random 100 samples
population = random.sample(range(len(raw_datasets["train"])), 100)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets["train"] = tokenized_datasets["train"].select(population)
tokenized_datasets["test"] = tokenized_datasets["test"].select(population)

tokenized_datasets = tokenized_datasets.remove_columns('text')
tokenized_datasets = tokenized_datasets.rename_column("label","labels")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
batch_size=32,
collate_fn=data_collator,
)

testloader = DataLoader(
tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
)

return trainloader, testloader


def train(net, trainloader, epochs):
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
for _ in range(epochs):
for batch in trainloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()


def test(net, testloader):
metric = load_metric("accuracy")
loss = 0
net.eval()
for batch in testloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
loss += outputs.loss.item()
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
loss /= len(testloader.dataset)
accuracy = metric.compute()["accuracy"]
return loss, accuracy


def main():
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)

trainloader, testloader = load_data()

# Flower client
class IMDBClient(fl.client.NumPyClient):
def get_parameters(self):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
print("Training Started...")
train(net, trainloader, epochs=1)
print("Training Finished.")
return self.get_parameters(), len(trainloader), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), len(testloader), {"accuracy": float(accuracy)}

# Start client
fl.client.start_numpy_client("[::]:9999", client=IMDBClient())


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions examples/transformers-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[build-system]
requires = [
"poetry==1.1.10",
]
build-backend = "poetry.masonry.api"

[tool.poetry]
name = "transformers_pytorch"
version = "0.1.0"
description = "HuggingFace Transformers Federated Learning Quickstart with Flower"
authors = [
"The Flower Authors <enquiries@flower.dev>",
"Kaushik Amar Das <kaushik.das@iiitg.ac.in>"
]

[tool.poetry.dependencies]
python = "^3.8"
flwr = "0.17.0"
# flwr = { path = "../../", develop = true } # Development
torch = "1.9.0"
transformers = "4.11.3"
datasets = "1.12.1"
scikit-learn = "1.0"
14 changes: 14 additions & 0 deletions examples/transformers-pytorch/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

python server.py &
sleep 2 # Sleep for 2s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
17 changes: 17 additions & 0 deletions examples/transformers-pytorch/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import flwr as fl


if __name__ == "__main__":

# Define strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5,
fraction_eval=0.5,
)

# Start server
fl.server.start_server(
server_address="[::]:9999",
config={"num_rounds": 3},
strategy=strategy,
)