Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add FlowerTune templates to flwr new #3587

Merged
merged 51 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
41e231b
Init flwrtune to flwr new
yan-gao-GY Jun 11, 2024
39a448f
Update flwr run for backend_config passing
yan-gao-GY Jun 11, 2024
ab19093
Update
yan-gao-GY Jun 11, 2024
c828db8
Init flwr new with 4 LLM tasks
yan-gao-GY Jun 12, 2024
cd7a896
Formatting
yan-gao-GY Jun 12, 2024
df70a0f
Update src/py/flwr/cli/new/templates/app/pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
7787d51
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
70e5d7a
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
4b6720b
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 12, 2024
421ddce
Fix config files
yan-gao-GY Jun 12, 2024
922c0c5
Update pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
8e60056
Update pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 12, 2024
51ba186
Fix
yan-gao-GY Jun 13, 2024
a318ac4
Fix
yan-gao-GY Jun 13, 2024
ca33364
Fix
yan-gao-GY Jun 13, 2024
5817d47
Avoid warnings
yan-gao-GY Jun 13, 2024
9b04e87
Fix
yan-gao-GY Jun 13, 2024
b9ab673
Update readme and formatting
yan-gao-GY Jun 13, 2024
f4cc70a
Formatting
yan-gao-GY Jun 13, 2024
2fc3c58
Update src/py/flwr/cli/new/templates/app/pyproject.flwrtune.toml.tpl
yan-gao-GY Jun 13, 2024
56ebd7a
Formatting
yan-gao-GY Jun 13, 2024
8bbdff3
Update src/py/flwr/cli/new/templates/app/code/flwrtune/config.yaml.tpl
jafermarq Jun 14, 2024
fdc867a
Merge branch 'main' into add-flwrtune-flwrnew
jafermarq Jun 14, 2024
83b30e6
Merge branch 'main' into add-flwrtune-flwrnew
jafermarq Jun 18, 2024
27cae47
Update readme
yan-gao-GY Jun 18, 2024
3bcf6dd
Replace `task` with `challenge`
yan-gao-GY Jun 18, 2024
9fd9a3f
Formatting
yan-gao-GY Jun 18, 2024
95e65dd
Update readme
yan-gao-GY Jun 18, 2024
29d78cd
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
5648a79
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
bfe3e1a
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
fabe74a
Update src/py/flwr/cli/new/templates/app/README.flwrtune.md.tpl
yan-gao-GY Jun 18, 2024
2a9ea73
Update readme
yan-gao-GY Jun 18, 2024
b634864
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 18, 2024
28c5602
Update readme
yan-gao-GY Jun 18, 2024
c0085ed
Update src/py/flwr/cli/new/new.py
yan-gao-GY Jun 20, 2024
5056153
Update src/py/flwr/cli/new/new.py
yan-gao-GY Jun 20, 2024
8ccf93f
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
cd6d02e
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
4cc871a
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
69a6693
Update src/py/flwr/cli/new/templates/app/code/flwrtune/client.py.tpl
yan-gao-GY Jun 20, 2024
dc9e08e
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
6368699
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
6ec4288
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 20, 2024
bf07c4f
Formatting
yan-gao-GY Jun 20, 2024
76093e3
Change model parameter init method & formatting
yan-gao-GY Jun 20, 2024
061275b
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 20, 2024
dc34b16
Update src/py/flwr/cli/new/templates/app/code/flwrtune/app.py.tpl
yan-gao-GY Jun 21, 2024
d48c2a2
Update FlowerTune names
yan-gao-GY Jun 21, 2024
4a61686
Formatting
yan-gao-GY Jun 21, 2024
e4f66ed
Merge branch 'main' into add-flwrtune-flwrnew
yan-gao-GY Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 103 additions & 27 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class MlFramework(str, Enum):
HUGGINGFACE = "HF"
MLX = "MLX"
SKLEARN = "sklearn"
FLOWERTUNE = "FlowerTune"


