diff --git a/src/py/flwr/cli/log.py b/src/py/flwr/cli/log.py index 6915de1e00c5..cd4079c1c131 100644 --- a/src/py/flwr/cli/log.py +++ b/src/py/flwr/cli/log.py @@ -26,18 +26,49 @@ from flwr.cli.config_utils import load_and_validate from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel from flwr.common.logger import log as logger +from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611 +from flwr.proto.exec_pb2_grpc import ExecStub CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds) -# pylint: disable=unused-argument -def stream_logs(run_id: int, channel: grpc.Channel, period: int) -> None: +def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None: """Stream logs from the beginning of a run with connection refresh.""" + start_time = time.time() + stub = ExecStub(channel) + req = StreamLogsRequest(run_id=run_id) + + for res in stub.StreamLogs(req): + print(res.log_output) + if time.time() - start_time > duration: + break -# pylint: disable=unused-argument def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None: """Print logs from the beginning of a run.""" + stub = ExecStub(channel) + req = StreamLogsRequest(run_id=run_id) + + try: + while True: + try: + # Enforce timeout for graceful exit + for res in stub.StreamLogs(req, timeout=timeout): + print(res.log_output) + except grpc.RpcError as e: + # pylint: disable=E1101 + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + break + if e.code() == grpc.StatusCode.NOT_FOUND: + logger(ERROR, "Invalid run_id `%s`, exiting", run_id) + break + if e.code() == grpc.StatusCode.CANCELLED: + break + except KeyboardInterrupt: + logger(DEBUG, "Stream interrupted by user") + finally: + channel.close() + logger(DEBUG, "Channel closed") def on_channel_state_change(channel_connectivity: str) -> None: diff --git a/src/py/flwr/cli/log_test.py b/src/py/flwr/cli/log_test.py new file mode 100644 index 000000000000..932610bea2f3 --- /dev/null +++ b/src/py/flwr/cli/log_test.py @@ -0,0 +1,78 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface `log` command.""" + + +import unittest +from typing import NoReturn +from unittest.mock import Mock, call, patch + +from flwr.proto.exec_pb2 import StreamLogsResponse # pylint: disable=E0611 + +from .log import print_logs, stream_logs + + +class InterruptedStreamLogsResponse: + """Create a StreamLogsResponse object with KeyboardInterrupt.""" + + @property + def log_output(self) -> NoReturn: + """Raise KeyboardInterrupt to exit logstream test gracefully.""" + raise KeyboardInterrupt + + +class TestFlwrLog(unittest.TestCase): + """Unit tests for `flwr log` CLI functions.""" + + def setUp(self) -> None: + """Initialize mock ExecStub before each test.""" + self.expected_calls = [ + call("log_output_1"), + call("log_output_2"), + call("log_output_3"), + ] + mock_response_iterator = [ + iter( + [StreamLogsResponse(log_output=f"log_output_{i}") for i in range(1, 4)] + + [InterruptedStreamLogsResponse()] + ) + ] + self.mock_stub = Mock() + self.mock_stub.StreamLogs.side_effect = mock_response_iterator + self.patcher = patch("flwr.cli.log.ExecStub", return_value=self.mock_stub) + + self.patcher.start() + + # Create mock channel + self.mock_channel = Mock() + + def tearDown(self) -> None: + """Cleanup.""" + self.patcher.stop() + + def test_flwr_log_stream_method(self) -> None: + """Test stream_logs.""" + with patch("builtins.print") as mock_print: + with self.assertRaises(KeyboardInterrupt): + stream_logs(run_id=123, channel=self.mock_channel, duration=1) + # Assert that mock print was called with the expected arguments + mock_print.assert_has_calls(self.expected_calls) + + def test_flwr_log_print_method(self) -> None: + """Test print_logs.""" + with patch("builtins.print") as mock_print: + print_logs(run_id=123, channel=self.mock_channel, timeout=0) + # Assert that mock print was called with the expected arguments + mock_print.assert_has_calls(self.expected_calls)