diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index 1b42df249ae..90799b84ed1 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -187,6 +187,7 @@ def remove_keys(self, keys: Iterable[str]): def clear_state(self): self.log_format = dict.fromkeys(self.log_record_order) + self.log_format.update(**self._build_default_keys()) @staticmethod def _build_default_keys(): diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index 9bbc4452d98..de9de42601f 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -663,6 +663,44 @@ def handler(event, context): assert "my_key" not in second_log +def test_clear_state_keeps_standard_keys(lambda_context, stdout, service_name): + # GIVEN + logger = Logger(service=service_name, stream=stdout) + standard_keys = ["level", "location", "message", "timestamp", "service"] + + # WHEN clear_state is set + @logger.inject_lambda_context(clear_state=True) + def handler(event, context): + logger.info("Foo") + + # THEN all standard keys should be available as usual + handler({}, lambda_context) + handler({}, lambda_context) + + first_log, second_log = capture_multiple_logging_statements_output(stdout) + for key in standard_keys: + assert key in first_log + assert key in second_log + + +def test_clear_state_keeps_exception_keys(lambda_context, stdout, service_name): + # GIVEN + logger = Logger(service=service_name, stream=stdout) + + # WHEN clear_state is set and an exception was logged + @logger.inject_lambda_context(clear_state=True) + def handler(event, context): + try: + raise ValueError("something went wrong") + except Exception: + logger.exception("Received an exception") + + # THEN we expect a "exception_name" to be "ValueError" + handler({}, lambda_context) + log = capture_logging_output(stdout) + assert "ValueError" == log["exception_name"] + + def test_inject_lambda_context_allows_handler_with_kwargs(lambda_context, stdout, service_name): # GIVEN logger = Logger(service=service_name, stream=stdout)