diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index cb3bb397348..bb0d9865f2c 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -3,13 +3,13 @@ from typing import Any -def json_formatter(unserialized_value: Any): - """JSON custom serializer to cast unserialisable values to strings. +def json_formatter(unserializable_value: Any): + """JSON custom serializer to cast unserializable values to strings. Example ------- - **Serialize unserialisable value to string** + **Serialize unserializable value to string** class X: pass value = {"x": X()} @@ -18,10 +18,10 @@ class X: pass Parameters ---------- - unserialized_value: Any + unserializable_value: Any Python object unserializable by JSON """ - return str(unserialized_value) + return str(unserializable_value) class JsonFormatter(logging.Formatter): @@ -39,11 +39,12 @@ def __init__(self, **kwargs): """Return a JsonFormatter instance. The `json_default` kwarg is used to specify a formatter for otherwise - unserialisable values. It must not throw. Defaults to a function that + unserializable values. It must not throw. Defaults to a function that coerces the value to a string. Other kwargs are used to specify log field format strings. """ + self.default_json_formatter = kwargs.pop("json_default", json_formatter) datefmt = kwargs.pop("datefmt", None) super(JsonFormatter, self).__init__(datefmt=datefmt) @@ -54,7 +55,6 @@ def __init__(self, **kwargs): "location": "%(funcName)s:%(lineno)d", } self.format_dict.update(kwargs) - self.default_json_formatter = kwargs.pop("json_default", json_formatter) def update_formatter(self, **kwargs): self.format_dict.update(kwargs) @@ -64,6 +64,7 @@ def format(self, record): # noqa: A003 record_dict["asctime"] = self.formatTime(record, self.datefmt) log_dict = {} + for key, value in self.format_dict.items(): if value and key in self.reserved_keys: # converts default logging expr to its record value @@ -84,19 +85,13 @@ def format(self, record): # noqa: A003 except (json.decoder.JSONDecodeError, TypeError, ValueError): pass - if record.exc_info: + if record.exc_info and not record.exc_text: # Cache the traceback text to avoid converting it multiple times # (it's constant anyway) # from logging.Formatter:format - if not record.exc_text: # pragma: no cover - record.exc_text = self.formatException(record.exc_info) + record.exc_text = self.formatException(record.exc_info) if record.exc_text: log_dict["exception"] = record.exc_text - json_record = json.dumps(log_dict, default=self.default_json_formatter) - - if hasattr(json_record, "decode"): # pragma: no cover - json_record = json_record.decode("utf-8") - - return json_record + return json.dumps(log_dict, default=self.default_json_formatter) diff --git a/tests/functional/test_aws_lambda_logging.py b/tests/functional/test_aws_lambda_logging.py index cf4782d1d2a..86c6124cb1f 100644 --- a/tests/functional/test_aws_lambda_logging.py +++ b/tests/functional/test_aws_lambda_logging.py @@ -38,16 +38,17 @@ def test_setup_with_valid_log_levels(stdout, level): def test_logging_exception_traceback(stdout): - logger = Logger(level="DEBUG", stream=stdout, request_id="request id!", another="value") + logger = Logger(level="DEBUG", stream=stdout) try: - raise Exception("Boom") - except Exception: - logger.exception("This is a test") + raise ValueError("Boom") + except ValueError: + logger.exception("A value error occurred") log_dict = json.loads(stdout.getvalue()) check_log_dict(log_dict) + assert "ERROR" == log_dict["level"] assert "exception" in log_dict @@ -86,15 +87,32 @@ def test_with_json_message(stdout): assert msg == log_dict["message"] -def test_with_unserialisable_value_in_message(stdout): +def test_with_unserializable_value_in_message(stdout): logger = Logger(level="DEBUG", stream=stdout) - class X: + class Unserializable: pass - msg = {"x": X()} + msg = {"x": Unserializable()} logger.debug(msg) log_dict = json.loads(stdout.getvalue()) assert log_dict["message"]["x"].startswith("<") + + +def test_with_unserializable_value_in_message_custom(stdout): + class Unserializable: + pass + + # GIVEN a custom json_default + logger = Logger(level="DEBUG", stream=stdout, json_default=lambda o: f"") + + # WHEN we log a message + logger.debug({"x": Unserializable()}) + + log_dict = json.loads(stdout.getvalue()) + + # THEN json_default should not be in the log message and the custom unserializable handler should be used + assert log_dict["message"]["x"] == "" + assert "json_default" not in log_dict