From f6cb8ccbb9201be8de516bf62867cf4c390803f3 Mon Sep 17 00:00:00 2001 From: Holger Roth <6304754+holgerroth@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:29:59 -0400 Subject: [PATCH] Update pt params converter (#2989) * update pt params converter * use exclude_vars * print warning * add return value --- .../app_common/abstract/params_converter.py | 2 ++ nvflare/app_opt/pt/params_converter.py | 34 ++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/nvflare/app_common/abstract/params_converter.py b/nvflare/app_common/abstract/params_converter.py index 1ae611a836..31ed54b9fa 100644 --- a/nvflare/app_common/abstract/params_converter.py +++ b/nvflare/app_common/abstract/params_converter.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from abc import ABC, abstractmethod from typing import Any, List @@ -23,6 +24,7 @@ class ParamsConverter(ABC): def __init__(self, supported_tasks: List[str] = None): self.supported_tasks = supported_tasks + self.logger = logging.getLogger(self.__class__.__name__) def process(self, task_name: str, shareable: Shareable, fl_ctx: FLContext) -> Shareable: if not self.supported_tasks or task_name in self.supported_tasks: diff --git a/nvflare/app_opt/pt/params_converter.py b/nvflare/app_opt/pt/params_converter.py index d2ab2bd086..503da8e284 100644 --- a/nvflare/app_opt/pt/params_converter.py +++ b/nvflare/app_opt/pt/params_converter.py @@ -23,16 +23,42 @@ class NumpyToPTParamsConverter(ParamsConverter): def convert(self, params: Dict, fl_ctx) -> Dict: tensor_shapes = fl_ctx.get_prop("tensor_shapes") + exclude_vars = fl_ctx.get_prop("exclude_vars") + + return_params = {} if tensor_shapes: - return { + return_params = { k: torch.as_tensor(np.reshape(v, tensor_shapes[k])) if k in tensor_shapes else torch.as_tensor(v) for k, v in params.items() } else: - return {k: torch.as_tensor(v) for k, v in params.items()} + return_params = {k: torch.as_tensor(v) for k, v in params.items()} + + if exclude_vars: + for k, v in exclude_vars.items(): + return_params[k] = v + + return return_params class PTToNumpyParamsConverter(ParamsConverter): def convert(self, params: Dict, fl_ctx) -> Dict: - fl_ctx.set_prop("tensor_shapes", {k: v.shape for k, v in params.items()}) - return {k: v.cpu().numpy() for k, v in params.items()} + return_tensors = {} + tensor_shapes = {} + exclude_vars = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + return_tensors[k] = v.cpu().numpy() + tensor_shapes[k] = v.shape + else: + exclude_vars[k] = v + + if tensor_shapes: + fl_ctx.set_prop("tensor_shapes", tensor_shapes) + if exclude_vars: + fl_ctx.set_prop("exclude_vars", exclude_vars) + self.logger.warning( + f"{len(exclude_vars)} vars excluded as they were non-tensor type: " f"{list(exclude_vars.keys())}" + ) + + return return_tensors