From 7fc4261d9d8565d95d14849634fa1db8861ef254 Mon Sep 17 00:00:00 2001 From: Javier Date: Tue, 29 Oct 2024 20:58:56 +0000 Subject: [PATCH 1/3] refactor(framework) Remove `RunTracker` and logs thread from `Executor` (#4391) Co-authored-by: Heng Pan --- src/py/flwr/superexec/deployment.py | 6 +- src/py/flwr/superexec/exec_servicer.py | 77 ++------------------- src/py/flwr/superexec/exec_servicer_test.py | 22 +----- src/py/flwr/superexec/executor.py | 7 +- src/py/flwr/superexec/simulation.py | 11 ++- 5 files changed, 17 insertions(+), 106 deletions(-) diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index f13cd2f8ea20..dce25a0f3e69 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -28,7 +28,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate import LinkState, LinkStateFactory -from .executor import Executor, RunTracker +from .executor import Executor class DeploymentEngine(Executor): @@ -141,7 +141,7 @@ def start_run( fab_file: bytes, override_config: UserConfig, federation_config: UserConfig, - ) -> Optional[RunTracker]: + ) -> Optional[int]: """Start run using the Flower Deployment Engine.""" try: @@ -151,7 +151,7 @@ def start_run( ) log(INFO, "Created run %s", str(run_id)) - return None + return run_id # pylint: disable-next=broad-except except Exception as e: log(ERROR, "Could not start run: %s", str(e)) diff --git a/src/py/flwr/superexec/exec_servicer.py b/src/py/flwr/superexec/exec_servicer.py index 14c1a3548047..a8d11b51f4d5 100644 --- a/src/py/flwr/superexec/exec_servicer.py +++ b/src/py/flwr/superexec/exec_servicer.py @@ -15,10 +15,6 @@ """SuperExec API servicer.""" -import select -import sys -import threading -import time from collections.abc import Generator from logging import ERROR, INFO from typing import Any @@ -37,7 +33,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate import LinkStateFactory -from .executor import Executor, RunTracker +from .executor import Executor SELECT_TIMEOUT = 1 # Timeout for selecting ready-to-read file descriptors (in seconds) @@ -55,7 +51,6 @@ def __init__( self.ffs_factory = ffs_factory self.executor = executor self.executor.initialize(linkstate_factory, ffs_factory) - self.runs: dict[int, RunTracker] = {} def StartRun( self, request: StartRunRequest, context: grpc.ServicerContext @@ -63,25 +58,17 @@ def StartRun( """Create run ID.""" log(INFO, "ExecServicer.StartRun") - run = self.executor.start_run( + run_id = self.executor.start_run( request.fab.content, user_config_from_proto(request.override_config), user_config_from_proto(request.federation_config), ) - if run is None: + if run_id is None: log(ERROR, "Executor failed to start run") return StartRunResponse() - self.runs[run.run_id] = run - - # Start a background thread to capture the log output - capture_thread = threading.Thread( - target=_capture_logs, args=(run,), daemon=True - ) - capture_thread.start() - - return StartRunResponse(run_id=run.run_id) + return StartRunResponse(run_id=run_id) def StreamLogs( # pylint: disable=C0103 self, request: StreamLogsRequest, context: grpc.ServicerContext @@ -89,58 +76,4 @@ def StreamLogs( # pylint: disable=C0103 """Get logs.""" log(INFO, "ExecServicer.StreamLogs") - # Exit if `run_id` not found - if request.run_id not in self.runs: - context.abort(grpc.StatusCode.NOT_FOUND, "Run ID not found") - - last_sent_index = 0 - while context.is_active(): - # Yield n'th row of logs, if n'th row < len(logs) - logs = self.runs[request.run_id].logs - for i in range(last_sent_index, len(logs)): - yield StreamLogsResponse(log_output=logs[i]) - last_sent_index = len(logs) - - # Wait for and continue to yield more log responses only if the - # run isn't completed yet. If the run is finished, the entire log - # is returned at this point and the server ends the stream. - if self.runs[request.run_id].proc.poll() is not None: - log(INFO, "All logs for run ID `%s` returned", request.run_id) - context.set_code(grpc.StatusCode.OK) - context.cancel() - - time.sleep(1.0) # Sleep briefly to avoid busy waiting - - -def _capture_logs( - run: RunTracker, -) -> None: - while True: - # Explicitly check if Popen.poll() is None. Required for `pytest`. - if run.proc.poll() is None: - # Select streams only when ready to read - ready_to_read, _, _ = select.select( - [run.proc.stdout, run.proc.stderr], - [], - [], - SELECT_TIMEOUT, - ) - # Read from std* and append to RunTracker.logs - for stream in ready_to_read: - # Flush stdout to view output in real time - readline = stream.readline() - sys.stdout.write(readline) - sys.stdout.flush() - # Append to logs - line = readline.rstrip() - if line: - run.logs.append(f"{line}") - - # Close std* to prevent blocking - elif run.proc.poll() is not None: - log(INFO, "Subprocess finished, exiting log capture") - if run.proc.stdout: - run.proc.stdout.close() - if run.proc.stderr: - run.proc.stderr.close() - break + yield StreamLogsResponse() diff --git a/src/py/flwr/superexec/exec_servicer_test.py b/src/py/flwr/superexec/exec_servicer_test.py index 6044895de3cf..3b50200d22f2 100644 --- a/src/py/flwr/superexec/exec_servicer_test.py +++ b/src/py/flwr/superexec/exec_servicer_test.py @@ -20,7 +20,7 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611 -from .exec_servicer import ExecServicer, _capture_logs +from .exec_servicer import ExecServicer def test_start_run() -> None: @@ -36,7 +36,7 @@ def test_start_run() -> None: run_res.proc = proc executor = MagicMock() - executor.start_run = lambda _, __, ___: run_res + executor.start_run = lambda _, __, ___: run_res.run_id context_mock = MagicMock() @@ -48,22 +48,4 @@ def test_start_run() -> None: # Execute response = servicer.StartRun(request, context_mock) - assert response.run_id == 10 - - -def test_capture_logs() -> None: - """Test capture_logs function.""" - run_res = Mock() - run_res.logs = [] - with subprocess.Popen( - ["echo", "success"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) as proc: - run_res.proc = proc - _capture_logs(run_res) - - assert len(run_res.logs) == 1 - assert run_res.logs[0] == "success" diff --git a/src/py/flwr/superexec/executor.py b/src/py/flwr/superexec/executor.py index a36e1dec0fd2..fd87b0d742be 100644 --- a/src/py/flwr/superexec/executor.py +++ b/src/py/flwr/superexec/executor.py @@ -72,7 +72,7 @@ def start_run( fab_file: bytes, override_config: UserConfig, federation_config: UserConfig, - ) -> Optional[RunTracker]: + ) -> Optional[int]: """Start a run using the given Flower FAB ID and version. This method creates a new run on the SuperLink, returns its run_id @@ -89,7 +89,6 @@ def start_run( Returns ------- - run_id : Optional[RunTracker] - The run_id and the associated process of the run created by the SuperLink, - or `None` if it fails. + run_id : Optional[int] + The run_id of the run created by the SuperLink, or `None` if it fails. """ diff --git a/src/py/flwr/superexec/simulation.py b/src/py/flwr/superexec/simulation.py index 83ea0d0681a1..3941b0c98bc6 100644 --- a/src/py/flwr/superexec/simulation.py +++ b/src/py/flwr/superexec/simulation.py @@ -33,7 +33,7 @@ from flwr.server.superlink.linkstate import LinkStateFactory from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes -from .executor import Executor, RunTracker +from .executor import Executor def _user_config_to_str(user_config: UserConfig) -> str: @@ -125,7 +125,7 @@ def start_run( fab_file: bytes, override_config: UserConfig, federation_config: UserConfig, - ) -> Optional[RunTracker]: + ) -> Optional[int]: """Start run using the Flower Simulation Engine.""" if self.num_supernodes is None: raise ValueError( @@ -199,17 +199,14 @@ def start_run( command.extend(["--run-config", f"{override_config_str}"]) # Start Simulation - proc = subprocess.Popen( # pylint: disable=consider-using-with + _ = subprocess.Popen( # pylint: disable=consider-using-with command, text=True, ) log(INFO, "Started run %s", str(run_id)) - return RunTracker( - run_id=run_id, - proc=proc, - ) + return run_id # pylint: disable-next=broad-except except Exception as e: From fa787e675b064efbbf9475775ce8fb93ac5381ed Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 29 Oct 2024 21:08:48 +0000 Subject: [PATCH 2/3] feat(framework) Add `add_serverapp_log` and `get_serverapp_log` to `LinkState` (#4389) --- .../linkstate/in_memory_linkstate.py | 30 ++++++- .../server/superlink/linkstate/linkstate.py | 35 ++++++++ .../superlink/linkstate/linkstate_test.py | 87 ++++++++++++++++++- .../superlink/linkstate/sqlite_linkstate.py | 50 +++++++++++ 4 files changed, 197 insertions(+), 5 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index c616dafa8951..0830c26fc49c 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -17,7 +17,8 @@ import threading import time -from dataclasses import dataclass +from bisect import bisect_right +from dataclasses import dataclass, field from logging import ERROR, WARNING from typing import Optional from uuid import UUID, uuid4 @@ -43,7 +44,7 @@ @dataclass -class RunRecord: +class RunRecord: # pylint: disable=R0902 """The record of a specific run, including its status and timestamps.""" run: Run @@ -52,6 +53,8 @@ class RunRecord: starting_at: str = "" running_at: str = "" finished_at: str = "" + logs: list[tuple[float, str]] = field(default_factory=list) + log_lock: threading.Lock = field(default_factory=threading.Lock) class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904 @@ -511,3 +514,26 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None: if run_id not in self.run_ids: raise ValueError(f"Run {run_id} not found") self.contexts[run_id] = context + + def add_serverapp_log(self, run_id: int, log_message: str) -> None: + """Add a log entry to the serverapp logs for the specified `run_id`.""" + if run_id not in self.run_ids: + raise ValueError(f"Run {run_id} not found") + run = self.run_ids[run_id] + with run.log_lock: + run.logs.append((now().timestamp(), log_message)) + + def get_serverapp_log( + self, run_id: int, after_timestamp: Optional[float] + ) -> tuple[str, float]: + """Get the serverapp logs for the specified `run_id`.""" + if run_id not in self.run_ids: + raise ValueError(f"Run {run_id} not found") + run = self.run_ids[run_id] + if after_timestamp is None: + after_timestamp = 0.0 + with run.log_lock: + # Find the index where the timestamp would be inserted + index = bisect_right(run.logs, (after_timestamp, "")) + latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0 + return "".join(log for _, log in run.logs[index:]), latest_timestamp diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 0ca9180f4e39..a40ea0fc139c 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -299,3 +299,38 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None: context : Context The context to be associated with the specified `run_id`. """ + + @abc.abstractmethod + def add_serverapp_log(self, run_id: int, log_message: str) -> None: + """Add a log entry to the ServerApp logs for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run for which to add a log entry. + log_message : str + The log entry to be added to the ServerApp logs. + """ + + @abc.abstractmethod + def get_serverapp_log( + self, run_id: int, after_timestamp: Optional[float] + ) -> tuple[str, float]: + """Get the ServerApp logs for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run for which to retrieve the ServerApp logs. + + after_timestamp : Optional[float] + Retrieve logs after this timestamp. If set to `None`, retrieve all logs. + + Returns + ------- + tuple[str, float] + A tuple containing: + - The ServerApp logs associated with the specified `run_id`. + - The timestamp of the latest log entry in the returned logs. + Returns `0` if no logs are returned. + """ diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index d29358a24825..1fc21bf02a2a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -23,7 +23,7 @@ from unittest.mock import patch from uuid import UUID -from flwr.common import DEFAULT_TTL, Context, RecordSet +from flwr.common import DEFAULT_TTL, Context, RecordSet, now from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, @@ -1038,6 +1038,87 @@ def test_set_context_invalid_run_id(self) -> None: with self.assertRaises(ValueError): state.set_serverapp_context(61016, context) # Invalid run_id + def test_add_serverapp_log_invalid_run_id(self) -> None: + """Test adding serverapp log with invalid run_id.""" + # Prepare + state: LinkState = self.state_factory() + invalid_run_id = 99999 + log_entry = "Invalid log entry" + + # Execute and assert + with self.assertRaises(ValueError): + state.add_serverapp_log(invalid_run_id, log_entry) + + def test_get_serverapp_log_invalid_run_id(self) -> None: + """Test retrieving serverapp log with invalid run_id.""" + # Prepare + state: LinkState = self.state_factory() + invalid_run_id = 99999 + + # Execute and assert + with self.assertRaises(ValueError): + state.get_serverapp_log(invalid_run_id, after_timestamp=None) + + def test_add_and_get_serverapp_log(self) -> None: + """Test adding and retrieving serverapp logs.""" + # Prepare + state: LinkState = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + log_entry_1 = "Log entry 1" + log_entry_2 = "Log entry 2" + timestamp = now().timestamp() + + # Execute + state.add_serverapp_log(run_id, log_entry_1) + state.add_serverapp_log(run_id, log_entry_2) + retrieved_logs, latest = state.get_serverapp_log( + run_id, after_timestamp=timestamp + ) + + # Assert + assert latest > timestamp + assert log_entry_1 + log_entry_2 == retrieved_logs + + def test_get_serverapp_log_after_timestamp(self) -> None: + """Test retrieving serverapp logs after a specific timestamp.""" + # Prepare + state: LinkState = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + log_entry_1 = "Log entry 1" + log_entry_2 = "Log entry 2" + state.add_serverapp_log(run_id, log_entry_1) + timestamp = now().timestamp() + state.add_serverapp_log(run_id, log_entry_2) + + # Execute + retrieved_logs, latest = state.get_serverapp_log( + run_id, after_timestamp=timestamp + ) + + # Assert + assert latest > timestamp + assert log_entry_1 not in retrieved_logs + assert log_entry_2 == retrieved_logs + + def test_get_serverapp_log_after_timestamp_no_logs(self) -> None: + """Test retrieving serverapp logs after a specific timestamp but no logs are + found.""" + # Prepare + state: LinkState = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + log_entry = "Log entry" + state.add_serverapp_log(run_id, log_entry) + timestamp = now().timestamp() + + # Execute + retrieved_logs, latest = state.get_serverapp_log( + run_id, after_timestamp=timestamp + ) + + # Assert + assert latest == 0 + assert retrieved_logs == "" + def create_task_ins( consumer_node_id: int, @@ -1123,7 +1204,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 15 + assert len(result) == 17 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -1148,7 +1229,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 15 + assert len(result) == 17 if __name__ == "__main__": diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 89d00528fa56..a931e747fac1 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -99,6 +99,17 @@ ); """ +SQL_CREATE_TABLE_LOGS = """ +CREATE TABLE IF NOT EXISTS logs ( + timestamp REAL, + run_id INTEGER, + node_id INTEGER, + log TEXT, + PRIMARY KEY (timestamp, run_id, node_id), + FOREIGN KEY (run_id) REFERENCES run(run_id) +); +""" + SQL_CREATE_TABLE_CONTEXT = """ CREATE TABLE IF NOT EXISTS context( run_id INTEGER UNIQUE, @@ -191,6 +202,7 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]: # Create each table if not exists queries cur.execute(SQL_CREATE_TABLE_RUN) + cur.execute(SQL_CREATE_TABLE_LOGS) cur.execute(SQL_CREATE_TABLE_CONTEXT) cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) @@ -1015,6 +1027,44 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None: except sqlite3.IntegrityError: raise ValueError(f"Run {run_id} not found") from None + def add_serverapp_log(self, run_id: int, log_message: str) -> None: + """Add a log entry to the ServerApp logs for the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + + # Store log + try: + query = """ + INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?); + """ + self.query(query, (now().timestamp(), sint64_run_id, 0, log_message)) + except sqlite3.IntegrityError: + raise ValueError(f"Run {run_id} not found") from None + + def get_serverapp_log( + self, run_id: int, after_timestamp: Optional[float] + ) -> tuple[str, float]: + """Get the ServerApp logs for the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + + # Check if the run_id exists + query = "SELECT run_id FROM run WHERE run_id = ?;" + if not self.query(query, (sint64_run_id,)): + raise ValueError(f"Run {run_id} not found") + + # Retrieve logs + if after_timestamp is None: + after_timestamp = 0.0 + query = """ + SELECT log, timestamp FROM logs + WHERE run_id = ? AND node_id = ? AND timestamp > ?; + """ + rows = self.query(query, (sint64_run_id, 0, after_timestamp)) + rows.sort(key=lambda x: x["timestamp"]) + latest_timestamp = rows[-1]["timestamp"] if rows else 0.0 + return "".join(row["log"] for row in rows), latest_timestamp + def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]: """Check if the TaskIns exists and is valid (not expired). From 322a132924b6353bfb58b6c6029c419ea057fbf1 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 29 Oct 2024 21:51:52 +0000 Subject: [PATCH 3/3] feat(framework) Add utility functions for capturing stdout/stderr (#4392) --- src/py/flwr/common/logger.py | 31 +++++++++++++++++ src/py/flwr/common/logger_test.py | 56 +++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 src/py/flwr/common/logger_test.py diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 3a058abac9c6..58edce76c718 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -16,8 +16,10 @@ import logging +import sys from logging import WARN, LogRecord from logging.handlers import HTTPHandler +from queue import Queue from typing import TYPE_CHECKING, Any, Optional, TextIO # Create logger @@ -259,3 +261,32 @@ def set_logger_propagation( if not child_logger.propagate: child_logger.log(logging.DEBUG, "Logger propagate set to False") return child_logger + + +def mirror_output_to_queue(log_queue: Queue[str]) -> None: + """Mirror stdout and stderr output to the provided queue.""" + + def get_write_fn(stream: TextIO) -> Any: + original_write = stream.write + + def fn(s: str) -> int: + ret = original_write(s) + stream.flush() + log_queue.put(s) + return ret + + return fn + + sys.stdout.write = get_write_fn(sys.stdout) # type: ignore[method-assign] + sys.stderr.write = get_write_fn(sys.stderr) # type: ignore[method-assign] + console_handler.stream = sys.stdout + + +def restore_output() -> None: + """Restore stdout and stderr. + + This will stop mirroring output to queues. + """ + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + console_handler.stream = sys.stdout diff --git a/src/py/flwr/common/logger_test.py b/src/py/flwr/common/logger_test.py new file mode 100644 index 000000000000..a566d5b35e1a --- /dev/null +++ b/src/py/flwr/common/logger_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower Logger tests.""" + + +import sys +from queue import Queue + +from .logger import mirror_output_to_queue, restore_output + + +def test_mirror_output_to_queue() -> None: + """Test that stdout and stderr are mirrored to the provided queue.""" + # Prepare + log_queue: Queue[str] = Queue() + + # Execute + mirror_output_to_queue(log_queue) + print("Test message") + sys.stderr.write("Error message\n") + + # Assert + assert not log_queue.empty() + assert log_queue.get() == "Test message" + assert log_queue.get() == "\n" + assert log_queue.get() == "Error message\n" + + +def test_restore_output() -> None: + """Test that stdout and stderr are restored after calling restore_output.""" + # Prepare + log_queue: Queue[str] = Queue() + + # Execute + mirror_output_to_queue(log_queue) + print("Test message before restore") + restore_output() + print("Test message after restore") + sys.stderr.write("Error message after restore\n") + + # Assert + assert log_queue.get() == "Test message before restore" + assert log_queue.get() == "\n" + assert log_queue.empty()