diff --git a/mycpp/TEST.sh b/mycpp/TEST.sh index d271a54867..f2e645fdb4 100755 --- a/mycpp/TEST.sh +++ b/mycpp/TEST.sh @@ -232,6 +232,9 @@ test-invalid-examples() { */invalid_global.py) expected_status=2 ;; + */invalid_switch.py) + expected_status=5 + ;; esac if test $status -ne $expected_status; then diff --git a/mycpp/cppgen_pass.py b/mycpp/cppgen_pass.py index 50b9735492..6474d79788 100644 --- a/mycpp/cppgen_pass.py +++ b/mycpp/cppgen_pass.py @@ -2,6 +2,7 @@ cppgen.py - AST pass to that prints C++ code """ import io +import itertools import json # for "C escaping" from typing import overload, Union, Optional, Dict @@ -385,8 +386,8 @@ def GetCReturnType(t) -> Tuple[str, bool, Optional[str]]: def PythonStringLiteral(s: str) -> str: """ - Returns a properly quoted string. - """ + Returns a properly quoted string. + """ # MyPy does bad escaping. Decode and push through json to get something # workable in C++. return json.dumps(format_strings.DecodeMyPyString(s)) @@ -1912,6 +1913,13 @@ def _collect_cases(self, if_node, out): """ The MyPy AST has a recursive structure for if-elif-elif rather than a flat one. It's a bit confusing. + + Appends (expr, block) cases to out param, and returns the default + block, which has no expression. + + default block may be None. + + Returns False if there is no default block. """ assert isinstance(if_node, IfStmt), if_node assert len(if_node.expr) == 1, if_node.expr @@ -1920,6 +1928,11 @@ def _collect_cases(self, if_node, out): expr = if_node.expr[0] body = if_node.body[0] + if not isinstance(expr, CallExpr): + self.report_error(expr, + 'Expected call like case(x), got %s' % expr) + return + out.append((expr, body)) if if_node.else_body: @@ -1930,30 +1943,44 @@ def _collect_cases(self, if_node, out): # if 0: if isinstance(first_of_block, IfStmt): - self._collect_cases(first_of_block, out) + return self._collect_cases(first_of_block, out) else: # default case - no expression - out.append((None, if_node.else_body)) + return if_node.else_body + + return False # NO DEFAULT BLOCK - Different than None - def _write_cases(self, cases): + def _write_cases(self, switch_expr, cases, default_block): """ Write a list of (expr, block) pairs """ for expr, body in cases: - if expr is not None: - assert isinstance(expr, CallExpr), expr - for i, arg in enumerate(expr.args): - if i != 0: - self.def_write('\n') - self.def_write_ind('case ') - self.accept(arg) - self.def_write(': ') + assert expr is not None, expr + if not isinstance(expr, CallExpr): + self.report_error(expr, + 'Expected call like case(x), got %s' % expr) + return - self.accept(body) - self.def_write_ind(' break;\n') - else: - self.def_write_ind('default: ') - self.accept(body) # the whole block - # don't write 'break' + for i, arg in enumerate(expr.args): + if i != 0: + self.def_write('\n') + self.def_write_ind('case ') + self.accept(arg) + self.def_write(': ') + + self.accept(body) + self.def_write_ind(' break;\n') + + if default_block is None: + # an error occurred + return + if default_block is False: + self.report_error(switch_expr, + 'switch got no else: for default block') + return + + self.def_write_ind('default: ') + self.accept(default_block) + # don't write 'break' def _write_switch(self, expr, o): """Write a switch statement over integers.""" @@ -1969,8 +1996,8 @@ def _write_switch(self, expr, o): self.indent += 1 cases = [] - self._collect_cases(if_node, cases) - self._write_cases(cases) + default_block = self._collect_cases(if_node, cases) + self._write_cases(expr, cases, default_block) self.indent -= 1 self.def_write_ind('}\n') @@ -1989,20 +2016,57 @@ def _write_tag_switch(self, expr, o): self.indent += 1 cases = [] - self._collect_cases(if_node, cases) - self._write_cases(cases) + default_block = self._collect_cases(if_node, cases) + self._write_cases(expr, cases, default_block) self.indent -= 1 self.def_write_ind('}\n') + def _str_switch_cases(self, cases): + cases2 = [] + for expr, body in cases: + if not isinstance(expr, CallExpr): + # non-fatal check from _collect_cases + break + + args = expr.args + if len(args) != 1: + self.report_error( + expr, + 'str_switch can only have case("x"), not case("x", "y")' % + args) + break + + if not isinstance(args[0], StrExpr): + self.report_error( + expr, + 'str_switch can only be used with constant strings, got %s' + % args[0]) + break + + s = args[0].value + cases2.append((len(s), s, body)) + + # Sort by string length + cases2.sort(key=lambda pair: pair[0]) + grouped = itertools.groupby(cases2, key=lambda pair: pair[0]) + return grouped + def _write_str_switch(self, expr, o): """Write a switch statement over strings.""" assert len(expr.args) == 1, expr.args - self.def_write_ind('switch (len(') - self.accept(expr.args[0]) - self.def_write(')) {\n') + switch_var = expr.args[0] + if not isinstance(switch_var, NameExpr): + self.report_error( + expr.args[0], + 'str_switch(x) accepts only a variable name, got %s' % + switch_var) + return + self.def_write_ind('switch (len(%s)) {\n' % switch_var.name) + + # There can only be one thing under 'with str_switch' assert len(o.body.body) == 1, o.body.body if_node = o.body.body[0] assert isinstance(if_node, IfStmt), if_node @@ -2010,13 +2074,40 @@ def _write_str_switch(self, expr, o): self.indent += 1 cases = [] - self._collect_cases(if_node, cases) + default_block = self._collect_cases(if_node, cases) + + grouped_cases = self._str_switch_cases(cases) + # Warning: this consumes internal iterator + #self.log('grouped %s', list(grouped_cases)) + + for str_len, group in grouped_cases: + self.def_write_ind('case %s: {\n' % str_len) + if_num = 0 + for _, case_str, block in group: + self.indent += 1 - # TODO: - # - Every element must be a constant string - # - group by length + else_str = '' if if_num == 0 else 'else ' + self.def_write_ind('%sif (str_equals_c(%s, %s, %d)) ' % + (else_str, switch_var.name, + PythonStringLiteral(case_str), str_len)) + self.accept(block) - self.log('CASES %s', cases) + self.indent -= 1 + if_num += 1 + + self.indent += 1 + self.def_write_ind('else {\n') + self.def_write_ind(' goto str_switch_default;\n') + self.def_write_ind('}\n') + self.indent -= 1 + + self.def_write_ind('}\n') + self.def_write_ind(' break;\n') + + self.def_write('\n') + self.def_write_ind('str_switch_default:\n') + self.def_write_ind('default: ') + self.accept(default_block) self.indent -= 1 self.def_write_ind('}\n') diff --git a/mycpp/examples/invalid_switch.py b/mycpp/examples/invalid_switch.py new file mode 100644 index 0000000000..fa73dd150c --- /dev/null +++ b/mycpp/examples/invalid_switch.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python2 +""" +invalid_except.py +""" +from __future__ import print_function + +from mycpp.mylib import switch, str_switch, tagswitch + + +def NoDefault(): + # type: () -> None + + with switch(42) as case: + if case(42): + print('42') + + +def TagSwitch(): + # type: () -> None + + s = "foo" + with tagswitch(s) as case: + if 42: + print('ONE') + print('dupe') + + elif 43: + print('TWO') + + else: + print('neither') + + +def SwitchMustHaveCase(): + # type: () -> None + + i = 49 + with switch(i) as case: + if 42: + print('ONE') + print('dupe') + + elif 43: + print('TWO') + + else: + print('neither') + + +def StrSwitchNoTuple(): + # type: () -> None + + s = "foo" + with str_switch(s) as case: + # Problem: if you switch on length, do you duplicate the bogies + if case('spam', 'different len'): + print('ONE') + print('dupe') + + elif case('foo'): + print('TWO') + + else: + print('neither') + + +def StrSwitchNoInt(): + # type: () -> None + + s = "foo" + with str_switch(s) as case: + # integer not allowed + if case(42): + print('ONE') + print('dupe') + + else: + print('neither') + + +def run_tests(): + # type: () -> None + pass diff --git a/mycpp/examples/test_switch.py b/mycpp/examples/test_switch.py index 86952f46a2..89e1cd29e0 100755 --- a/mycpp/examples/test_switch.py +++ b/mycpp/examples/test_switch.py @@ -7,34 +7,38 @@ import os from mycpp.mylib import switch, str_switch, log -#from mycpp.mylib import switch, log def TestString(s): # type: (str) -> None - print(''' + #print(''' with str_switch(s) as case: # Problem: if you switch on length, do you duplicate the bogies - if case('spam', 'different len'): - print('ONE') - print('dupe') + if case('spam'): + print('== %s ==' % s) + print('SPAM') + print('yes') elif case('foo'): - print('TWO') + print('== %s ==' % s) + print('FOO') + + elif case('bar'): # same length + print('== %s ==' % s) + print('BAR') else: + print('== %s ==' % s) print('neither') - ''') + #''') + print('--') + print('') -def run_tests(): +def TestNumSwitch(): # type: () -> None - TestString('spam') - TestString('foo') - TestString('zzz') - x = 5 with switch(x) as case: if case(0): @@ -52,6 +56,17 @@ def run_tests(): print('another') +def run_tests(): + # type: () -> None + + TestString('spam') + TestString('bar') + TestString('zzz') # same length as bar + TestString('different len') + + TestNumSwitch() + + def run_benchmarks(): # type: () -> None raise NotImplementedError() diff --git a/mycpp/gc_builtins.cc b/mycpp/gc_builtins.cc index 769c113a03..fc23c523e4 100644 --- a/mycpp/gc_builtins.cc +++ b/mycpp/gc_builtins.cc @@ -419,7 +419,7 @@ bool keys_equal(Tuple2* t1, Tuple2* t2) { return are_equal(t1, t2); } -bool str_equals_c(const char* c_string, int c_len, BigStr* s) { +bool str_equals_c(BigStr* s, const char* c_string, int c_len) { // Needs SmallStr change if (len(s) == c_len) { return memcmp(s->data_, c_string, c_len) == 0; diff --git a/mycpp/gc_builtins.h b/mycpp/gc_builtins.h index e2e66a33f6..de758db84a 100644 --- a/mycpp/gc_builtins.h +++ b/mycpp/gc_builtins.h @@ -155,8 +155,8 @@ inline bool to_bool(int i) { bool str_contains(BigStr* haystack, BigStr* needle); -// Used by 'with switch(s)' -bool str_equals_c(const char* c_string, int c_len, BigStr* s); +// Used by 'with str_switch(s)' +bool str_equals_c(BigStr* s, const char* c_string, int c_len); // Only used by unit tests bool str_equals0(const char* c_string, BigStr* s);