diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index eaac9fa6a0..26838bcf84 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -13,9 +13,11 @@ import openfl.callbacks as callbacks_module from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling from openfl.databases import TensorDB +from openfl.databases import PersistentTensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import base_pb2, utils +from openfl.protocols.base_pb2 import NamedTensor from openfl.utilities import TaskResultKey, TensorKey, change_tags logger = logging.getLogger(__name__) @@ -137,6 +139,8 @@ def __init__( self.quit_job_sent_to = [] self.tensor_db = TensorDB() + # E.L db path from configuration if exists + self.persistent_db = PersistentTensorDB() # FIXME: I think next line generates an error on the second round # if it is set to 1 for the aggregator. self.db_store_rounds = db_store_rounds @@ -154,8 +158,21 @@ def __init__( # TODO: Remove. Used in deprecated interactive and native APIs self.best_tensor_dict: dict = {} self.last_tensor_dict: dict = {} + # these enable getting all tensors for a task + self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys} + self.collaborator_task_weight = {} # {TaskResultKey: data_size} + - if initial_tensor_dict: + # maintain a list of collaborators that have completed task and + # reported results in a given round + self.collaborators_done = [] + # Initialize a lock for thread safety + self.lock = Lock() + self.use_delta_updates = use_delta_updates + + if self._recover(): + print("recovered state of aggregator") + elif initial_tensor_dict: self._load_initial_tensors_from_dict(initial_tensor_dict) self.model = utils.construct_model_proto( tensor_dict=initial_tensor_dict, @@ -168,20 +185,6 @@ def __init__( self.collaborator_tensor_results = {} # {TensorKey: nparray}} - # these enable getting all tensors for a task - self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys} - - self.collaborator_task_weight = {} # {TaskResultKey: data_size} - - # maintain a list of collaborators that have completed task and - # reported results in a given round - self.collaborators_done = [] - - # Initialize a lock for thread safety - self.lock = Lock() - - self.use_delta_updates = use_delta_updates - # Callbacks self.callbacks = callbacks_module.CallbackList( callbacks, @@ -194,6 +197,38 @@ def __init__( # https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537 self.callbacks.on_experiment_begin() self.callbacks.on_round_begin(self.round_number) + + def _recover(self): + if self.persistent_db.is_task_table_empty(): + return False + # load tensor to tensor db + self.logger.info("Recovering previous state from persistent storage") + tensor_key_dict = self.persistent_db.load_tensors() + if len(tensor_key_dict) > 0: + self.tensor_db.cache_tensor(tensor_key_dict) + self.logger.info("Recovery - this is the tensor_db after recovery: %s", self.tensor_db) + committed_round_number, self.best_model_score = self.persistent_db.get_round_and_best_score() + # round number is the current round which is still in process i.e. committed_round_number + 1 + self.round_number = committed_round_number + 1 + self.logger.info("Recovery - loaded round number %s and best score %s", self.round_number,self.best_model_score) + self.logger.info("Recovery - Replaying saved task results") + task_id = 1 + while True: + task_result = self.persistent_db.get_task_result_by_id(task_id) + if not task_result: + break + collaborator_name = task_result["collaborator_name"] + round_number = task_result["round_number"] + task_name = task_result["task_name"] + data_size = task_result["data_size"] + serialized_tensors = task_result["named_tensors"] + named_tensors = [ + NamedTensor.FromString(tensor_string) + for tensor_string in serialized_tensors + ] + self.logger.info("Recovery - Replaying task results %s %s %s",collaborator_name ,round_number, task_name ) + self.process_task_results(collaborator_name, round_number, task_name, data_size, named_tensors) + task_id += 1 def _load_initial_tensors(self): """Load all of the tensors required to begin federated learning. @@ -255,18 +290,31 @@ def _save_model(self, round_number, file_path): for k, v in og_tensor_dict.items() ] tensor_dict = {} + tensor_tuple_dict = {} for tk in tensor_keys: tk_name, _, _, _, _ = tk - tensor_dict[tk_name] = self.tensor_db.get_tensor_from_cache(tk) + print(f"transaction - tk is {tk} len {len(tk)} tk_name is {tk_name}") + tensor_value = self.tensor_db.get_tensor_from_cache(tk) + tensor_dict[tk_name] = tensor_value + tensor_tuple_dict[tk] = tensor_value if tensor_dict[tk_name] is None: logger.info( "Cannot save model for round %s. Continuing...", round_number, ) return + #E.L here we can save the tensor_dict as well. as transaction. + # we can omit the proto save, at the end of the experiment to write the last and best model tensors as proto + # and clean all the db. if file_path == self.best_state_path: self.best_tensor_dict = tensor_dict if file_path == self.last_state_path: + # Transaction to persist/delete all data needed to increment the round + self.persistent_db.finalize_round(tensor_tuple_dict,self.round_number,self.best_model_score) + self.logger.info( + "Persist model and cleaned task result for round %s", + round_number, + ) self.last_tensor_dict = tensor_dict self.model = utils.construct_model_proto( tensor_dict, round_number, self.compression_pipeline @@ -606,6 +654,19 @@ def send_local_task_results( Returns: None """ + # Save task and its metadata for recovery + serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors] + self.persistent_db.save_task_results(collaborator_name,round_number,task_name,data_size,serialized_tensors) + self.process_task_results(collaborator_name,round_number,task_name,data_size,named_tensors) + + def process_task_results( + self, + collaborator_name, + round_number, + task_name, + data_size, + named_tensors, + ): if self._time_to_quit() or collaborator_name in self.stragglers: logger.warning( f"STRAGGLER: Collaborator {collaborator_name} is reporting results " @@ -990,6 +1051,7 @@ def _end_of_round_check(self): self.round_number += 1 # resetting stragglers for task for a new round + #E.L should it be saved? self.stragglers = [] # resetting collaborators_done for next round self.collaborators_done = [] @@ -1005,6 +1067,7 @@ def _end_of_round_check(self): # Cleaning tensor db self.tensor_db.clean_up(self.db_store_rounds) # Reset straggler handling policy for the next round. + #E.L should it be saved? self.straggler_handling_policy.reset_policy_for_round() def _is_collaborator_done(self, collaborator_name: str, round_number: int) -> None: diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index d4fd380998..fdeb0be15e 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -409,17 +409,22 @@ def get_data_for_tensorkey(self, tensor_key): creates_model=True, ) self.tensor_db.cache_tensor({new_model_tk: nparray}) - else: - logger.info( - "Count not find previous model layer." - "Fetching latest layer from aggregator" - ) - # The original model tensor should be fetched from client - nparray = self.get_aggregated_tensor_from_aggregator( - tensor_key, require_lossless=True - ) elif "model" in tags: # Pulling the model for the first time + logger.info("Getting model from aggregator key %s ", tensor_key) + nparray = self.get_aggregated_tensor_from_aggregator( + tensor_key, require_lossless=True + ) + else: + + # The original model tensor should be fetched from client + tensor_name, origin, round_number, report, tags = tensor_key + tags = (self.collaborator_name,) + tags + tensor_key = (tensor_name, origin, round_number, report, tags) + logger.info( + f"Couldnt not find previous model layer." + "Fetching latest layer from aggregator {tensor_key}" + ) nparray = self.get_aggregated_tensor_from_aggregator( tensor_key, require_lossless=True ) diff --git a/openfl/databases/__init__.py b/openfl/databases/__init__.py index 849fcde7c9..49a35039b1 100644 --- a/openfl/databases/__init__.py +++ b/openfl/databases/__init__.py @@ -3,3 +3,4 @@ from openfl.databases.tensor_db import TensorDB +from openfl.databases.persistent_db import PersistentTensorDB diff --git a/openfl/databases/persistent_db.py b/openfl/databases/persistent_db.py new file mode 100644 index 0000000000..05afe610cd --- /dev/null +++ b/openfl/databases/persistent_db.py @@ -0,0 +1,299 @@ +import json +import pickle +import sqlite3 +import numpy as np +from threading import Lock +from typing import Dict, Iterator, Optional + +from openfl.interface.aggregation_functions import AggregationFunction +from openfl.utilities import LocalTensor, TensorKey, change_tags + +__all__ = ['PersistentTensorDB'] + + +class PersistentTensorDB: + """ + The PersistentTensorDB class implements a database for storing tensors using SQLite. + + Attributes: + conn: The SQLite connection object. + cursor: The SQLite cursor object. + mutex: A threading Lock object used to ensure thread-safe operations. + """ + + def __init__(self, db_file: str = "tensordb.sqlite") -> None: + """Initializes a new instance of the PersistentTensorDB class.""" + self.conn = sqlite3.connect(db_file, check_same_thread=False) + self.cursor = self.conn.cursor() + self.mutex = Lock() + self._create_model_tensors_table() + self._create_task_results_table() + self._create_key_value_Store() + + def _create_model_tensors_table(self) -> None: + """Create the database schema if it does not exist.""" + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS tensors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tensor_name TEXT, + origin TEXT, + round INTEGER, + report INTEGER, + tags TEXT, + nparray BLOB + ) + """) + self.conn.commit() + + def _create_task_results_table(self) -> None: + """Creates a table for storing task results.""" + create_table_query = """ + CREATE TABLE IF NOT EXISTS task_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + collaborator_name TEXT NOT NULL, + round_number INTEGER NOT NULL, + task_name TEXT NOT NULL, + data_size INTEGER NOT NULL, + named_tensors BLOB NOT NULL + ); + """ + self.cursor.execute(create_table_query) + self.conn.commit() + + def _create_key_value_Store(self) -> None: + self.cursor.execute(""" + CREATE TABLE IF NOT EXISTS key_value_store ( + key TEXT PRIMARY KEY, + value REAL + ) + """) + self.conn.commit() + + def init_task_results_table(self): + """ + Creates a table for storing task results. Drops the table first if it already exists. + """ + drop_table_query = "DROP TABLE IF EXISTS task_results" + self.cursor.execute(drop_table_query) + + create_table_query = """ + CREATE TABLE IF NOT EXISTS task_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + collaborator_name TEXT NOT NULL, + round_number INTEGER NOT NULL, + task_name TEXT NOT NULL, + data_size INTEGER NOT NULL, + named_tensors BLOB NOT NULL + ); + """ + self.cursor.execute(create_table_query) + self.conn.commit() + + def save_task_results( + self, + collaborator_name: str, + round_number: int, + task_name: str, + data_size: int, + serialized_tensors, + ): + """ + Saves task results to the task_results table. + + Args: + collaborator_name (str): Collaborator name. + round_number (int): Round number. + task_name (str): Task name. + data_size (int): Data size. + named_tensors: Named tensors (e.g., protobuf, stored as serialized binary). + """ + serialized_blob = pickle.dumps(serialized_tensors) + + + # Insert into the database + insert_query = """ + INSERT INTO task_results + (collaborator_name, round_number, task_name, data_size, named_tensors) + VALUES (?, ?, ?, ?, ?); + """ + self.cursor.execute( + insert_query, + (collaborator_name, round_number, task_name, data_size, serialized_blob), + ) + print(f"Saves task result for collaborator {collaborator_name} round number {round_number} taks_name {task_name} data_size {data_size}") + self.conn.commit() + + def get_task_result_by_id(self, task_result_id: int): + """ + Retrieve a task result by its ID. + + Args: + task_result_id (int): The ID of the task result to retrieve. + + Returns: + A dictionary containing the task result details, or None if not found. + """ + with self.mutex: + self.cursor.execute(""" + SELECT collaborator_name, round_number, task_name, data_size, named_tensors + FROM task_results + WHERE id = ? + """, (task_result_id,)) + result = self.cursor.fetchone() + if result: + collaborator_name, round_number, task_name, data_size, serialized_blob = result + serialized_tensors = pickle.loads(serialized_blob) + return { + "collaborator_name": collaborator_name, + "round_number": round_number, + "task_name": task_name, + "data_size": data_size, + "named_tensors": serialized_tensors + } + return None + + def _serialize_array(self, array: np.ndarray) -> bytes: + """Serialize a NumPy array into bytes for storing in SQLite.""" + return array.tobytes() + + def _deserialize_array(self, blob: bytes, dtype: Optional[np.dtype] = None) -> np.ndarray: + """Deserialize bytes from SQLite into a NumPy array.""" + return np.frombuffer(blob, dtype=dtype) + + def __repr__(self) -> str: + """Returns a string representation of the PersistentTensorDB.""" + with self.mutex: + self.cursor.execute("SELECT tensor_name, origin, round, report, tags FROM tensors") + rows = self.cursor.fetchall() + return f"PersistentTensorDB contents:\n{rows}" + + def finalize_round(self,tensor_key_dict: Dict[TensorKey, np.ndarray],round_number: int, best_score: float): + with self.mutex: + try: + # Begin transaction + self.cursor.execute("BEGIN TRANSACTION") + self._persist_tensors(tensor_key_dict) + self._init_task_results_table() + self._save_round_and_best_score(round_number,best_score) + # Commit transaction + self.conn.commit() + print(f"Committed model for round {round_number}, saved {len(tensor_key_dict)} model tensors with best_score {best_score}") + except Exception as e: + # Rollback transaction in case of an error + self.conn.rollback() + raise RuntimeError(f"Failed to finalize round: {e}") + + def _persist_tensors(self, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None: + """Insert a dictionary of tensors into the SQLite database in a single transaction.""" + for tensor_key, nparray in tensor_key_dict.items(): + tensor_name, origin, fl_round, report, tags = tensor_key + serialized_array = self._serialize_array(nparray) + serialized_tags = json.dumps(tags) + self.cursor.execute(""" + INSERT INTO tensors (tensor_name, origin, round, report, tags, nparray) + VALUES (?, ?, ?, ?, ?, ?) + """, (tensor_name, origin, fl_round, int(report), serialized_tags, serialized_array)) + print(f"transaction - Saved tensor: {tensor_name}, origin: {origin}, round: {fl_round}, report: {report}, tags: {serialized_tags}") + + def _init_task_results_table(self): + """ + Creates a table for storing task results. Drops the table first if it already exists. + """ + drop_table_query = "DROP TABLE IF EXISTS task_results" + self.cursor.execute(drop_table_query) + + create_table_query = """ + CREATE TABLE IF NOT EXISTS task_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + collaborator_name TEXT NOT NULL, + round_number INTEGER NOT NULL, + task_name TEXT NOT NULL, + data_size INTEGER NOT NULL, + named_tensors BLOB NOT NULL + ); + """ + self.cursor.execute(create_table_query) + + def _save_round_and_best_score(self, round_number: int, best_score: float) -> None: + """Save the round number and best score as key-value pairs in the database.""" + # Create a table with key-value structure where values can be integer or float + # Insert or update the round_number + self.cursor.execute(""" + INSERT OR REPLACE INTO key_value_store (key, value) + VALUES (?, ?) + """, ("round_number", float(round_number))) + + # Insert or update the best_score + self.cursor.execute(""" + INSERT OR REPLACE INTO key_value_store (key, value) + VALUES (?, ?) + """, ("best_score", float(best_score))) + + + + def load_tensors(self) -> Dict[TensorKey, np.ndarray]: + """Load all tensors from the SQLite database and return them as a dictionary.""" + tensor_dict = {} + with self.mutex: + self.cursor.execute("SELECT tensor_name, origin, round, report, tags, nparray FROM tensors") + rows = self.cursor.fetchall() + for row in rows: + tensor_name, origin, fl_round, report, tags, nparray = row + # Deserialize the JSON string back to a Python list + deserialized_tags = tuple(json.loads(tags)) + tensor_key = TensorKey(tensor_name, origin, fl_round, report, deserialized_tags) + tensor_dict[tensor_key] = self._deserialize_array(nparray) + return tensor_dict + + + def get_round_and_best_score(self) -> tuple[int, float]: + """Retrieve the round number and best score from the database.""" + with self.mutex: + # Fetch the round_number + self.cursor.execute(""" + SELECT value FROM key_value_store WHERE key = ? + """, ("round_number",)) + round_number = self.cursor.fetchone() + if round_number is None: + round_number = -1 + else: + round_number = int(round_number[0]) # Cast to int + + # Fetch the best_score + self.cursor.execute(""" + SELECT value FROM key_value_store WHERE key = ? + """, ("best_score",)) + best_score = self.cursor.fetchone() + if best_score is None: + best_score = 0 + else: + best_score = float(best_score[0]) # Cast to float + return round_number, best_score + + + def clean_up(self, remove_older_than: int = 1) -> None: + """Remove old entries from the database.""" + if remove_older_than < 0: + return + with self.mutex: + self.cursor.execute("SELECT MAX(round) FROM tensors") + current_round = self.cursor.fetchone()[0] + if current_round is None: + return + self.cursor.execute(""" + DELETE FROM tensors + WHERE round <= ? AND report = 0 + """, (current_round - remove_older_than,)) + self.conn.commit() + + + def close(self) -> None: + """Close the SQLite database connection.""" + self.conn.close() + + def is_task_table_empty(self) -> bool: + """Check if the task table is empty.""" + with self.mutex: + self.cursor.execute("SELECT COUNT(*) FROM task_results") + count = self.cursor.fetchone()[0] + return count == 0