Skip to content

Commit

Permalink
Merge pull request #11 from youknowone/refactor-asdl
Browse files Browse the repository at this point in the history
Refactor ast to hold data as seperated type
  • Loading branch information
youknowone authored May 7, 2023
2 parents 48920a0 + 6d73580 commit 7b8844b
Show file tree
Hide file tree
Showing 112 changed files with 22,478 additions and 19,206 deletions.
5 changes: 5 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
parser/src/python.rs linguist-generated
**/*.snap linguist-generated -merge
**/*.lalrpop text eol=LF
**/*.py text working-tree-encoding=UTF-8 eol=LF
**/*.rs text working-tree-encoding=UTF-8 eol=LF
129 changes: 99 additions & 30 deletions ast/asdl_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def visitSum(self, sum, name):
if is_simple(sum):
info.has_userdata = False
else:
for t in sum.types:
self.typeinfo[t.name] = TypeInfo(t.name)
self.add_children(t.name, t.fields)
if len(sum.types) > 1:
info.boxed = True
if sum.attributes:
Expand Down Expand Up @@ -205,16 +208,49 @@ def simple_sum(self, sum, name, depth):

def sum_with_constructors(self, sum, name, depth):
typeinfo = self.typeinfo[name]
generics, generics_applied = self.get_generics(name, "U = ()", "U")
enumname = rustname = get_rust_type(name)
# all the attributes right now are for location, so if it has attrs we
# can just wrap it in Located<>
if sum.attributes:
enumname = rustname + "Kind"

for t in sum.types:
if not t.fields:
continue
self.emit_attrs(depth)
self.typeinfo[t] = TypeInfo(t)
t_generics, t_generics_applied = self.get_generics(t.name, "U = ()", "U")
payload_name = f"{rustname}{t.name}"
self.emit(f"pub struct {payload_name}{t_generics} {{", depth)
for f in t.fields:
self.visit(f, typeinfo, "pub ", depth + 1, t.name)
self.emit("}", depth)
self.emit(
textwrap.dedent(
f"""
impl{t_generics_applied} From<{payload_name}{t_generics_applied}> for {enumname}{t_generics_applied} {{
fn from(payload: {payload_name}{t_generics_applied}) -> Self {{
{enumname}::{t.name}(payload)
}}
}}
"""
),
depth,
)

generics, generics_applied = self.get_generics(name, "U = ()", "U")
self.emit_attrs(depth)
self.emit(f"pub enum {enumname}{generics} {{", depth)
for t in sum.types:
self.visit(t, typeinfo, depth + 1)
if t.fields:
t_generics, t_generics_applied = self.get_generics(
t.name, "U = ()", "U"
)
self.emit(
f"{t.name}({rustname}{t.name}{t_generics_applied}),", depth + 1
)
else:
self.emit(f"{t.name},", depth + 1)
self.emit("}", depth)
if sum.attributes:
self.emit(
Expand All @@ -238,13 +274,18 @@ def visitField(self, field, parent, vis, depth, constructor=None):
if fieldtype and fieldtype.has_userdata:
typ = f"{typ}<U>"
# don't box if we're doing Vec<T>, but do box if we're doing Vec<Option<Box<T>>>
if fieldtype and fieldtype.boxed and (not (parent.product or field.seq) or field.opt):
if (
fieldtype
and fieldtype.boxed
and (not (parent.product or field.seq) or field.opt)
):
typ = f"Box<{typ}>"
if field.opt or (
# When a dictionary literal contains dictionary unpacking (e.g., `{**d}`),
# the expression to be unpacked goes in `values` with a `None` at the corresponding
# position in `keys`. To handle this, the type of `keys` needs to be `Option<Vec<T>>`.
constructor == "Dict" and field.name == "keys"
constructor == "Dict"
and field.name == "keys"
):
typ = f"Option<{typ}>"
if field.seq:
Expand Down Expand Up @@ -344,14 +385,21 @@ def visitSum(self, sum, name, depth):
)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
enumname += "Kind"
rustname = enumname + "Kind"
else:
rustname = enumname
self.emit("match node {", depth + 1)
for cons in sum.types:
fields_pattern = self.make_pattern(cons.fields)
fields_pattern = self.make_pattern(
enumname, rustname, cons.name, cons.fields
)
self.emit(
f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2
f"{fields_pattern[0]} {{ {fields_pattern[1]} }} {fields_pattern[2]} => {{",
depth + 2,
)
self.gen_construction(
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
)
self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
if is_located:
Expand Down Expand Up @@ -381,23 +429,33 @@ def visitProduct(self, product, name, depth):
)
if is_located:
self.emit("fold_located(folder, node, |folder, node| {", depth)
structname += "Data"
fields_pattern = self.make_pattern(product.fields)
self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1)
self.gen_construction(structname, product.fields, depth + 1)
rustname = structname + "Data"
else:
rustname = structname
fields_pattern = self.make_pattern(rustname, structname, None, product.fields)
self.emit(f"let {rustname} {{ {fields_pattern[1]} }} = node;", depth + 1)
self.gen_construction(rustname, product.fields, "", depth + 1)
if is_located:
self.emit("})", depth)
self.emit("}", depth)

