Skip to content

Commit

Permalink
Merge branch 'main' into underscore
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Oct 30, 2024
2 parents 656ffcf + 322a132 commit b84e08c
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 111 deletions.
31 changes: 31 additions & 0 deletions src/py/flwr/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@


import logging
import sys
from logging import WARN, LogRecord
from logging.handlers import HTTPHandler
from queue import Queue
from typing import TYPE_CHECKING, Any, Optional, TextIO

# Create logger
Expand Down Expand Up @@ -259,3 +261,32 @@ def set_logger_propagation(
if not child_logger.propagate:
child_logger.log(logging.DEBUG, "Logger propagate set to False")
return child_logger


def mirror_output_to_queue(log_queue: Queue[str]) -> None:
"""Mirror stdout and stderr output to the provided queue."""

def get_write_fn(stream: TextIO) -> Any:
original_write = stream.write

def fn(s: str) -> int:
ret = original_write(s)
stream.flush()
log_queue.put(s)
return ret

return fn

sys.stdout.write = get_write_fn(sys.stdout) # type: ignore[method-assign]
sys.stderr.write = get_write_fn(sys.stderr) # type: ignore[method-assign]
console_handler.stream = sys.stdout


def restore_output() -> None:
"""Restore stdout and stderr.
This will stop mirroring output to queues.
"""
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
console_handler.stream = sys.stdout
56 changes: 56 additions & 0 deletions src/py/flwr/common/logger_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Logger tests."""


import sys
from queue import Queue

from .logger import mirror_output_to_queue, restore_output


def test_mirror_output_to_queue() -> None:
"""Test that stdout and stderr are mirrored to the provided queue."""
# Prepare
log_queue: Queue[str] = Queue()

# Execute
mirror_output_to_queue(log_queue)
print("Test message")
sys.stderr.write("Error message\n")

# Assert
assert not log_queue.empty()
assert log_queue.get() == "Test message"
assert log_queue.get() == "\n"
assert log_queue.get() == "Error message\n"


def test_restore_output() -> None:
"""Test that stdout and stderr are restored after calling restore_output."""
# Prepare
log_queue: Queue[str] = Queue()

# Execute
mirror_output_to_queue(log_queue)
print("Test message before restore")
restore_output()
print("Test message after restore")
sys.stderr.write("Error message after restore\n")

# Assert
assert log_queue.get() == "Test message before restore"
assert log_queue.get() == "\n"
assert log_queue.empty()
30 changes: 28 additions & 2 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

import threading
import time
from dataclasses import dataclass
from bisect import bisect_right
from dataclasses import dataclass, field
from logging import ERROR, WARNING
from typing import Optional
from uuid import UUID, uuid4
Expand All @@ -43,7 +44,7 @@


@dataclass
class RunRecord:
class RunRecord: # pylint: disable=R0902
"""The record of a specific run, including its status and timestamps."""

run: Run
Expand All @@ -52,6 +53,8 @@ class RunRecord:
starting_at: str = ""
running_at: str = ""
finished_at: str = ""
logs: list[tuple[float, str]] = field(default_factory=list)
log_lock: threading.Lock = field(default_factory=threading.Lock)


class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
Expand Down Expand Up @@ -511,3 +514,26 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None:
if run_id not in self.run_ids:
raise ValueError(f"Run {run_id} not found")
self.contexts[run_id] = context

def add_serverapp_log(self, run_id: int, log_message: str) -> None:
"""Add a log entry to the serverapp logs for the specified `run_id`."""
if run_id not in self.run_ids:
raise ValueError(f"Run {run_id} not found")
run = self.run_ids[run_id]
with run.log_lock:
run.logs.append((now().timestamp(), log_message))

def get_serverapp_log(
self, run_id: int, after_timestamp: Optional[float]
) -> tuple[str, float]:
"""Get the serverapp logs for the specified `run_id`."""
if run_id not in self.run_ids:
raise ValueError(f"Run {run_id} not found")
run = self.run_ids[run_id]
if after_timestamp is None:
after_timestamp = 0.0
with run.log_lock:
# Find the index where the timestamp would be inserted
index = bisect_right(run.logs, (after_timestamp, ""))
latest_timestamp = run.logs[-1][0] if index < len(run.logs) else 0.0
return "".join(log for _, log in run.logs[index:]), latest_timestamp
35 changes: 35 additions & 0 deletions src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,38 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None:
context : Context
The context to be associated with the specified `run_id`.
"""

@abc.abstractmethod
def add_serverapp_log(self, run_id: int, log_message: str) -> None:
"""Add a log entry to the ServerApp logs for the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run for which to add a log entry.
log_message : str
The log entry to be added to the ServerApp logs.
"""

@abc.abstractmethod
def get_serverapp_log(
self, run_id: int, after_timestamp: Optional[float]
) -> tuple[str, float]:
"""Get the ServerApp logs for the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run for which to retrieve the ServerApp logs.
after_timestamp : Optional[float]
Retrieve logs after this timestamp. If set to `None`, retrieve all logs.
Returns
-------
tuple[str, float]
A tuple containing:
- The ServerApp logs associated with the specified `run_id`.
- The timestamp of the latest log entry in the returned logs.
Returns `0` if no logs are returned.
"""
87 changes: 84 additions & 3 deletions src/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from unittest.mock import patch
from uuid import UUID

from flwr.common import DEFAULT_TTL, Context, RecordSet
from flwr.common import DEFAULT_TTL, Context, RecordSet, now
from flwr.common.constant import ErrorCode, Status, SubStatus
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
generate_key_pairs,
Expand Down Expand Up @@ -1038,6 +1038,87 @@ def test_set_context_invalid_run_id(self) -> None:
with self.assertRaises(ValueError):
state.set_serverapp_context(61016, context) # Invalid run_id

def test_add_serverapp_log_invalid_run_id(self) -> None:
"""Test adding serverapp log with invalid run_id."""
# Prepare
state: LinkState = self.state_factory()
invalid_run_id = 99999
log_entry = "Invalid log entry"

# Execute and assert
with self.assertRaises(ValueError):
state.add_serverapp_log(invalid_run_id, log_entry)

def test_get_serverapp_log_invalid_run_id(self) -> None:
"""Test retrieving serverapp log with invalid run_id."""
# Prepare
state: LinkState = self.state_factory()
invalid_run_id = 99999

# Execute and assert
with self.assertRaises(ValueError):
state.get_serverapp_log(invalid_run_id, after_timestamp=None)

def test_add_and_get_serverapp_log(self) -> None:
"""Test adding and retrieving serverapp logs."""
# Prepare
state: LinkState = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
log_entry_1 = "Log entry 1"
log_entry_2 = "Log entry 2"
timestamp = now().timestamp()

# Execute
state.add_serverapp_log(run_id, log_entry_1)
state.add_serverapp_log(run_id, log_entry_2)
retrieved_logs, latest = state.get_serverapp_log(
run_id, after_timestamp=timestamp
)

# Assert
assert latest > timestamp
assert log_entry_1 + log_entry_2 == retrieved_logs

def test_get_serverapp_log_after_timestamp(self) -> None:
"""Test retrieving serverapp logs after a specific timestamp."""
# Prepare
state: LinkState = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
log_entry_1 = "Log entry 1"
log_entry_2 = "Log entry 2"
state.add_serverapp_log(run_id, log_entry_1)
timestamp = now().timestamp()
state.add_serverapp_log(run_id, log_entry_2)

# Execute
retrieved_logs, latest = state.get_serverapp_log(
run_id, after_timestamp=timestamp
)

# Assert
assert latest > timestamp
assert log_entry_1 not in retrieved_logs
assert log_entry_2 == retrieved_logs

def test_get_serverapp_log_after_timestamp_no_logs(self) -> None:
"""Test retrieving serverapp logs after a specific timestamp but no logs are
found."""
# Prepare
state: LinkState = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
log_entry = "Log entry"
state.add_serverapp_log(run_id, log_entry)
timestamp = now().timestamp()

# Execute
retrieved_logs, latest = state.get_serverapp_log(
run_id, after_timestamp=timestamp
)

# Assert
assert latest == 0
assert retrieved_logs == ""


def create_task_ins(
consumer_node_id: int,
Expand Down Expand Up @@ -1123,7 +1204,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 15
assert len(result) == 17


class SqliteFileBasedTest(StateTest, unittest.TestCase):
Expand All @@ -1148,7 +1229,7 @@ def test_initialize(self) -> None:
result = state.query("SELECT name FROM sqlite_schema;")

# Assert
assert len(result) == 15
assert len(result) == 17


if __name__ == "__main__":
Expand Down
50 changes: 50 additions & 0 deletions src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@
);
"""

SQL_CREATE_TABLE_LOGS = """
CREATE TABLE IF NOT EXISTS logs (
timestamp REAL,
run_id INTEGER,
node_id INTEGER,
log TEXT,
PRIMARY KEY (timestamp, run_id, node_id),
FOREIGN KEY (run_id) REFERENCES run(run_id)
);
"""

SQL_CREATE_TABLE_CONTEXT = """
CREATE TABLE IF NOT EXISTS context(
run_id INTEGER UNIQUE,
Expand Down Expand Up @@ -191,6 +202,7 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:

# Create each table if not exists queries
cur.execute(SQL_CREATE_TABLE_RUN)
cur.execute(SQL_CREATE_TABLE_LOGS)
cur.execute(SQL_CREATE_TABLE_CONTEXT)
cur.execute(SQL_CREATE_TABLE_TASK_INS)
cur.execute(SQL_CREATE_TABLE_TASK_RES)
Expand Down Expand Up @@ -1015,6 +1027,44 @@ def set_serverapp_context(self, run_id: int, context: Context) -> None:
except sqlite3.IntegrityError:
raise ValueError(f"Run {run_id} not found") from None

def add_serverapp_log(self, run_id: int, log_message: str) -> None:
"""Add a log entry to the ServerApp logs for the specified `run_id`."""
# Convert the uint64 value to sint64 for SQLite
sint64_run_id = convert_uint64_to_sint64(run_id)

# Store log
try:
query = """
INSERT INTO logs (timestamp, run_id, node_id, log) VALUES (?, ?, ?, ?);
"""
self.query(query, (now().timestamp(), sint64_run_id, 0, log_message))
except sqlite3.IntegrityError:
raise ValueError(f"Run {run_id} not found") from None

def get_serverapp_log(
self, run_id: int, after_timestamp: Optional[float]
) -> tuple[str, float]:
"""Get the ServerApp logs for the specified `run_id`."""
# Convert the uint64 value to sint64 for SQLite
sint64_run_id = convert_uint64_to_sint64(run_id)

# Check if the run_id exists
query = "SELECT run_id FROM run WHERE run_id = ?;"
if not self.query(query, (sint64_run_id,)):
raise ValueError(f"Run {run_id} not found")

# Retrieve logs
if after_timestamp is None:
after_timestamp = 0.0
query = """
SELECT log, timestamp FROM logs
WHERE run_id = ? AND node_id = ? AND timestamp > ?;
"""
rows = self.query(query, (sint64_run_id, 0, after_timestamp))
rows.sort(key=lambda x: x["timestamp"])
latest_timestamp = rows[-1]["timestamp"] if rows else 0.0
return "".join(row["log"] for row in rows), latest_timestamp

def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
"""Check if the TaskIns exists and is valid (not expired).
Expand Down
Loading

0 comments on commit b84e08c

Please sign in to comment.