Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

imports implementation #25

Merged
merged 7 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions src/interpreted/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Any
from unittest import mock

import sys

from interpreted import nodes
from interpreted.nodes import (
Assign,
Expand All @@ -24,13 +26,25 @@
Subscript,
UnaryOp,
While,
alias,
Import,
ImportFrom,
)
from interpreted.parser import parse

NOT_SET = object()


class Scope:
def __init__(self):
self.set("print", Print())
self.set("len", Len())
self.set("int", Int())
self.set("float", Float())
self.set("deque", DequeConstructor())

self.scope_lookup: {str: Scope} = {}

def get(self, name) -> Any:
return getattr(self, name, NOT_SET)

Expand Down Expand Up @@ -154,8 +168,9 @@ def __init__(self, value: Object) -> None:


class UserFunction(Function):
def __init__(self, definition: FunctionDef) -> None:
def __init__(self, definition: FunctionDef, current_globals: Scope) -> None:
self.definition = definition
self.current_globals = current_globals

def as_string(self) -> str:
return f"<function {self.definition.name!r}>"
Expand All @@ -167,8 +182,10 @@ def call(self, interpreter: Interpreter, args: list[Object]) -> Object:
super().ensure_args(args)

parent_scope = interpreter.scope
parent_globals = interpreter.globals

function_scope = Scope()
interpreter.globals = self.current_globals
interpreter.scope = function_scope

for param, arg in zip(self.definition.params, args):
Expand All @@ -183,6 +200,7 @@ def call(self, interpreter: Interpreter, args: list[Object]) -> Object:

finally:
interpreter.scope = parent_scope
interpreter.globals = parent_globals

return Value(None)

Expand Down Expand Up @@ -370,12 +388,6 @@ def is_truthy(obj: Object) -> bool:
class Interpreter:
def __init__(self) -> None:
self.globals = Scope()
self.globals.set("print", Print())
self.globals.set("len", Len())
self.globals.set("int", Int())
self.globals.set("float", Float())
self.globals.set("deque", DequeConstructor())

self.scope = self.globals

def visit(self, node: Node) -> Object | None:
Expand All @@ -387,8 +399,63 @@ def visit_Module(self, node: Module) -> None:
for stmt in node.body:
self.visit(stmt)

def visit_Import(self, node: Import) -> None:
for alias in node.names:
name = alias.name
if alias.asname:
name = alias.asname

contents = ""
with open(f"{alias.name}.py", "r") as f:
contents = f.read()
module = parse(contents)

parent_scope = self.scope
parent_globals = self.globals

module_scope = Scope()
self.scope = module_scope
self.globals = module_scope

self.visit(module)

self.scope = parent_scope
self.globals = parent_globals

self.scope.set(name, module_scope)

def visit_ImportFrom(self, node: ImportFrom) -> None:
module_name = node.module

contents = ""
with open(f"{module_name}.py", "r") as f:
contents = f.read()
module = parse(contents)

parent_scope = self.scope
parent_globals = self.globals

module_scope = Scope()
self.scope = module_scope
self.globals = module_scope

self.visit(module)

self.scope = parent_scope
self.globals = parent_globals

for alias in node.names:
name = alias.name
if alias.asname:
name = alias.asname

body = module_scope.get(alias.name)
self.scope.set(name, body)
if type(body) is not Value:
self.scope.scope_lookup[name] = module_scope

def visit_FunctionDef(self, node: FunctionDef) -> None:
function = UserFunction(node)
function = UserFunction(node, self.globals)
self.scope.set(node.name, function)

def visit_Assign(self, node: Assign) -> None:
Expand Down Expand Up @@ -595,6 +662,12 @@ def visit_Call(self, node: Call) -> Object:
raise InterpreterError(f"{object_type!r} object is not callable")

arguments = [self.visit(arg) for arg in node.args]

if function.as_string in self.scope.scope_lookup:
module_scope = self.scope.scope_lookup[function.as_string]
print(vars(self.scope))
return function.call(self, arguments, module_scope)

s-m33r marked this conversation as resolved.
Show resolved Hide resolved
return function.call(self, arguments)

def visit_Subscript(self, node: Subscript) -> Object:
Expand Down Expand Up @@ -646,6 +719,11 @@ def visit_Attribute(self, node: Attribute) -> Object:
obj = self.visit(node.value)
assert obj is not None

if type(obj) is Scope:
scoped_result = obj.get(attribute_name)

return scoped_result
s-m33r marked this conversation as resolved.
Show resolved Hide resolved

if attribute_name in obj.attributes:
return obj.attributes[attribute_name]

Expand All @@ -660,6 +738,9 @@ def visit_Name(self, node: Name) -> Value:
name = node.id

value = self.scope.get(name)
if value in self.scope.scope_lookup:
return value

if value is NOT_SET:
value = self.globals.get(name)
if value is NOT_SET:
Expand Down
49 changes: 49 additions & 0 deletions tests/interpreted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,52 @@ def test_file_not_found() -> None:
assert process.stdout == b""
assert process.stderr == b"\x1b[31mError:\x1b[m Unable to open file: 'foo.py'\n"
assert process.returncode == 1


def test_imports(tmp_path) -> None:
math_content = """\
PI = 3.14

def add(a, b):
return a + b

def mul(a, b):
return a * b

def area(r):
return PI * r * r
"""
utils_content = """\
import calc as math

def cos(x):
return "bru what"
"""
main_content = """\
from utils import math, cos

print(math.area(2))
print(math.add(2,3))
print(math.mul(3,4))
print(cos(30))
"""

main = tmp_path / "main.py"
main.write_text(dedent(main_content))

utils = tmp_path / "utils.py"
utils.write_text(dedent(utils_content))

math = tmp_path / "calc.py"
math.write_text(dedent(math_content))

process = subprocess.run(
["interpreted", main.as_posix()],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=str(tmp_path),
)

assert process.stderr == b""
assert process.stdout.decode() == "12.56\n5\n12\nbru what\n"
assert process.returncode == 0
Loading