Skip to content

Commit

Permalink
[mycpp] Implement with str_switch(s) pattern
Browse files Browse the repository at this point in the history
It first dispatches on the string length, and then the value with the
runtime function str_equals_c().
  • Loading branch information
Andy C committed Mar 11, 2024
1 parent 46f144b commit 6efd0db
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 46 deletions.
3 changes: 3 additions & 0 deletions mycpp/TEST.sh
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ test-invalid-examples() {
*/invalid_global.py)
expected_status=2
;;
*/invalid_switch.py)
expected_status=5
;;
esac

if test $status -ne $expected_status; then
Expand Down
153 changes: 122 additions & 31 deletions mycpp/cppgen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
cppgen.py - AST pass to that prints C++ code
"""
import io
import itertools
import json # for "C escaping"

from typing import overload, Union, Optional, Dict
Expand Down Expand Up @@ -385,8 +386,8 @@ def GetCReturnType(t) -> Tuple[str, bool, Optional[str]]:

def PythonStringLiteral(s: str) -> str:
"""
Returns a properly quoted string.
"""
Returns a properly quoted string.
"""
# MyPy does bad escaping. Decode and push through json to get something
# workable in C++.
return json.dumps(format_strings.DecodeMyPyString(s))
Expand Down Expand Up @@ -1912,6 +1913,13 @@ def _collect_cases(self, if_node, out):
"""
The MyPy AST has a recursive structure for if-elif-elif rather than a
flat one. It's a bit confusing.
Appends (expr, block) cases to out param, and returns the default
block, which has no expression.
default block may be None.
Returns False if there is no default block.
"""
assert isinstance(if_node, IfStmt), if_node
assert len(if_node.expr) == 1, if_node.expr
Expand All @@ -1920,6 +1928,11 @@ def _collect_cases(self, if_node, out):
expr = if_node.expr[0]
body = if_node.body[0]

if not isinstance(expr, CallExpr):
self.report_error(expr,
'Expected call like case(x), got %s' % expr)
return

out.append((expr, body))

if if_node.else_body:
Expand All @@ -1930,30 +1943,44 @@ def _collect_cases(self, if_node, out):
# if 0:

if isinstance(first_of_block, IfStmt):
self._collect_cases(first_of_block, out)
return self._collect_cases(first_of_block, out)
else:
# default case - no expression
out.append((None, if_node.else_body))
return if_node.else_body

return False # NO DEFAULT BLOCK - Different than None

def _write_cases(self, cases):
def _write_cases(self, switch_expr, cases, default_block):
""" Write a list of (expr, block) pairs """

for expr, body in cases:
if expr is not None:
assert isinstance(expr, CallExpr), expr
for i, arg in enumerate(expr.args):
if i != 0:
self.def_write('\n')
self.def_write_ind('case ')
self.accept(arg)
self.def_write(': ')
assert expr is not None, expr
if not isinstance(expr, CallExpr):
self.report_error(expr,
'Expected call like case(x), got %s' % expr)
return

self.accept(body)
self.def_write_ind(' break;\n')
else:
self.def_write_ind('default: ')
self.accept(body) # the whole block
# don't write 'break'
for i, arg in enumerate(expr.args):
if i != 0:
self.def_write('\n')
self.def_write_ind('case ')
self.accept(arg)
self.def_write(': ')

self.accept(body)
self.def_write_ind(' break;\n')

if default_block is None:
# an error occurred
return
if default_block is False:
self.report_error(switch_expr,
'switch got no else: for default block')
return

self.def_write_ind('default: ')
self.accept(default_block)
# don't write 'break'

def _write_switch(self, expr, o):
"""Write a switch statement over integers."""
Expand All @@ -1969,8 +1996,8 @@ def _write_switch(self, expr, o):

self.indent += 1
cases = []
self._collect_cases(if_node, cases)
self._write_cases(cases)
default_block = self._collect_cases(if_node, cases)
self._write_cases(expr, cases, default_block)

self.indent -= 1
self.def_write_ind('}\n')
Expand All @@ -1989,34 +2016,98 @@ def _write_tag_switch(self, expr, o):

self.indent += 1
cases = []
self._collect_cases(if_node, cases)
self._write_cases(cases)
default_block = self._collect_cases(if_node, cases)
self._write_cases(expr, cases, default_block)

self.indent -= 1
self.def_write_ind('}\n')

def _str_switch_cases(self, cases):
cases2 = []
for expr, body in cases:
if not isinstance(expr, CallExpr):
# non-fatal check from _collect_cases
break

args = expr.args
if len(args) != 1:
self.report_error(
expr,
'str_switch can only have case("x"), not case("x", "y")' %
args)
break

if not isinstance(args[0], StrExpr):
self.report_error(
expr,
'str_switch can only be used with constant strings, got %s'
% args[0])
break

s = args[0].value
cases2.append((len(s), s, body))

# Sort by string length
cases2.sort(key=lambda pair: pair[0])
grouped = itertools.groupby(cases2, key=lambda pair: pair[0])
return grouped

