-
Notifications
You must be signed in to change notification settings - Fork 881
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(framework) Add FlowerTune templates to
flwr new
(#3587)
Co-authored-by: Javier <jafermarq@users.noreply.github.com> Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
- Loading branch information
1 parent
ccfef79
commit 7424b62
Showing
11 changed files
with
635 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
src/py/flwr/cli/new/templates/app/README.flowertune.md.tpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# FlowerTune LLM on $challenge_name Dataset | ||
|
||
This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name). | ||
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset. | ||
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way, | ||
which allows users to perform the training on a single GPU. | ||
|
||
|
||
## Methodology | ||
|
||
This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library. | ||
The clients' models are aggregated with FedAvg strategy. | ||
This provides a baseline performance for the leaderboard of $challenge_name challenge. | ||
|
||
|
||
## Environments setup | ||
|
||
Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with: | ||
|
||
```shell | ||
pip install -e . | ||
``` | ||
|
||
## Experimental setup | ||
|
||
The dataset is partitioned into $num_clients shards with IID fashion serving as clients. | ||
We randomly sample $fraction_fit clients to be available for each round, | ||
and the federated fine-tuning lasts for `200` rounds. | ||
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard). | ||
|
||
|
||
## Running the challenge | ||
|
||
First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website. | ||
Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine: | ||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
|
||
Run the challenge with default config values. | ||
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically. | ||
|
||
```bash | ||
flwr run | ||
``` | ||
|
||
## VRAM consumption | ||
|
||
We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below: | ||
|
||
| Challenges | GeneralNLP | Finance | Medical | Code | | ||
| :--------: | :--------: | :--------: | :--------: | :--------: | | ||
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB | | ||
|
||
You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`. |
15 changes: 15 additions & 0 deletions
15
src/py/flwr/cli/new/templates/app/code/flwr_tune/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Flower CLI `new` command app / code / flwr_tune templates.""" |
86 changes: 86 additions & 0 deletions
86
src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""$project_name: A Flower / FlowerTune app.""" | ||
|
||
import os | ||
import warnings | ||
from datetime import datetime | ||
|
||
from flwr_datasets import FederatedDataset | ||
from hydra import compose, initialize | ||
from hydra.utils import instantiate | ||
|
||
from flwr.client import ClientApp | ||
from flwr.common import ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerConfig | ||
|
||
from $import_name.client import gen_client_fn, get_parameters | ||
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting | ||
from $import_name.models import get_model | ||
from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config | ||
|
||
# Avoid warnings | ||
warnings.filterwarnings("ignore", category=UserWarning) | ||
os.environ["TOKENIZERS_PARALLELISM"] = "true" | ||
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1" | ||
|
||
# Initialise regular config | ||
with initialize(config_path="conf", version_base="1.1"): | ||
cfg = compose(config_name="config") | ||
|
||
# Initialise static config | ||
with initialize(config_path="conf", version_base="1.1"): | ||
cfg_static = compose(config_name="static_config") | ||
|
||
cfg.train.num_rounds = cfg_static.num_rounds | ||
|
||
# Create output directory given current timestamp | ||
current_time = datetime.now() | ||
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S") | ||
save_path = os.path.join(os.getcwd(), f"results/{folder_name}") | ||
os.makedirs(save_path, exist_ok=True) | ||
|
||
# Partition dataset and get dataloaders | ||
partitioner = instantiate(cfg_static.partitioner) | ||
fds = FederatedDataset( | ||
dataset=cfg_static.dataset.name, partitioners={"train": partitioner} | ||
) | ||
( | ||
tokenizer, | ||
data_collator, | ||
formatting_prompts_func, | ||
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name) | ||
|
||
# ClientApp for Flower Next | ||
client = ClientApp( | ||
client_fn=gen_client_fn( | ||
fds, | ||
tokenizer, | ||
formatting_prompts_func, | ||
data_collator, | ||
cfg.model, | ||
cfg.train, | ||
save_path, | ||
), | ||
) | ||
|
||
# Get initial model weights | ||
init_model = get_model(cfg.model) | ||
init_model_parameters = get_parameters(init_model) | ||
init_model_parameters = ndarrays_to_parameters(init_model_parameters) | ||
|
||
# Instantiate strategy according to config. Here we pass other arguments | ||
# that are only defined at runtime. | ||
strategy = instantiate( | ||
cfg.strategy, | ||
on_fit_config_fn=get_on_fit_config(), | ||
fit_metrics_aggregation_fn=fit_weighted_average, | ||
initial_parameters=init_model_parameters, | ||
evaluate_fn=get_evaluate_fn( | ||
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path | ||
), | ||
) | ||
|
||
# ServerApp for Flower Next | ||
server = ServerApp( | ||
config=ServerConfig(num_rounds=cfg_static.num_rounds), | ||
strategy=strategy, | ||
) |
Oops, something went wrong.