From cc5b862ae8656e11b4c07d6defca3830ef49f7f0 Mon Sep 17 00:00:00 2001 From: rechen Date: Fri, 9 Feb 2024 18:01:36 -0800 Subject: [PATCH] rewrite: run function frames in vm.analyze_all_defs() and vm.infer_stub(). analyze_all_defs() runs all functions it can find. infer_stub() runs only the ones needed to infer module-level types. (We eventually need to handle things like a function creating a nested function and returning it, but we can figure that out later.) PiperOrigin-RevId: 605779001 --- pytype/rewrite/vm.py | 39 +++++++++++++++-------- pytype/rewrite/vm_test.py | 65 ++++++++++++++++++++++++++++++++++----- 2 files changed, 84 insertions(+), 20 deletions(-) diff --git a/pytype/rewrite/vm.py b/pytype/rewrite/vm.py index d895fa12b..10956a8f4 100644 --- a/pytype/rewrite/vm.py +++ b/pytype/rewrite/vm.py @@ -18,25 +18,40 @@ def __init__( ): self._code = code self._initial_globals = initial_globals + self._module_frame: frame.Frame = None - def _run(self): - module_frame = frame.Frame( + def _run_module(self) -> None: + assert not self._module_frame + self._module_frame = frame.Frame( name='__main__', code=self._code, initial_locals=self._initial_globals, initial_globals=self._initial_globals, ) - module_frame.run() - return module_frame + self._module_frame.run() + + def _run_function(self, func: abstract.Function) -> frame.Frame: + assert self._module_frame + func_frame = frame.Frame( + name=func.name, + code=func.code, + initial_locals={}, + initial_globals=self._module_frame.final_locals, + ) + func_frame.run() + return func_frame def analyze_all_defs(self): - module_frame = self._run() - for func in module_frame.functions: - del func - raise NotImplementedError('Function analysis not implemented yet') + self._run_module() + functions = list(self._module_frame.functions) + while functions: + func = functions.pop(0) + func_frame = self._run_function(func) + functions.extend(func_frame.functions) def infer_stub(self): - module_frame = self._run() - for name, var in module_frame.final_locals: - del name, var - raise NotImplementedError('Pytd generation not implemented yet') + self._run_module() + for var in self._module_frame.final_locals.values(): + for value in var.values: + if isinstance(value, abstract.Function): + self._run_function(value) diff --git a/pytype/rewrite/vm_test.py b/pytype/rewrite/vm_test.py index a1bdf6e77..f542ae4ca 100644 --- a/pytype/rewrite/vm_test.py +++ b/pytype/rewrite/vm_test.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Type, TypeVar from pytype.pyc import opcodes from pytype.rewrite import abstract @@ -7,19 +7,28 @@ import unittest +_T = TypeVar('_T') + def _make_vm(src: str) -> vm_lib.VirtualMachine: return vm_lib.VirtualMachine(test_utils.parse(src), {}) +def _get(typ: Type[_T], var) -> _T: + v = var.get_atomic_value() + assert isinstance(v, typ) + return v + + class VmTest(unittest.TestCase): def test_run_module_frame(self): block = [opcodes.LOAD_CONST(0, 0, 0, None), opcodes.RETURN_VALUE(0, 0)] code = test_utils.FakeOrderedCode([block], [None]) vm = vm_lib.VirtualMachine(code.Seal(), {}) - module_frame = vm._run() - self.assertIsNotNone(module_frame) + self.assertIsNone(vm._module_frame) + vm._run_module() + self.assertIsNotNone(vm._module_frame) def test_globals(self): vm = _make_vm(""" @@ -33,18 +42,58 @@ def g(): g() f() """) - module_frame = vm._run() + vm._run_module() def get_const(var): - return cast(abstract.PythonConstant, var.get_atomic_value()).constant + return _get(abstract.PythonConstant, var).constant - x = get_const(module_frame.load_global('x')) - y = get_const(module_frame.load_global('y')) - z = get_const(module_frame.load_global('z')) + x = get_const(vm._module_frame.load_global('x')) + y = get_const(vm._module_frame.load_global('y')) + z = get_const(vm._module_frame.load_global('z')) self.assertEqual(x, 42) self.assertIsNone(y) self.assertEqual(z, 42) + def test_analyze_functions(self): + # Just make sure this doesn't crash. + vm = _make_vm(""" + def f(): + def g(): + pass + """) + vm.analyze_all_defs() + + def test_infer_stub(self): + # Just make sure this doesn't crash. + vm = _make_vm(""" + def f(): + def g(): + pass + """) + vm.infer_stub() + + def test_run_function(self): + vm = _make_vm(""" + x = None + + def f(): + global x + x = 42 + + def g(): + y = x + """) + vm._run_module() + f = _get(abstract.Function, vm._module_frame.final_locals['f']) + g = _get(abstract.Function, vm._module_frame.final_locals['g']) + f_frame = vm._run_function(f) + g_frame = vm._run_function(g) + + self.assertEqual(f_frame.load_global('x').get_atomic_value(), + abstract.PythonConstant(42)) + self.assertEqual(g_frame.load_local('y').get_atomic_value(), + abstract.PythonConstant(None)) + if __name__ == '__main__': unittest.main()