Skip to content

Commit

Permalink
Implement Closures (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
SohamKukreti authored Oct 11, 2023
1 parent da9662e commit 2b6f2c0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/interpreted/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@


class Scope:
def __init__(self):
def __init__(self, parent=None) -> None:
self.data = {}
self.parent = parent
self.set("print", Print())
self.set("len", Len())
self.set("int", Int())
self.set("float", Float())
self.set("deque", DequeConstructor())

def get(self, name) -> Any:
return getattr(self, name, NOT_SET)
return self.data.get(name, NOT_SET)

def set(self, name, value) -> None:
setattr(self, name, value)
self.data[name] = value


class InterpreterError(Exception):
Expand Down Expand Up @@ -170,8 +172,14 @@ def __init__(self, value: Object) -> None:


class UserFunction(Function):
def __init__(self, definition: FunctionDef, current_globals: Scope) -> None:
def __init__(
self,
definition: FunctionDef,
parent_scope: Scope,
current_globals: Scope,
) -> None:
self.definition = definition
self.parent_scope = parent_scope
self.current_globals = current_globals

def as_string(self) -> str:
Expand All @@ -183,10 +191,10 @@ def arg_count(self) -> int:
def call(self, interpreter: Interpreter, args: list[Object]) -> Object:
super().ensure_args(args)

parent_scope = interpreter.scope
current_scope = interpreter.scope
parent_globals = interpreter.globals

function_scope = Scope()
function_scope = Scope(parent=self.parent_scope)
interpreter.globals = self.current_globals
interpreter.scope = function_scope

Expand All @@ -201,7 +209,7 @@ def call(self, interpreter: Interpreter, args: list[Object]) -> Object:
return ret.value

finally:
interpreter.scope = parent_scope
interpreter.scope = current_scope
interpreter.globals = parent_globals

return Value(None)
Expand Down Expand Up @@ -424,7 +432,7 @@ def visit_Import(self, node: Import) -> None:
self.scope = parent_scope
self.globals = parent_globals

module_obj = Module(members=vars(module_scope))
module_obj = Module(members=module_scope.data)

self.scope.set(name, module_obj)

Expand All @@ -451,7 +459,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> None:
for alias in node.names:
name = alias.name
if name == "*":
for member, value in vars(module_scope).items():
for member, value in module_scope.data.items():
self.scope.set(member, value)
return

Expand All @@ -462,7 +470,8 @@ def visit_ImportFrom(self, node: ImportFrom) -> None:
self.scope.set(name, member)

def visit_FunctionDef(self, node: FunctionDef) -> None:
function = UserFunction(node, self.globals)
parent_scope = self.scope
function = UserFunction(node, parent_scope, self.globals)
self.scope.set(node.name, function)

def visit_Assign(self, node: Assign) -> None:
Expand Down Expand Up @@ -734,12 +743,17 @@ def visit_Attribute(self, node: Attribute) -> Object:
def visit_Name(self, node: Name) -> Value:
name = node.id

value = self.scope.get(name)
current_scope = self.scope
while current_scope is not None:
value = current_scope.get(name)
if value is NOT_SET:
current_scope = current_scope.parent
else:
return value

value = self.globals.get(name)
if value is NOT_SET:
value = self.globals.get(name)
if value is NOT_SET:
raise InterpreterError(f"{name!r} is not defined")
raise InterpreterError(f"{name!r} is not defined")

return value

Expand Down
20 changes: 20 additions & 0 deletions tests/interpreted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ def foo(x):
""",
"a\nbc\nab\nabc\nb\n",
),
(
"""\
x = 5
def bar():
x = 10
def baz():
def foo():
print(x)
return foo
return baz
foo = bar()()
foo()
""",
"10\n",
),
),
)
def test_interpret(source, output) -> None:
Expand Down

0 comments on commit 2b6f2c0

Please sign in to comment.