Skip to content

Commit fba1de2

Browse files
committed
Refactor parser to be more type-correct
Also detect whether names that occur in both inputs and outputs are at the same position.
1 parent 5fbe7b0 commit fba1de2

File tree

2 files changed

+56
-29
lines changed

2 files changed

+56
-29
lines changed

Tools/cases_generator/generate_cases.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,11 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
7272
if dedent < 0:
7373
indent += " " * -dedent
7474
# TODO: Is it better to count forward or backward?
75-
for i, input in enumerate(reversed(instr.inputs or ()), 1):
75+
for i, input in enumerate(reversed(instr.inputs), 1):
7676
f.write(f"{indent} PyObject *{input} = PEEK({i});\n")
77-
for output in instr.outputs or ():
78-
f.write(f"{indent} PyObject *{output};\n")
79-
# input = ", ".join(instr.inputs)
80-
# output = ", ".join(instr.outputs)
81-
# f.write(f"{indent} // {input} -- {output}\n")
77+
for output in instr.outputs:
78+
if output not in instr.inputs:
79+
f.write(f"{indent} PyObject *{output};\n")
8280
assert instr.block is not None
8381
blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
8482
# Remove blank lines from ends
@@ -97,7 +95,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
9795
# Write the body
9896
ninputs = len(instr.inputs or ())
9997
for line in blocklines:
100-
if m := re.match(r"(\s*)ERROR_IF\(([^,]+), (\w+)\);\s*$", line):
98+
if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
10199
space, cond, label = m.groups()
102100
# ERROR_IF() must remove the inputs from the stack.
103101
# The code block is responsible for DECREF()ing them.
@@ -114,7 +112,8 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
114112
elif diff < 0:
115113
f.write(f"{indent} STACK_SHRINK({-diff});\n")
116114
for i, output in enumerate(reversed(instr.outputs or ()), 1):
117-
f.write(f"{indent} POKE({i}, {output});\n")
115+
if output not in (instr.inputs or ()):
116+
f.write(f"{indent} POKE({i}, {output});\n")
118117
assert instr.block
119118

120119
def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):

Tools/cases_generator/parser.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,28 @@ class Block(Node):
5757

5858

5959
@dataclass
60-
class InstDef(Node):
60+
class InstHeader(Node):
6161
name: str
62-
inputs: list[str] | None
63-
outputs: list[str] | None
64-
block: Block | None
62+
inputs: list[str]
63+
outputs: list[str]
64+
65+
66+
@dataclass
67+
class InstDef(Node):
68+
header: InstHeader
69+
block: Block
70+
71+
@property
72+
def name(self):
73+
return self.header.name
74+
75+
@property
76+
def inputs(self):
77+
return self.header.inputs
78+
79+
@property
80+
def outputs(self):
81+
return self.header.outputs
6582

6683

6784
@dataclass
@@ -82,30 +99,42 @@ class Parser(PLexer):
8299
def inst_def(self) -> InstDef | None:
83100
if header := self.inst_header():
84101
if block := self.block():
85-
header.block = block
86-
return header
102+
return InstDef(header, block)
87103
raise self.make_syntax_error("Expected block")
88104
return None
89105

90106
@contextual
91-
def inst_header(self):
107+
def inst_header(self) -> InstHeader | None:
92108
# inst(NAME) | inst(NAME, (inputs -- outputs))
93109
# TODO: Error out when there is something unexpected.
94-
# TODO: Make INST a keyword in the lexer.
110+
# TODO: Make INST a keyword in the lexer.``
95111
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "inst":
96112
if (self.expect(lx.LPAREN)
97113
and (tkn := self.expect(lx.IDENTIFIER))):
98114
name = tkn.text
99115
if self.expect(lx.COMMA):
100116
inp, outp = self.stack_effect()
101-
if (self.expect(lx.RPAREN)
102-
and self.peek().kind == lx.LBRACE):
103-
return InstDef(name, inp, outp, [])
117+
if self.expect(lx.RPAREN):
118+
if ((tkn := self.peek())
119+
and tkn.kind == lx.LBRACE):
120+
self.check_overlaps(inp, outp)
121+
return InstHeader(name, inp, outp)
104122
elif self.expect(lx.RPAREN):
105-
return InstDef(name, None, None, [])
123+
return InstHeader(name, [], [])
106124
return None
107125

108-
def stack_effect(self):
126+
def check_overlaps(self, inp: list[str], outp: list[str]):
127+
for i, name in enumerate(inp):
128+
try:
129+
j = outp.index(name)
130+
except ValueError:
131+
continue
132+
else:
133+
if i != j:
134+
raise self.make_syntax_error(
135+
f"Input {name!r} at pos {i} repeated in output at different pos {j}")
136+
137+
def stack_effect(self) -> tuple[list[str], list[str]]:
109138
# '(' [inputs] '--' [outputs] ')'
110139
if self.expect(lx.LPAREN):
111140
inp = self.inputs() or []
@@ -115,7 +144,7 @@ def stack_effect(self):
115144
return inp, outp
116145
raise self.make_syntax_error("Expected stack effect")
117146

118-
def inputs(self):
147+
def inputs(self) -> list[str] | None:
119148
# input (, input)*
120149
here = self.getpos()
121150
if inp := self.input():
@@ -128,7 +157,7 @@ def inputs(self):
128157
self.setpos(here)
129158
return None
130159

131-
def input(self):
160+
def input(self) -> str | None:
132161
# IDENTIFIER
133162
if (tkn := self.expect(lx.IDENTIFIER)):
134163
if self.expect(lx.LBRACKET):
@@ -148,7 +177,7 @@ def input(self):
148177
return "??"
149178
return None
150179

151-
def outputs(self):
180+
def outputs(self) -> list[str] | None:
152181
# output (, output)*
153182
here = self.getpos()
154183
if outp := self.output():
@@ -161,7 +190,7 @@ def outputs(self):
161190
self.setpos(here)
162191
return None
163192

164-
def output(self):
193+
def output(self) -> str | None:
165194
return self.input() # TODO: They're not quite the same.
166195

167196
@contextual
@@ -176,7 +205,6 @@ def super_def(self) -> Super | None:
176205
return res
177206

178207
def ops(self) -> list[str] | None:
179-
here = self.getpos()
180208
if tkn := self.expect(lx.IDENTIFIER):
181209
ops = [tkn.text]
182210
while self.expect(lx.PLUS):
@@ -197,7 +225,7 @@ def family_def(self) -> Family | None:
197225
return Family(tkn.text, members)
198226
return None
199227

200-
def members(self):
228+
def members(self) -> list[str] | None:
201229
here = self.getpos()
202230
if tkn := self.expect(lx.IDENTIFIER):
203231
near = self.getpos()
@@ -214,8 +242,8 @@ def block(self) -> Block:
214242
tokens = self.c_blob()
215243
return Block(tokens)
216244

217-
def c_blob(self):
218-
tokens = []
245+
def c_blob(self) -> list[lx.Token]:
246+
tokens: list[lx.Token] = []
219247
level = 0
220248
while tkn := self.next(raw=True):
221249
if tkn.kind in (lx.LBRACE, lx.LPAREN, lx.LBRACKET):

0 commit comments

Comments
 (0)