def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)
def make_pattern(self, rustname, pyname, fieldname, fields):
if fields:
header = f"{pyname}::{fieldname}({rustname}{fieldname}"
footer = ")"
else:
header = f"{pyname}::{fieldname}"
footer = ""

def gen_construction(self, cons_path, fields, depth):
self.emit(f"Ok({cons_path} {{", depth)
body = ",".join(rust_field(f.name) for f in fields)
return header, body, footer

def gen_construction(self, header, fields, footer, depth):
self.emit(f"Ok({header} {{", depth)
for field in fields:
name = rust_field(field.name)
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
self.emit("})", depth)
self.emit(f"}}{footer})", depth)


class FoldModuleVisitor(TypeInfoEmitVisitor):
Expand Down Expand Up @@ -514,33 +572,36 @@ def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)

def visitSum(self, sum, name, depth):
enumname = get_rust_type(name)
rustname = enumname = get_rust_type(name)
if sum.attributes:
enumname += "Kind"
rustname = enumname + "Kind"

self.emit(f"impl NamedNode for ast::{enumname} {{", depth)
self.emit(f"impl NamedNode for ast::{rustname} {{", depth)
self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1)
self.emit("}", depth)
self.emit(f"impl Node for ast::{enumname} {{", depth)
self.emit(f"impl Node for ast::{rustname} {{", depth)
self.emit(
"fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1
)
self.emit("match self {", depth + 2)
for variant in sum.types:
self.constructor_to_object(variant, enumname, depth + 3)
self.constructor_to_object(variant, enumname, rustname, depth + 3)
self.emit("}", depth + 2)
self.emit("}", depth + 1)
self.emit(
"fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult<Self> {",
depth + 1,
)
self.gen_sum_fromobj(sum, name, enumname, depth + 2)
self.gen_sum_fromobj(sum, name, enumname, rustname, depth + 2)
self.emit("}", depth + 1)
self.emit("}", depth)

def constructor_to_object(self, cons, enumname, depth):
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth)
def constructor_to_object(self, cons, enumname, rustname, depth):
self.emit(f"ast::{rustname}::{cons.name}", depth)
if cons.fields:
fields_pattern = self.make_pattern(cons.fields)
self.emit(f"( ast::{enumname}{cons.name} {{ {fields_pattern} }} )", depth)
self.emit(" => {", depth)
self.make_node(cons.name, cons.fields, depth + 1)
self.emit("}", depth)

Expand Down Expand Up @@ -586,15 +647,20 @@ def make_node(self, variant, fields, depth):
def make_pattern(self, fields):
return ",".join(rust_field(f.name) for f in fields)

def gen_sum_fromobj(self, sum, sumname, enumname, depth):
def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
if sum.attributes:
self.extract_location(sumname, depth)

self.emit("let _cls = _object.class();", depth)
self.emit("Ok(", depth)
for cons in sum.types:
self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth)
self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1)
if cons.fields:
self.emit(f"ast::{rustname}::{cons.name} (ast::{enumname}{cons.name} {{", depth + 1)
self.gen_construction_fields(cons, sumname, depth + 1)
self.emit("})", depth + 1)
else:
self.emit(f"ast::{rustname}::{cons.name}", depth + 1)
self.emit("} else", depth)

self.emit("{", depth)
Expand All @@ -610,13 +676,16 @@ def gen_product_fromobj(self, product, prodname, structname, depth):
self.gen_construction(structname, product, prodname, depth + 1)
self.emit(")", depth)

def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
def gen_construction_fields(self, cons, name, depth):
for field in cons.fields:
self.emit(
f"{rust_field(field.name)}: {self.decode_field(field, name)},",
depth + 1,
)

def gen_construction(self, cons_path, cons, name, depth):
self.emit(f"ast::{cons_path} {{", depth)
self.gen_construction_fields(cons, name, depth + 1)
self.emit("}", depth)

def extract_location(self, typename, depth):
Expand Down
Loading

0 comments on commit 7b8844b

Please sign in to comment.