class LlmChallengeName(str, Enum):
"""Available LLM challenges."""

GENERALNLP = "GeneralNLP"
FINANCE = "Finance"
MEDICAL = "Medical"
CODE = "Code"


class TemplateNotFound(Exception):
Expand Down Expand Up @@ -81,6 +91,7 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) ->
create_file(file_path, content)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def new(
project_name: Annotated[
Optional[str],
Expand Down Expand Up @@ -125,6 +136,19 @@ def new(

framework_str = framework_str.lower()

if framework_str == "flowertune":
llm_challenge_value = prompt_options(
"Please select LLM challenge by typing in the number",
sorted([challenge.value for challenge in LlmChallengeName]),
)
selected_value = [
name
for name, value in vars(LlmChallengeName).items()
if value == llm_challenge_value
]
llm_challenge_str = selected_value[0]
llm_challenge_str = llm_challenge_str.lower()

print(
typer.style(
f"\n🔨 Creating Flower project {project_name}...",
Expand All @@ -139,40 +163,92 @@ def new(
import_name = package_name.replace("-", "_")
project_dir = os.path.join(cwd, package_name)

# List of files to render
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

context = {
"project_name": project_name,
"package_name": package_name,
"import_name": import_name.replace("-", "_"),
"username": username,
}

# List of files to render
if framework_str == "flowertune":
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": "app/code/flwr_tune/server.py.tpl"
},
f"{import_name}/client.py": {
"template": "app/code/flwr_tune/client.py.tpl"
},
f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"},
f"{import_name}/models.py": {
"template": "app/code/flwr_tune/models.py.tpl"
},
f"{import_name}/dataset.py": {
"template": "app/code/flwr_tune/dataset.py.tpl"
},
f"{import_name}/conf/config.yaml": {
"template": "app/code/flwr_tune/config.yaml.tpl"
},
f"{import_name}/conf/static_config.yaml": {
"template": "app/code/flwr_tune/static_config.yaml.tpl"
},
}

# Challenge specific context
fraction_fit = "0.2" if llm_challenge_str == "code" else "0.1"
if llm_challenge_str == "generalnlp":
challenge_name = "General NLP"
num_clients = "20"
dataset_name = "vicgalle/alpaca-gpt4"
elif llm_challenge_str == "finance":
challenge_name = "Finance"
num_clients = "50"
dataset_name = "FinGPT/fingpt-sentiment-train"
elif llm_challenge_str == "medical":
challenge_name = "Medical"
num_clients = "20"
dataset_name = "medalpaca/medical_meadow_medical_flashcards"
else:
challenge_name = "Code"
num_clients = "10"
dataset_name = "lucasmccabe-lmi/CodeAlpaca-20k"

context["llm_challenge_str"] = llm_challenge_str
context["fraction_fit"] = fraction_fit
context["challenge_name"] = challenge_name
context["num_clients"] = num_clients
context["dataset_name"] = dataset_name
else:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
},
f"{import_name}/client.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
}

for file_path, value in files.items():
render_and_create(
file_path=os.path.join(project_dir, file_path),
Expand Down
56 changes: 56 additions & 0 deletions src/py/flwr/cli/new/templates/app/README.flowertune.md.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# FlowerTune LLM on $challenge_name Dataset

This directory conducts federated instruction tuning with a pretrained [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model on a [$challenge_name dataset](https://huggingface.co/datasets/$dataset_name).
We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset.
Flower's Simulation Engine is used to simulate the LLM fine-tuning process in federated way,
which allows users to perform the training on a single GPU.


## Methodology

This baseline performs federated LLM fine-tuning with [LoRA](https://arxiv.org/pdf/2106.09685) using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library.
The clients' models are aggregated with FedAvg strategy.
This provides a baseline performance for the leaderboard of $challenge_name challenge.


## Environments setup

Project dependencies are defined in `pyproject.toml`. Install them in an activated Python environment with:

```shell
pip install -e .
```

## Experimental setup

The dataset is partitioned into $num_clients shards with IID fashion serving as clients.
We randomly sample $fraction_fit clients to be available for each round,
and the federated fine-tuning lasts for `200` rounds.
All settings are defined in `$project_name/conf/static_config.yaml`, which is not allowed to be modified for fair competition if you plan to participated in the [LLM leaderboard](https://flower.ai/benchmarks/llm-leaderboard).


## Running the challenge

First make sure that you have got the access to [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.3) model with your Hugging-Face account. You can request access directly from the Hugging-Face website.
Then, follow the instruction [here](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) to log in your account. Note you only need to complete this stage once in your development machine:

```bash
huggingface-cli login
```

Run the challenge with default config values.
The configs are in `$project_name/conf/config.yaml` and `$project_name/conf/static_config.yaml`, and are loaded automatically.

```bash
flwr run
```

## VRAM consumption

We use Mistral-7B model with 4-bit quantization as default. The estimated VRAM consumption per client for each challenge is shown below:

| Challenges | GeneralNLP | Finance | Medical | Code |
| :--------: | :--------: | :--------: | :--------: | :--------: |
| VRAM | ~25.50 GB | ~17.30 GB | ~22.80 GB | ~17.40 GB |

You can adjust the CPU/GPU resources you assign to each of the clients based on your device, which is specified with `flower.engine.simulation` in `pyproject.toml`.
15 changes: 15 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower CLI `new` command app / code / flwr_tune templates."""
86 changes: 86 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""$project_name: A Flower / FlowerTune app."""

import os
import warnings
from datetime import datetime

from flwr_datasets import FederatedDataset
from hydra import compose, initialize
from hydra.utils import instantiate

from flwr.client import ClientApp
from flwr.common import ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig

from $import_name.client import gen_client_fn, get_parameters
from $import_name.dataset import get_tokenizer_and_data_collator_and_propt_formatting
from $import_name.models import get_model
from $import_name.server import fit_weighted_average, get_evaluate_fn, get_on_fit_config

# Avoid warnings
warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"

# Initialise regular config
with initialize(config_path="conf", version_base="1.1"):
cfg = compose(config_name="config")

# Initialise static config
with initialize(config_path="conf", version_base="1.1"):
cfg_static = compose(config_name="static_config")

cfg.train.num_rounds = cfg_static.num_rounds

# Create output directory given current timestamp
current_time = datetime.now()
folder_name = current_time.strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(os.getcwd(), f"results/{folder_name}")
os.makedirs(save_path, exist_ok=True)

# Partition dataset and get dataloaders
partitioner = instantiate(cfg_static.partitioner)
fds = FederatedDataset(
dataset=cfg_static.dataset.name, partitioners={"train": partitioner}
)
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)

# ClientApp for Flower Next
client = ClientApp(
client_fn=gen_client_fn(
fds,
tokenizer,
formatting_prompts_func,
data_collator,
cfg.model,
cfg.train,
save_path,
),
)

# Get initial model weights
init_model = get_model(cfg.model)
init_model_parameters = get_parameters(init_model)
init_model_parameters = ndarrays_to_parameters(init_model_parameters)

# Instantiate strategy according to config. Here we pass other arguments
# that are only defined at runtime.
strategy = instantiate(
cfg.strategy,
on_fit_config_fn=get_on_fit_config(),
fit_metrics_aggregation_fn=fit_weighted_average,
initial_parameters=init_model_parameters,
evaluate_fn=get_evaluate_fn(
cfg.model, cfg.train.save_every_round, cfg_static.num_rounds, save_path
),
)

# ServerApp for Flower Next
server = ServerApp(
config=ServerConfig(num_rounds=cfg_static.num_rounds),
strategy=strategy,
)
Loading