diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py index acc2fa5385d923..8e55aeccc83d31 100644 --- a/Lib/test/test_pstats.py +++ b/Lib/test/test_pstats.py @@ -6,7 +6,10 @@ from enum import StrEnum, _test_simple_enum import pstats +import tempfile import cProfile +from os import remove, path +from functools import cmp_to_key class AddCallersTestCase(unittest.TestCase): """Tests for pstats.add_callers helper.""" @@ -30,12 +33,35 @@ class StatsTestCase(unittest.TestCase): def setUp(self): stats_file = support.findfile('pstats.pck') self.stats = pstats.Stats(stats_file) + to_compile = 'import os' + self.temp_storage = tempfile.mktemp() + profiled = compile(to_compile, '', 'exec') + cProfile.run(profiled, filename=self.temp_storage) + + def tearDown(self): + remove(self.temp_storage) def test_add(self): stream = StringIO() stats = pstats.Stats(stream=stream) stats.add(self.stats, self.stats) + def test_dump_and_load_works_correctly(self): + self.stats.dump_stats(filename=self.temp_storage) + tmp_stats = pstats.Stats(self.temp_storage) + self.assertEqual(self.stats.stats, tmp_stats.stats) + + def test_load_equivalent_to_init(self): + empty = pstats.Stats() + empty.load_stats(self.temp_storage) + created = pstats.Stats(self.temp_storage) + self.assertEqual(empty.stats, created.stats) + + def test_loading_wrong_types(self): + empty = pstats.Stats() + with self.assertRaises(TypeError): + empty.load_stats(42) + def test_sort_stats_int(self): valid_args = {-1: 'stdname', 0: 'calls', @@ -119,5 +145,42 @@ def test_SortKey_enum(self): self.assertEqual(SortKey.FILENAME, 'filename') self.assertNotEqual(SortKey.FILENAME, SortKey.CALLS) + +class TupleCompTestCase(unittest.TestCase): + + def test_tuple_comp_compare_is_correct(self): + comp_list = [(0, 1), (1, -1), (2, 1)] + tup = pstats.TupleComp(comp_list) + to_sort = [(1, 3, 4), (2, 3, 1), (5, 4, 3)] + desired = [(1, 3, 4), (2, 3, 1), (5, 4, 3)] + to_sort.sort(key=cmp_to_key(tup.compare)) + self.assertEqual(to_sort, desired) + + +class UtilsTestCase(unittest.TestCase): + def test_count_calls(self): + dic = { + 1: 2, 3: 5 + } + dic_null = { + 1: 0, 2: 0 + } + self.assertEqual(pstats.count_calls(dic), 7) + self.assertEqual(pstats.count_calls(dic_null), 0) + + def test_f8(self): + self.assertEqual(pstats.f8(2.3232), ' 2.323') + self.assertEqual(pstats.f8(0), ' 0.000') + + def test_func_name(self): + func = ('file', 10, 'name') + self.assertEqual(pstats.func_get_function_name(func), 'name') + + def test_strip_path(self): + func = (path.join('long', 'path'), 10, 'name') + desired = ('path', 10, 'name') + self.assertEqual(pstats.func_strip_path(func), desired) + + if __name__ == "__main__": unittest.main()