From 67c73f898d4f7b9d25aa0738043ddfc6e3d7a2b6 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 20 Jul 2023 11:41:28 -0700 Subject: [PATCH] Add lightning module api --- nvflare/lightning/__init__.py | 15 ++++ nvflare/lightning/module.py | 138 ++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 nvflare/lightning/__init__.py create mode 100644 nvflare/lightning/module.py diff --git a/nvflare/lightning/__init__.py b/nvflare/lightning/__init__.py new file mode 100644 index 0000000000..8d081690d5 --- /dev/null +++ b/nvflare/lightning/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module import LightningModule as LightningModule diff --git a/nvflare/lightning/module.py b/nvflare/lightning/module.py new file mode 100644 index 0000000000..a23dacb004 --- /dev/null +++ b/nvflare/lightning/module.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import functools +from typing import Any + +import pytorch_lightning as pl + +import nvflare.client as flare + + +def unflatten(global_weights): + "Unflattens the params from NVFlare." + result = {} + for var_name in global_weights: + _var_name_split = var_name.split(".") + encoder_key = _var_name_split[0] + if encoder_key not in result: + result[encoder_key] = {} + local_var_name = ".".join(_var_name_split[1:]) + result[encoder_key][local_var_name] = global_weights[var_name] + return result + + +def flatten(params: dict): + "Flattens the params from nemo." + # Turn nested dict into single level dict supported by ModelPersistor and Aggregator + state_dict = {} + for encoder_key, prompt_state_dict in params.items(): + for k, v in prompt_state_dict.items(): + state_dict[f"{encoder_key}.{k}"] = v.detach().cpu() + return state_dict + + +class LightningModule(pl.LightningModule): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.fl_model = None + + def get_fl_model(self): + return self.clone() + + def clone(self): + # make new copy of self, and then load fl_model + new_module = copy.copy(self) + if self.fl_model is not None: + new_module.load_state_dict(self.fl_model) + return new_module + + def on_train_start(self): + super().on_train_start() + print("\n *****nvflare****** on_train_start ********** \n") + self._fl_train_start() + + def on_train_end(self): + super().on_train_end() + print("\n *****nvflare****** on_train_end ********** \n") + self._fl_train_end() + + def _fl_init(self): + config_file = "nvf_lightning.json" + config = { + "exchange_path": "./", + "exchange_format": "pytorch", + "params_type": "FULL" + } + with open(config_file, "w") as f: + json.dump(config, f) + flare.init(config=config_file) + + def _fl_train_start(self): + print("ZZZZZ calling _fl_train_start ZZZZZ") + model, metadata = flare.receive_model() + if model: + print("ZZZZZ receiving model ZZZZZ") + weights = unflatten(model) + self.fl_model = weights + self.load_state_dict(weights) + print("ZZZZZ ending _fl_train_start ZZZZZ") + + def _fl_train_end(self): + print("ZZZZZ calling _fl_train_end ZZZZZ") + weights = flatten(self.state_dict()) + flare.submit_model(weights) + print("ZZZZZ ending _fl_train_end ZZZZZ") + + @staticmethod + def fit_start(_func): + """ Decorator factory. """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self._fl_init() + return func(self, *args, **kwargs) + + return wrapper + return decorator(_func) + + @staticmethod + def train_start(_func): + """ Decorator factory. """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self._fl_train_start() + return func(self, *args, **kwargs) + + return wrapper + return decorator(_func) + + @staticmethod + def train_end(_func): + """ Decorator factory. """ + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + r = func(self, *args, **kwargs) + self._fl_train_end() + return r + + return wrapper + return decorator(_func)