Skip to content

Commit

Permalink
Add persistent db module and recovery logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudnoize committed Dec 30, 2024
1 parent c280f10 commit 727bcf2
Show file tree
Hide file tree
Showing 4 changed files with 394 additions and 25 deletions.
95 changes: 79 additions & 16 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import openfl.callbacks as callbacks_module
from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.databases import TensorDB
from openfl.databases import PersistentTensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.protocols.base_pb2 import NamedTensor
from openfl.utilities import TaskResultKey, TensorKey, change_tags

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -137,6 +139,8 @@ def __init__(
self.quit_job_sent_to = []

self.tensor_db = TensorDB()
# E.L db path from configuration if exists
self.persistent_db = PersistentTensorDB()
# FIXME: I think next line generates an error on the second round
# if it is set to 1 for the aggregator.
self.db_store_rounds = db_store_rounds
Expand All @@ -154,8 +158,21 @@ def __init__(
# TODO: Remove. Used in deprecated interactive and native APIs
self.best_tensor_dict: dict = {}
self.last_tensor_dict: dict = {}
# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
self.collaborator_task_weight = {} # {TaskResultKey: data_size}


if initial_tensor_dict:
# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []
# Initialize a lock for thread safety
self.lock = Lock()
self.use_delta_updates = use_delta_updates

if self._recover():
print("recovered state of aggregator")
elif initial_tensor_dict:
self._load_initial_tensors_from_dict(initial_tensor_dict)
self.model = utils.construct_model_proto(
tensor_dict=initial_tensor_dict,
Expand All @@ -168,20 +185,6 @@ def __init__(

self.collaborator_tensor_results = {} # {TensorKey: nparray}}

# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}

self.collaborator_task_weight = {} # {TaskResultKey: data_size}

# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []

# Initialize a lock for thread safety
self.lock = Lock()

self.use_delta_updates = use_delta_updates

# Callbacks
self.callbacks = callbacks_module.CallbackList(
callbacks,
Expand All @@ -194,6 +197,38 @@ def __init__(
# https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
self.callbacks.on_experiment_begin()
self.callbacks.on_round_begin(self.round_number)

def _recover(self):
if self.persistent_db.is_task_table_empty():
return False
# load tensor to tensor db
self.logger.info("Recovering previous state from persistent storage")
tensor_key_dict = self.persistent_db.load_tensors()
if len(tensor_key_dict) > 0:
self.tensor_db.cache_tensor(tensor_key_dict)
self.logger.info("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)
committed_round_number, self.best_model_score = self.persistent_db.get_round_and_best_score()
# round number is the current round which is still in process i.e. committed_round_number + 1
self.round_number = committed_round_number + 1
self.logger.info("Recovery - loaded round number %s and best score %s", self.round_number,self.best_model_score)
self.logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
task_result = self.persistent_db.get_task_result_by_id(task_id)
if not task_result:
break
collaborator_name = task_result["collaborator_name"]
round_number = task_result["round_number"]
task_name = task_result["task_name"]
data_size = task_result["data_size"]
serialized_tensors = task_result["named_tensors"]
named_tensors = [
NamedTensor.FromString(tensor_string)
for tensor_string in serialized_tensors
]
self.logger.info("Recovery - Replaying task results %s %s %s",collaborator_name ,round_number, task_name )
self.process_task_results(collaborator_name, round_number, task_name, data_size, named_tensors)
task_id += 1

def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.
Expand Down Expand Up @@ -255,18 +290,31 @@ def _save_model(self, round_number, file_path):
for k, v in og_tensor_dict.items()
]
tensor_dict = {}
tensor_tuple_dict = {}
for tk in tensor_keys:
tk_name, _, _, _, _ = tk
tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk)
print(f"transaction - tk is {tk} len {len(tk)} tk_name is {tk_name}")
tensor_value = self.tensor_db.get_tensor_from_cache(tk)
tensor_dict[tk_name] = tensor_value
tensor_tuple_dict[tk] = tensor_value
if tensor_dict[tk_name] is None:
logger.info(
"Cannot save model for round %s. Continuing...",
round_number,
)
return
#E.L here we can save the tensor_dict as well. as transaction.
# we can omit the proto save, at the end of the experiment to write the last and best model tensors as proto
# and clean all the db.
if file_path == self.best_state_path:
self.best_tensor_dict = tensor_dict
if file_path == self.last_state_path:
# Transaction to persist/delete all data needed to increment the round
self.persistent_db.finalize_round(tensor_tuple_dict,self.round_number,self.best_model_score)
self.logger.info(
"Persist model and cleaned task result for round %s",
round_number,
)
self.last_tensor_dict = tensor_dict
self.model = utils.construct_model_proto(
tensor_dict, round_number, self.compression_pipeline
Expand Down Expand Up @@ -606,6 +654,19 @@ def send_local_task_results(
Returns:
None
"""
# Save task and its metadata for recovery
serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors]
self.persistent_db.save_task_results(collaborator_name,round_number,task_name,data_size,serialized_tensors)
self.process_task_results(collaborator_name,round_number,task_name,data_size,named_tensors)

def process_task_results(
self,
collaborator_name,
round_number,
task_name,
data_size,
named_tensors,
):
if self._time_to_quit() or collaborator_name in self.stragglers:
logger.warning(
f"STRAGGLER: Collaborator {collaborator_name} is reporting results "
Expand Down Expand Up @@ -990,6 +1051,7 @@ def _end_of_round_check(self):

self.round_number += 1
# resetting stragglers for task for a new round
#E.L should it be saved?
self.stragglers = []
# resetting collaborators_done for next round
self.collaborators_done = []
Expand All @@ -1005,6 +1067,7 @@ def _end_of_round_check(self):
# Cleaning tensor db
self.tensor_db.clean_up(self.db_store_rounds)
# Reset straggler handling policy for the next round.
#E.L should it be saved?
self.straggler_handling_policy.reset_policy_for_round()

def _is_collaborator_done(self, collaborator_name: str, round_number: int) -> None:
Expand Down
24 changes: 15 additions & 9 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,17 +409,23 @@ def get_data_for_tensorkey(self, tensor_key):
creates_model=True,
)
self.tensor_db.cache_tensor({new_model_tk: nparray})
else:
logger.info(
"Count not find previous model layer."
"Fetching latest layer from aggregator"
)
# The original model tensor should be fetched from client
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
elif "model" in tags:
# Pulling the model for the first time
on purpose
logger.info("Getting model from aggregator key %s ", tensor_key)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:

# The original model tensor should be fetched from client
tensor_name, origin, round_number, report, tags = tensor_key
tags = (self.collaborator_name,) + tags
tensor_key = (tensor_name, origin, round_number, report, tags)
logger.info(
f"Couldnt not find previous model layer."
"Fetching latest layer from aggregator {tensor_key}"
)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
Expand Down
1 change: 1 addition & 0 deletions openfl/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@


from openfl.databases.tensor_db import TensorDB
from openfl.databases.persistent_db import PersistentTensorDB
Loading

0 comments on commit 727bcf2

Please sign in to comment.