From 5c567c74afb48358b8dfead4f35497d5ef5c1dfe Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 3 Dec 2024 14:21:25 -0500 Subject: [PATCH] update decomposer --- nvflare/app_opt/pt/decomposers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nvflare/app_opt/pt/decomposers.py b/nvflare/app_opt/pt/decomposers.py index 0696f3cd0c..a7f071f33d 100644 --- a/nvflare/app_opt/pt/decomposers.py +++ b/nvflare/app_opt/pt/decomposers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ def __init__(self, tensor): super().__init__() self.register_buffer("saved_tensor", tensor) + class TensorDecomposer(fobs.Decomposer): def supported_type(self): return torch.Tensor @@ -38,7 +39,6 @@ def decompose(self, target: torch.Tensor, manager: DatumManager = None) -> Any: return self._numpy_serialize(target) def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: - if isinstance(data, dict): if data["dtype"] == "torch.bfloat16": return self._jit_deserialize(data) @@ -52,8 +52,7 @@ def recompose(self, data: Any, manager: DatumManager = None) -> torch.Tensor: @staticmethod def _numpy_serialize(tensor: torch.Tensor) -> dict: stream = BytesIO() - - # torch.save uses Pickle so converting Tensor to ndarray first + # supported ScalarType, use numpy to avoid Pickle array = tensor.detach().cpu().numpy() np.save(stream, array, allow_pickle=False) return { @@ -69,8 +68,9 @@ def _numpy_deserialize(data: Any) -> torch.Tensor: @staticmethod def _jit_serialize(tensor: torch.Tensor) -> dict: - module = SerializationModule(tensor) stream = BytesIO() + # unsupported ScalarType by numpy, use torch.jit to avoid Pickle + module = SerializationModule(tensor) torch.jit.save(torch.jit.script(module), stream) return { "buffer": stream.getvalue(),