diff --git a/tests/test_CodeEntropy/test_main.py b/tests/test_CodeEntropy/test_main.py index 42b9b79..bee18f0 100644 --- a/tests/test_CodeEntropy/test_main.py +++ b/tests/test_CodeEntropy/test_main.py @@ -1,3 +1,8 @@ +import os +import shutil +import subprocess +import sys +import tempfile import unittest from unittest.mock import MagicMock, patch @@ -9,6 +14,22 @@ class TestMain(unittest.TestCase): Unit tests for the main functionality of CodeEntropy. """ + def setUp(self): + """ + Set up a temporary directory as the working directory before each test. + """ + self.test_dir = tempfile.mkdtemp(prefix="CodeEntropy_") + self._orig_dir = os.getcwd() + os.chdir(self.test_dir) + + def tearDown(self): + """ + Clean up by removing the temporary directory and restoring the original working + directory. + """ + os.chdir(self._orig_dir) + shutil.rmtree(self.test_dir) + @patch("CodeEntropy.main.sys.exit") @patch("CodeEntropy.main.RunManager") def test_main_successful_run(self, mock_RunManager, mock_exit): @@ -68,6 +89,51 @@ def test_main_exception_triggers_exit( "Fatal error during entropy calculation: Test exception", exc_info=True ) + def test_main_entry_point_runs(self): + """ + Test that the CLI entry point (main.py) runs successfully with minimal required + arguments. + """ + # Prepare input files + data_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "data") + ) + tpr_path = shutil.copy(os.path.join(data_dir, "md_A4_dna.tpr"), self.test_dir) + trr_path = shutil.copy( + os.path.join(data_dir, "md_A4_dna_xf.trr"), self.test_dir + ) + + config_path = os.path.join(self.test_dir, "config.yaml") + with open(config_path, "w") as f: + f.write("run1:\n" " selection_string: resid 1\n") + + result = subprocess.run( + [ + sys.executable, + "-m", + "CodeEntropy.main", + "--top_traj_file", + tpr_path, + trr_path, + ], + cwd=self.test_dir, + capture_output=True, + text=True, + ) + + self.assertEqual(result.returncode, 0) + + # Check for job folder and output file + job_dir = os.path.join(self.test_dir, "job001") + output_file = os.path.join(job_dir, "output_file.json") + + self.assertTrue(os.path.exists(job_dir)) + self.assertTrue(os.path.exists(output_file)) + + with open(output_file) as f: + content = f.read() + self.assertIn("DA", content) + if __name__ == "__main__": unittest.main()