Skip to content

Commit

Permalink
rewrite: run function frames in vm.analyze_all_defs() and vm.infer_st…
Browse files Browse the repository at this point in the history
…ub().

analyze_all_defs() runs all functions it can find. infer_stub() runs only the
ones needed to infer module-level types. (We eventually need to handle things
like a function creating a nested function and returning it, but we can figure
that out later.)

PiperOrigin-RevId: 605779001
  • Loading branch information
rchen152 committed Feb 13, 2024
1 parent 36bed9f commit cc5b862
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 20 deletions.
39 changes: 27 additions & 12 deletions pytype/rewrite/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,40 @@ def __init__(
):
self._code = code
self._initial_globals = initial_globals
self._module_frame: frame.Frame = None

def _run(self):
module_frame = frame.Frame(
def _run_module(self) -> None:
assert not self._module_frame
self._module_frame = frame.Frame(
name='__main__',
code=self._code,
initial_locals=self._initial_globals,
initial_globals=self._initial_globals,
)
module_frame.run()
return module_frame
self._module_frame.run()

def _run_function(self, func: abstract.Function) -> frame.Frame:
assert self._module_frame
func_frame = frame.Frame(
name=func.name,
code=func.code,
initial_locals={},
initial_globals=self._module_frame.final_locals,
)
func_frame.run()
return func_frame

def analyze_all_defs(self):
module_frame = self._run()
for func in module_frame.functions:
del func
raise NotImplementedError('Function analysis not implemented yet')
self._run_module()
functions = list(self._module_frame.functions)
while functions:
func = functions.pop(0)
func_frame = self._run_function(func)
functions.extend(func_frame.functions)

def infer_stub(self):
module_frame = self._run()
for name, var in module_frame.final_locals:
del name, var
raise NotImplementedError('Pytd generation not implemented yet')
self._run_module()
for var in self._module_frame.final_locals.values():
for value in var.values:
if isinstance(value, abstract.Function):
self._run_function(value)
65 changes: 57 additions & 8 deletions pytype/rewrite/vm_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import Type, TypeVar

from pytype.pyc import opcodes
from pytype.rewrite import abstract
Expand All @@ -7,19 +7,28 @@

import unittest

_T = TypeVar('_T')


def _make_vm(src: str) -> vm_lib.VirtualMachine:
return vm_lib.VirtualMachine(test_utils.parse(src), {})


def _get(typ: Type[_T], var) -> _T:
v = var.get_atomic_value()
assert isinstance(v, typ)
return v


class VmTest(unittest.TestCase):

def test_run_module_frame(self):
block = [opcodes.LOAD_CONST(0, 0, 0, None), opcodes.RETURN_VALUE(0, 0)]
code = test_utils.FakeOrderedCode([block], [None])
vm = vm_lib.VirtualMachine(code.Seal(), {})
module_frame = vm._run()
self.assertIsNotNone(module_frame)
self.assertIsNone(vm._module_frame)
vm._run_module()
self.assertIsNotNone(vm._module_frame)

def test_globals(self):
vm = _make_vm("""
Expand All @@ -33,18 +42,58 @@ def g():
g()
f()
""")
module_frame = vm._run()
vm._run_module()

def get_const(var):
return cast(abstract.PythonConstant, var.get_atomic_value()).constant
return _get(abstract.PythonConstant, var).constant

x = get_const(module_frame.load_global('x'))
y = get_const(module_frame.load_global('y'))
z = get_const(module_frame.load_global('z'))
x = get_const(vm._module_frame.load_global('x'))
y = get_const(vm._module_frame.load_global('y'))
z = get_const(vm._module_frame.load_global('z'))
self.assertEqual(x, 42)
self.assertIsNone(y)
self.assertEqual(z, 42)

def test_analyze_functions(self):
# Just make sure this doesn't crash.
vm = _make_vm("""
def f():
def g():
pass
""")
vm.analyze_all_defs()

def test_infer_stub(self):
# Just make sure this doesn't crash.
vm = _make_vm("""
def f():
def g():
pass
""")
vm.infer_stub()

def test_run_function(self):
vm = _make_vm("""
x = None
def f():
global x
x = 42
def g():
y = x
""")
vm._run_module()
f = _get(abstract.Function, vm._module_frame.final_locals['f'])
g = _get(abstract.Function, vm._module_frame.final_locals['g'])
f_frame = vm._run_function(f)
g_frame = vm._run_function(g)

self.assertEqual(f_frame.load_global('x').get_atomic_value(),
abstract.PythonConstant(42))
self.assertEqual(g_frame.load_local('y').get_atomic_value(),
abstract.PythonConstant(None))


if __name__ == '__main__':
unittest.main()

0 comments on commit cc5b862

Please sign in to comment.