From e7ff0acdac1f1d61073a59862d39f57784dcc56d Mon Sep 17 00:00:00 2001 From: Alexei Znamensky <103110+russoz@users.noreply.github.com> Date: Fri, 15 Sep 2023 21:36:59 +1200 Subject: [PATCH] refactor test helper context class (#7266) --- tests/unit/plugins/modules/helper.py | 150 ++++++++++++++------------- 1 file changed, 78 insertions(+), 72 deletions(-) diff --git a/tests/unit/plugins/modules/helper.py b/tests/unit/plugins/modules/helper.py index b02fe886359..94ccfb4f385 100644 --- a/tests/unit/plugins/modules/helper.py +++ b/tests/unit/plugins/modules/helper.py @@ -18,6 +18,83 @@ RunCmdCall = namedtuple("RunCmdCall", ["command", "environ", "rc", "out", "err"]) +class _BaseContext(object): + def __init__(self, helper, testcase, mocker, capfd): + self.helper = helper + self.testcase = testcase + self.mocker = mocker + self.capfd = capfd + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def _run(self): + with pytest.raises(SystemExit): + self.helper.module_main() + + out, err = self.capfd.readouterr() + results = json.loads(out) + + self.check_results(results) + + def test_flags(self, flag=None): + flags = self.testcase.flags + if flag: + flags = flags.get(flag) + return flags + + def run(self): + func = self._run + + test_flags = self.test_flags() + if test_flags.get("skip"): + pytest.skip(reason=test_flags["skip"]) + if test_flags.get("xfail"): + pytest.xfail(reason=test_flags["xfail"]) + + func() + + def check_results(self, results): + print("testcase =\n%s" % str(self.testcase)) + print("results =\n%s" % results) + if 'exception' in results: + print("exception = \n%s" % results["exception"]) + + for test_result in self.testcase.output: + assert results[test_result] == self.testcase.output[test_result], \ + "'{0}': '{1}' != '{2}'".format(test_result, results[test_result], self.testcase.output[test_result]) + + +class _RunCmdContext(_BaseContext): + def __init__(self, *args, **kwargs): + super(_RunCmdContext, self).__init__(*args, **kwargs) + self.run_cmd_calls = self.testcase.run_command_calls + self.mock_run_cmd = self._make_mock_run_cmd() + + def _make_mock_run_cmd(self): + call_results = [(x.rc, x.out, x.err) for x in self.run_cmd_calls] + error_call_results = (123, + "OUT: testcase has not enough run_command calls", + "ERR: testcase has not enough run_command calls") + mock_run_command = self.mocker.patch('ansible.module_utils.basic.AnsibleModule.run_command', + side_effect=chain(call_results, repeat(error_call_results))) + return mock_run_command + + def check_results(self, results): + super(_RunCmdContext, self).check_results(results) + call_args_list = [(item[0][0], item[1]) for item in self.mock_run_cmd.call_args_list] + expected_call_args_list = [(item.command, item.environ) for item in self.run_cmd_calls] + print("call args list =\n%s" % call_args_list) + print("expected args list =\n%s" % expected_call_args_list) + + assert self.mock_run_cmd.call_count == len(self.run_cmd_calls) + if self.mock_run_cmd.call_count: + assert call_args_list == expected_call_args_list + + class Helper(object): @staticmethod def from_list(module_main, list_): @@ -73,7 +150,7 @@ def testcases_ids(self): return [item.id for item in self.testcases] def __call__(self, *args, **kwargs): - return _Context(self, *args, **kwargs) + return _RunCmdContext(self, *args, **kwargs) @property def test_module(self): @@ -92,74 +169,3 @@ def _test_module(mocker, capfd, patch_bin, testcase): testcase_context.run() return _test_module - - -class _Context(object): - def __init__(self, helper, testcase, mocker, capfd): - self.helper = helper - self.testcase = testcase - self.mocker = mocker - self.capfd = capfd - - self.run_cmd_calls = self.testcase.run_command_calls - self.mock_run_cmd = self._make_mock_run_cmd() - - def _make_mock_run_cmd(self): - call_results = [(x.rc, x.out, x.err) for x in self.run_cmd_calls] - error_call_results = (123, - "OUT: testcase has not enough run_command calls", - "ERR: testcase has not enough run_command calls") - mock_run_command = self.mocker.patch('ansible.module_utils.basic.AnsibleModule.run_command', - side_effect=chain(call_results, repeat(error_call_results))) - return mock_run_command - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - def _run(self): - with pytest.raises(SystemExit): - self.helper.module_main() - - out, err = self.capfd.readouterr() - results = json.loads(out) - - self.check_results(results) - - def test_flags(self, flag=None): - flags = self.testcase.flags - if flag: - flags = flags.get(flag) - return flags - - def run(self): - func = self._run - - test_flags = self.test_flags() - if test_flags.get("skip"): - pytest.skip(reason=test_flags["skip"]) - if test_flags.get("xfail"): - pytest.xfail(reason=test_flags["xfail"]) - - func() - - def check_results(self, results): - print("testcase =\n%s" % str(self.testcase)) - print("results =\n%s" % results) - if 'exception' in results: - print("exception = \n%s" % results["exception"]) - - for test_result in self.testcase.output: - assert results[test_result] == self.testcase.output[test_result], \ - "'{0}': '{1}' != '{2}'".format(test_result, results[test_result], self.testcase.output[test_result]) - - call_args_list = [(item[0][0], item[1]) for item in self.mock_run_cmd.call_args_list] - expected_call_args_list = [(item.command, item.environ) for item in self.run_cmd_calls] - print("call args list =\n%s" % call_args_list) - print("expected args list =\n%s" % expected_call_args_list) - - assert self.mock_run_cmd.call_count == len(self.run_cmd_calls) - if self.mock_run_cmd.call_count: - assert call_args_list == expected_call_args_list