From bd4040474b106bacc5f915027d9c8e1aa150c149 Mon Sep 17 00:00:00 2001 From: Naman Trivedi Date: Wed, 31 Jul 2024 13:57:40 +0000 Subject: [PATCH] Raise all init errors in init instead of suppressing them until the fist invoke --- awslambdaric/bootstrap.py | 31 +++++++++----------- tests/test_bootstrap.py | 60 ++++++--------------------------------- 2 files changed, 21 insertions(+), 70 deletions(-) diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index e737b7b..60aa216 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -37,36 +37,30 @@ def _get_handler(handler): try: (modname, fname) = handler.rsplit(".", 1) except ValueError as e: - fault = FaultException( + raise FaultException( FaultException.MALFORMED_HANDLER_NAME, "Bad handler '{}': {}".format(handler, str(e)), ) - return make_fault_handler(fault) try: if modname.split(".")[0] in sys.builtin_module_names: - fault = FaultException( + raise FaultException( FaultException.BUILT_IN_MODULE_CONFLICT, "Cannot use built-in module {} as a handler module".format(modname), ) - return make_fault_handler(fault) m = importlib.import_module(modname.replace("/", ".")) except ImportError as e: - fault = FaultException( + raise FaultException( FaultException.IMPORT_MODULE_ERROR, "Unable to import module '{}': {}".format(modname, str(e)), ) - request_handler = make_fault_handler(fault) - return request_handler except SyntaxError as e: trace = [' File "%s" Line %s\n %s' % (e.filename, e.lineno, e.text)] - fault = FaultException( + raise FaultException( FaultException.USER_CODE_SYNTAX_ERROR, "Syntax error in module '{}': {}".format(modname, str(e)), trace, ) - request_handler = make_fault_handler(fault) - return request_handler try: request_handler = getattr(m, fname) @@ -76,15 +70,8 @@ def _get_handler(handler): "Handler '{}' missing on module '{}'".format(fname, modname), None, ) - request_handler = make_fault_handler(fault) - return request_handler - - -def make_fault_handler(fault): - def result(*args): raise fault - - return result + return request_handler def make_error( @@ -475,15 +462,23 @@ def run(app_root, handler, lambda_runtime_api_addr): lambda_runtime_client = LambdaRuntimeClient( lambda_runtime_api_addr, use_thread_for_polling_next ) + error_result = None try: _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink) global _GLOBAL_AWS_REQUEST_ID request_handler = _get_handler(handler) + except FaultException as e: + error_result = make_error( + e.msg, + e.exception_type, + e.trace, + ) except Exception: error_result = build_fault_result(sys.exc_info(), None) + if error_result is not None: log_error(error_result, log_sink) lambda_runtime_client.post_init_error(to_json(error_result)) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index fd56d9f..7bc2ad2 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -603,43 +603,6 @@ def raise_exception_handler(json_input, lambda_context): self.assertEqual(mock_stdout.getvalue(), error_logs) - # The order of patches matter. Using MagicMock resets sys.stdout to the default. - @patch("importlib.import_module") - @patch("sys.stdout", new_callable=StringIO) - def test_handle_event_request_fault_exception_logging_syntax_error( - self, mock_stdout, mock_import_module - ): - try: - eval("-") - except SyntaxError as e: - syntax_error = e - - mock_import_module.side_effect = syntax_error - - response_handler = bootstrap._get_handler("a.b") - - bootstrap.handle_event_request( - self.lambda_runtime, - response_handler, - "invoke_id", - self.event_body, - "application/json", - {}, - {}, - "invoked_function_arn", - 0, - bootstrap.StandardLogSink(), - ) - error_logs = ( - lambda_unhandled_exception_warning_message - + f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r" - ) - error_logs += "Traceback (most recent call last):\r" - error_logs += '  File "" Line 1\r' - error_logs += "    -\n" - - self.assertEqual(mock_stdout.getvalue(), error_logs) - class TestXrayFault(unittest.TestCase): def test_make_xray(self): @@ -717,10 +680,8 @@ def __eq__(self, other): def test_get_event_handler_bad_handler(self): handler_name = "bad_handler" - response_handler = bootstrap._get_handler(handler_name) with self.assertRaises(FaultException) as cm: - response_handler() - + response_handler = bootstrap._get_handler(handler_name) returned_exception = cm.exception self.assertEqual( self.FaultExceptionMatcher( @@ -732,9 +693,8 @@ def test_get_event_handler_bad_handler(self): def test_get_event_handler_import_error(self): handler_name = "no_module.handler" - response_handler = bootstrap._get_handler(handler_name) with self.assertRaises(FaultException) as cm: - response_handler() + response_handler = bootstrap._get_handler(handler_name) returned_exception = cm.exception self.assertEqual( self.FaultExceptionMatcher( @@ -757,10 +717,9 @@ def test_get_event_handler_syntax_error(self): filename_w_ext = os.path.basename(tmp_file.name) filename, _ = os.path.splitext(filename_w_ext) handler_name = "{}.syntax_error".format(filename) - response_handler = bootstrap._get_handler(handler_name) with self.assertRaises(FaultException) as cm: - response_handler() + response_handler = bootstrap._get_handler(handler_name) returned_exception = cm.exception self.assertEqual( self.FaultExceptionMatcher( @@ -782,9 +741,8 @@ def test_get_event_handler_missing_error(self): filename_w_ext = os.path.basename(tmp_file.name) filename, _ = os.path.splitext(filename_w_ext) handler_name = "{}.my_handler".format(filename) - response_handler = bootstrap._get_handler(handler_name) with self.assertRaises(FaultException) as cm: - response_handler() + response_handler = bootstrap._get_handler(handler_name) returned_exception = cm.exception self.assertEqual( self.FaultExceptionMatcher( @@ -801,9 +759,8 @@ def test_get_event_handler_slash(self): response_handler() def test_get_event_handler_build_in_conflict(self): - response_handler = bootstrap._get_handler("sys.hello") with self.assertRaises(FaultException) as cm: - response_handler() + response_handler = bootstrap._get_handler("sys.hello") returned_exception = cm.exception self.assertEqual( self.FaultExceptionMatcher( @@ -1452,9 +1409,8 @@ def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout): class TestBootstrapModule(unittest.TestCase): - @patch("awslambdaric.bootstrap.handle_event_request") @patch("awslambdaric.bootstrap.LambdaRuntimeClient") - def test_run(self, mock_runtime_client, mock_handle_event_request): + def test_run(self, mock_runtime_client): expected_app_root = "/tmp/test/app_root" expected_handler = "app.my_test_handler" expected_lambda_runtime_api_addr = "test_addr" @@ -1467,12 +1423,12 @@ def test_run(self, mock_runtime_client, mock_handle_event_request): MagicMock(), ] - with self.assertRaises(TypeError): + with self.assertRaises(SystemExit) as cm: bootstrap.run( expected_app_root, expected_handler, expected_lambda_runtime_api_addr ) - mock_handle_event_request.assert_called_once() + self.assertEqual(cm.exception.code, 1) @patch( "awslambdaric.bootstrap.LambdaLoggerHandler",