Skip to content

Commit

Permalink
feat(framework) Add FlowerTune templates to flwr new (#3587)
Browse files Browse the repository at this point in the history
Co-authored-by: Javier <jafermarq@users.noreply.github.com>
Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
3 people authored Jun 21, 2024
1 parent ccfef79 commit 7424b62
Show file tree
Hide file tree
Showing 11 changed files with 635 additions and 27 deletions.
130 changes: 103 additions & 27 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class MlFramework(str, Enum):
HUGGINGFACE = "HF"
MLX = "MLX"
SKLEARN = "sklearn"
FLOWERTUNE = "FlowerTune"


class LlmChallengeName(str, Enum):
"""Available LLM challenges."""

GENERALNLP = "GeneralNLP"
FINANCE = "Finance"
MEDICAL = "Medical"
CODE = "Code"


class TemplateNotFound(Exception):
Expand Down Expand Up @@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
create_file(file_path, content)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def new(
project_name: Annotated[
Optional[str],
Expand Down Expand Up @@ -125,6 +136,19 @@ def new(

framework_str = framework_str.lower()

if framework_str == "flowertune":
llm_challenge_value = prompt_options(
"Please select LLM challenge by typing in the number",
sorted([challenge.value for challenge in LlmChallengeName]),
)
selected_value = [
name
for name, value in vars(LlmChallengeName).items()
if value == llm_challenge_value
]
llm_challenge_str = selected_value[0]
llm_challenge_str = llm_challenge_str.lower()

print(
typer.style(
f"\n🔨 Creating Flower project {project_name}...",
Expand All @@ -139,40 +163,92 @@ def new(
import_name = package_name.replace("-", "_")
project_dir = os.path.join(cwd, package_name)

# List of files to render
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

context = {
"project_name": project_name,
"package_name": package_name,
"import_name": import_name.replace("-", "_"),
"username": username,
}

# List of files to render
if framework_str == "flowertune":
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": "app/code/flwr_tune/server.py.tpl"
},
f"{import_name}/client.py": {
"template": "app/code/flwr_tune/client.py.tpl"
},
f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
f"{import_name}/models.py": {
"template": "app/code/flwr_tune/models.py.tpl"
},
f"{import_name}/dataset.py": {
"template": "app/code/flwr_tune/dataset.py.tpl"
},
f"{import_name}/conf/config.yaml": {
"template": "app/code/flwr_tune/config.yaml.tpl"
},
f"{import_name}/conf/static_config.yaml": {
"template": "app/code/flwr_tune/static_config.yaml.tpl"
},
}

# Challenge specific context
fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
if llm_challenge_str == "generalnlp":
challenge_name = "General NLP"
num_clients = "20"
dataset_name = "vicgalle/alpaca-gpt4"
elif llm_challenge_str == "finance":
challenge_name = "Finance"
num_clients = "50"
dataset_name = "FinGPT/fingpt-sentiment-train"
elif llm_challenge_str == "medical":
challenge_name = "Medical"
num_clients = "20"
dataset_name = "medalpaca/medical_meadow_medical_flashcards"
else:
challenge_name = "Code"
num_clients = "10"
dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"

context["llm_challenge_str"] = llm_challenge_str
context["fraction_fit"] = fraction_fit
context["challenge_name"] = challenge_name
context["num_clients"] = num_clients
context["dataset_name"] = dataset_name
else:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

for file_path, value in files.items():
render_and_create(
file_path=os.path.join(project_dir, file_path),
Expand Down
56 changes: 56 additions & 0 deletions src/py/flwr/cli/new/templates/app/README.flowertune.md.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# FlowerTune LLM on $challenge_name Dataset

This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
which allows users to perform the training on a single GPU.


## Methodology

This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
The clients' models are aggregated with FedAvg strategy.
This provides a baseline performance for the leaderboard of $challenge_name challenge.


## Environments setup

Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:

```shell
pip install -e .
```

## Experimental setup

The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
We randomly sample $fraction_fit clients to be available for each round,
and the federated fine-tuning lasts for `200` rounds.
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).


## Running the challenge

First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:

```bash
huggingface-cli login
```

Run the challenge with default config values.
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.

```bash
flwr run
```

## VRAM consumption

We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:

| Challenges | GeneralNLP | Finance | Medical | Code |
| :--------: | :--------: | :--------: | :--------: | :--------: |
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |

You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
15 changes: 15 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower CLI `new` command app / code / flwr_tune templates."""
86 changes: 86 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""$project_name: A Flower / FlowerTune app."""

import os
import warnings
from datetime import datetime

from flwr_datasets import FederatedDataset
from hydra import compose, initialize
from hydra.utils import instantiate

from flwr.client import ClientApp
from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig

from $import_name.client import gen_client_fn, get_parameters
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
from $import_name.models import get_model
from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config

# Avoid warnings
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"

# Initialise regular config
with initialize(config_path="conf", version_base="1.1"):
cfg = compose(config_name="config")

# Initialise static config
with initialize(config_path="conf", version_base="1.1"):
cfg_static = compose(config_name="static_config")

cfg.train.num_rounds = cfg_static.num_rounds

# Create output directory given current timestamp
current_time = datetime.now()
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
os.makedirs(save_path, exist_ok=True)

# Partition dataset and get dataloaders
partitioner = instantiate(cfg_static.partitioner)
fds = FederatedDataset(
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
)
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)

# ClientApp for Flower Next
client = ClientApp(
client_fn=gen_client_fn(
fds,
tokenizer,
formatting_prompts_func,
data_collator,
cfg.model,
cfg.train,
save_path,
),
)

# Get initial model weights
init_model = get_model(cfg.model)
init_model_parameters = get_parameters(init_model)
init_model_parameters = ndarrays_to_parameters(init_model_parameters)

# Instantiate strategy according to config. Here we pass other arguments
# that are only defined at runtime.
strategy = instantiate(
cfg.strategy,
on_fit_config_fn=get_on_fit_config(),
fit_metrics_aggregation_fn=fit_weighted_average,
initial_parameters=init_model_parameters,
evaluate_fn=get_evaluate_fn(
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
),
)

# ServerApp for Flower Next
server = ServerApp(
config=ServerConfig(num_rounds=cfg_static.num_rounds),
strategy=strategy,
)
Loading

0 comments on commit 7424b62

Please sign in to comment.