Skip to content

Commit

Permalink
Merge branch 'main' into fix_simulator_worker_workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 20, 2024
2 parents a72cd98 + 6e4f444 commit 7854ee6
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"metadata": {},
"source": [
"\n",
"Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n",
"Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the streaming capability from the clients to the server with Tensorboard SummaryWriter sender syntax, but with a MLflow receiver\n",
"\n",
"> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"mlflow.note.content": "## **Hello PyTorch experiment with MLflow**"
},
"run_tags": {
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
}
},
"artifact_location": "artifacts"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"mlflow.note.content": "## **Hello PyTorch experiment with MLflow**"
},
"run_tags": {
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
"mlflow.note.content": "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server for server delivery to MLflow.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n"
}
},
"artifact_location": "artifacts"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"kwargs" : {
"project": "hello-pt-experiment",
"name": "hello-pt",
"notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n",
"notes": "Federated Experiment tracking with W&B \n Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the Flare streaming capability from the clients to the server and deliver to MLFLow.\\n\\n> **_NOTE:_** \\n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.\n",
"tags": ["baseline", "paper1"],
"job_type": "train-validate",
"config": {
Expand Down
15 changes: 8 additions & 7 deletions nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ def __init__(self):
self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}
"""
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
callbacks = trainer.callbacks
if isinstance(callbacks, list):
if isinstance(callbacks, Callback):
callbacks = [callbacks]
elif not isinstance(callbacks, list):
callbacks = []

if not any(isinstance(cb, FLCallback) for cb in callbacks):
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
callbacks.append(fl_callback)
elif isinstance(callbacks, Callback):
callbacks = [callbacks, fl_callback]
else:
callbacks = [fl_callback]

if restore_state:
if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks):
callbacks.append(RestoreState())

trainer.callbacks = callbacks
Expand Down
2 changes: 2 additions & 0 deletions nvflare/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def send(model: FLModel, clear_cache: bool = True) -> None:
model (FLModel): Sends a FLModel object.
clear_cache: clear cache after send
"""
if not isinstance(fl_model, FLModel):
raise TypeError("fl_model needs to be an instance of FLModel")
global client_api
return client_api.send(model, clear_cache)

Expand Down

0 comments on commit 7854ee6

Please sign in to comment.