def _write_str_switch(self, expr, o):
"""Write a switch statement over strings."""
assert len(expr.args) == 1, expr.args

self.def_write_ind('switch (len(')
self.accept(expr.args[0])
self.def_write(')) {\n')
switch_var = expr.args[0]
if not isinstance(switch_var, NameExpr):
self.report_error(
expr.args[0],
'str_switch(x) accepts only a variable name, got %s' %
switch_var)
return

self.def_write_ind('switch (len(%s)) {\n' % switch_var.name)

# There can only be one thing under 'with str_switch'
assert len(o.body.body) == 1, o.body.body
if_node = o.body.body[0]
assert isinstance(if_node, IfStmt), if_node

self.indent += 1

cases = []
self._collect_cases(if_node, cases)
default_block = self._collect_cases(if_node, cases)

grouped_cases = self._str_switch_cases(cases)
# Warning: this consumes internal iterator
#self.log('grouped %s', list(grouped_cases))

for str_len, group in grouped_cases:
self.def_write_ind('case %s: {\n' % str_len)
if_num = 0
for _, case_str, block in group:
self.indent += 1

# TODO:
# - Every element must be a constant string
# - group by length
else_str = '' if if_num == 0 else 'else '
self.def_write_ind('%sif (str_equals_c(%s, %s, %d)) ' %
(else_str, switch_var.name,
PythonStringLiteral(case_str), str_len))
self.accept(block)

self.log('CASES %s', cases)
self.indent -= 1
if_num += 1

self.indent += 1
self.def_write_ind('else {\n')
self.def_write_ind(' goto str_switch_default;\n')
self.def_write_ind('}\n')
self.indent -= 1

self.def_write_ind('}\n')
self.def_write_ind(' break;\n')

self.def_write('\n')
self.def_write_ind('str_switch_default:\n')
self.def_write_ind('default: ')
self.accept(default_block)

self.indent -= 1
self.def_write_ind('}\n')
Expand Down
83 changes: 83 additions & 0 deletions mycpp/examples/invalid_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python2
"""
invalid_except.py
"""
from __future__ import print_function

from mycpp.mylib import switch, str_switch, tagswitch


def NoDefault():
# type: () -> None

with switch(42) as case:
if case(42):
print('42')


def TagSwitch():
# type: () -> None

s = "foo"
with tagswitch(s) as case:
if 42:
print('ONE')
print('dupe')

elif 43:
print('TWO')

else:
print('neither')


def SwitchMustHaveCase():
# type: () -> None

i = 49
with switch(i) as case:
if 42:
print('ONE')
print('dupe')

elif 43:
print('TWO')

else:
print('neither')


def StrSwitchNoTuple():
# type: () -> None

s = "foo"
with str_switch(s) as case:
# Problem: if you switch on length, do you duplicate the bogies
if case('spam', 'different len'):
print('ONE')
print('dupe')

elif case('foo'):
print('TWO')

else:
print('neither')


def StrSwitchNoInt():
# type: () -> None

s = "foo"
with str_switch(s) as case:
# integer not allowed
if case(42):
print('ONE')
print('dupe')

else:
print('neither')


def run_tests():
# type: () -> None
pass
39 changes: 27 additions & 12 deletions mycpp/examples/test_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,38 @@
import os

from mycpp.mylib import switch, str_switch, log
#from mycpp.mylib import switch, log


def TestString(s):
# type: (str) -> None

print('''
#print('''
with str_switch(s) as case:
# Problem: if you switch on length, do you duplicate the bogies
if case('spam', 'different len'):
print('ONE')
print('dupe')
if case('spam'):
print('== %s ==' % s)
print('SPAM')
print('yes')

elif case('foo'):
print('TWO')
print('== %s ==' % s)
print('FOO')

elif case('bar'): # same length
print('== %s ==' % s)
print('BAR')

else:
print('== %s ==' % s)
print('neither')
''')
#''')
print('--')
print('')


def run_tests():
def TestNumSwitch():
# type: () -> None

TestString('spam')
TestString('foo')
TestString('zzz')

x = 5
with switch(x) as case:
if case(0):
Expand All @@ -52,6 +56,17 @@ def run_tests():
print('another')


def run_tests():
# type: () -> None

TestString('spam')
TestString('bar')
TestString('zzz') # same length as bar
TestString('different len')

TestNumSwitch()


def run_benchmarks():
# type: () -> None
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion mycpp/gc_builtins.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ bool keys_equal(Tuple2<BigStr*, int>* t1, Tuple2<BigStr*, int>* t2) {
return are_equal(t1, t2);
}

bool str_equals_c(const char* c_string, int c_len, BigStr* s) {
bool str_equals_c(BigStr* s, const char* c_string, int c_len) {
// Needs SmallStr change
if (len(s) == c_len) {
return memcmp(s->data_, c_string, c_len) == 0;
Expand Down
Loading

0 comments on commit 6efd0db

Please sign in to comment.