Skip to content

Commit

Permalink
Merge branch 'main' into HE_KM
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Jan 10, 2024
2 parents 8cb6d7d + 030b3d2 commit 3f5af59
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand All @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -147,6 +149,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand Down Expand Up @@ -174,12 +177,12 @@ def local_train(self, fl_ctx, abort_signal):
)

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.log_metrics({"train_loss": cost.item(), "running_loss": running_loss}, current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.log_metric("validation_accuracy", metric, epoch)
self.writer.log_metric("validation_accuracy", metric, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down
9 changes: 6 additions & 3 deletions examples/advanced/experiment-tracking/pt/learner_with_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
class PTLearner(Learner):
def __init__(
self,
data_path="/tmp/nvflare/tensorboard-streaming",
data_path="~/data",
lr=0.01,
epochs=5,
exclude_vars=None,
Expand All @@ -52,6 +52,7 @@ def __init__(
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -71,6 +72,7 @@ def __init__(
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -150,6 +152,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand All @@ -173,12 +176,12 @@ def local_train(self, fl_ctx, abort_signal):
running_loss = 0.0

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.add_scalar("train_loss", cost.item(), current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.add_scalar("validation_accuracy", metric, epoch)
self.writer.add_scalar("validation_accuracy", metric, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand All @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -141,6 +143,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand All @@ -164,12 +167,12 @@ def local_train(self, fl_ctx, abort_signal):
running_loss = 0.0

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.log({"train_loss": cost.item()}, current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.log({"validation_accuracy": metric}, epoch)
self.writer.log({"validation_accuracy": metric}, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def authorize(self, fl_ctx: FLContext) -> Tuple[bool, str]:
if command in ["check_resources", "submit_job"]:
security_items = fl_ctx.get_prop(FLContextKey.SECURITY_ITEMS)
job_meta = security_items.get(FLContextKey.JOB_META)
auth_tokens = job_meta.get(JobMetaKey.CUSTOM_PROPS).get("auth_tokens")
auth_tokens = job_meta.get(JobMetaKey.CUSTOM_PROPS, {}).get("auth_tokens")
if not auth_tokens:
return False, f"Not authorized to execute command: {command}"

site_name = fl_ctx.get_identity_name()
site_auth_token = auth_tokens.get(site_name).split(":")[1]

Expand Down
4 changes: 2 additions & 2 deletions examples/hello-world/hello-tf2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ Prepare the data first:
bash ./prepare_data.sh
```

Use nvflare simulator to run the hello-examples:
Use nvflare simulator to run the hello-examples: (TF2 does not allow multiple processes to be running on a single GPU at the same time. Need to set the simulator threads to 1. "-gpu" option can be used to run multiple concurrent clients.)

```
nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 hello-tf2/jobs/hello-tf2
nvflare simulator -w /tmp/nvflare/ -n 2 -t 1 hello-tf2/jobs/hello-tf2
```

### 3. Access the logs and results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def evaluate(input_weights):
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
mlflow.log_metric("loss", running_loss / 2000, i)
global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i
mlflow.log_metric("loss", running_loss / 2000, global_step)
running_loss = 0.0

print(f"({client_id}) Finished Training")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@
"-force"
]
},
{
"cell_type": "markdown",
"id": "932c4b05-c370-4513-8bd8-5c05962d5696",
"metadata": {},
"source": [
">Note:\n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@
"Now let's look closer at the configurations, in particular, the client side, and make sure it matches to the new class we just created"
]
},
{
"cell_type": "markdown",
"id": "379cbf3b-b058-4b5a-b3f8-aa6824f0cf2d",
"metadata": {},
"source": [
">Note: \n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
"%pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"id": "cd59e545-77d9-4cc4-9047-298ceff450e1",
"metadata": {},
"source": [
"> Note: \n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "markdown",
"id": "5f87c3da-59dc-4551-9448-b20b64a57137",
Expand Down Expand Up @@ -471,7 +480,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.8.16"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@
"%pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"id": "65da6755-609d-44f2-844f-01fb3bda1bd0",
"metadata": {
"tags": []
},
"source": [
">Note:\n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "markdown",
"id": "5f87c3da-59dc-4551-9448-b20b64a57137",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
"%pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"id": "50af50ec-ae21-4711-aa88-3ddacf0b5d01",
"metadata": {},
"source": [
">Note:\n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "markdown",
"id": "5f87c3da-59dc-4551-9448-b20b64a57137",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
"%pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"id": "3ed5db39-b8f6-42bd-9610-6e671f17a6ea",
"metadata": {},
"source": [
">Note:\n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command."
]
},
{
"cell_type": "markdown",
"id": "f3d5cd9a-3da9-446c-aac0-6ae84bf0ead1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@
"%pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"id": "952a7361-f9ab-4673-9e37-acadc89847da",
"metadata": {},
"source": [
">Note:\n",
"In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command.\n"
]
},
{
"cell_type": "markdown",
"id": "5f87c3da-59dc-4551-9448-b20b64a57137",
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/job_cli.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@
"The above command creates a job folder at ```/tmp/nvflare/my_job``` with job template ```sag_pt```. \n",
"You can see that a few configuration files are created. Some of the configurations are open for you to overwrite.\n",
"\n",
"If you have the ```tree``` command installed ( ```python -m pip install``` on linux), you can use the ```tree``` command, otherwise, you can use \"ls -al\" to look at the job_folder structure:"
"If you have the ```tree``` command installed ( ```sudo apt install tree``` on linux), you can use the ```tree``` command, otherwise, you can use \"ls -al\" to look at the job_folder structure:"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
launch_timeout: Optional[float] = None,
task_wait_timeout: Optional[float] = None,
last_result_transfer_timeout: float = 300.0,
external_execution_wait: float = 5.0,
peer_read_timeout: Optional[float] = None,
monitor_interval: float = 0.01,
read_interval: float = 0.001,
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
launch_timeout=launch_timeout,
task_wait_timeout=task_wait_timeout,
last_result_transfer_timeout=last_result_transfer_timeout,
external_execution_wait=external_execution_wait,
peer_read_timeout=peer_read_timeout,
monitor_interval=monitor_interval,
read_interval=read_interval,
Expand Down
3 changes: 3 additions & 0 deletions nvflare/app_common/executors/launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
launch_timeout: Optional[float] = None,
task_wait_timeout: Optional[float] = None,
last_result_transfer_timeout: float = 300.0,
external_execution_wait: float = 5.0,
peer_read_timeout: Optional[float] = None,
monitor_interval: float = 1.0,
read_interval: float = 0.1,
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
self._launcher_finish = False
self._launcher_finish_time = None
self._last_result_transfer_timeout = last_result_transfer_timeout
self._external_execution_wait = external_execution_wait
self._received_result = Event()
self._job_end = False

Expand Down Expand Up @@ -245,6 +247,7 @@ def _initialize_external_execution(
self.log_error(fl_ctx, "External execution set up failed.")
abort_signal.trigger("External execution set up failed.")
return False
time.sleep(self._external_execution_wait)
return True

def _execute_launcher_method_in_thread_executor(self, method_name: str, **kwargs) -> Any:
Expand Down
20 changes: 13 additions & 7 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def to_shareable(fl_model: FLModel) -> Shareable:
raise ValueError("FLModel without params and metrics is NOT supported.")
elif fl_model.params is not None:
if fl_model.params_type is None:
raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).")
data_kind = params_type_to_data_kind.get(fl_model.params_type)
fl_model.params_type = ParamsType.FULL

data_kind = params_type_to_data_kind.get(fl_model.params_type.value)
if data_kind is None:
raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).")

Expand Down Expand Up @@ -103,11 +104,15 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
metrics = dxo.data
else:
params_type = data_kind_to_params_type.get(dxo.data_kind)
params = dxo.data
if params_type is None:
raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}")
if params is None:
raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}")
else:
params_type = ParamsType.FULL

params_type = ParamsType(params_type)

params = dxo.data
if MetaKey.INITIAL_METRICS in meta:
metrics = meta[MetaKey.INITIAL_METRICS]
except:
Expand Down Expand Up @@ -197,14 +202,15 @@ def get_configs(model: FLModel) -> Optional[dict]:
@staticmethod
def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel:
if model.params_type != ParamsType.FULL:
raise RuntimeError(
f"params_type {model_update.params_type} of `model` not supported! Expected `ParamsType.FULL`."
)
raise RuntimeError(f"params_type {model.params_type} of `model` not supported! Expected `ParamsType.FULL`.")

if replace_meta:
model.meta = model_update.meta
else:
model.meta.update(model_update.meta)

model.metrics = model_update.metrics

if model_update.params_type == ParamsType.FULL:
model.params = model_update.params
elif model_update.params_type == ParamsType.DIFF:
Expand Down
Loading

0 comments on commit 3f5af59

Please sign in to comment.