Skip to content

Commit

Permalink
Support Python 3's keyword-only arguments.
Browse files Browse the repository at this point in the history
Previously, they would parse correctly in Python 3, but any keyword-only
arguments would be quietly lost, and the user would either get
`TypeError: foo() got an unexpected keyword argument...` or the
confusing behavior of having the keyword argument overwritten with
whatever's in the context with the same name.
  • Loading branch information
eevee committed Feb 11, 2014
1 parent a574007 commit 836e5f9
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 37 deletions.
77 changes: 52 additions & 25 deletions mako/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,65 @@ def __init__(self, code, allow_kwargs=True, **exception_kwargs):
if not allow_kwargs and self.kwargs:
raise exceptions.CompileException(
"'**%s' keyword argument not allowed here" %
self.argnames[-1], **exception_kwargs)
self.kwargnames[-1], **exception_kwargs)

def get_argument_expressions(self, include_defaults=True):
"""return the argument declarations of this FunctionDecl as a printable
list."""
def get_argument_expressions(self, as_call=False):
"""Return the argument declarations of this FunctionDecl as a printable
list.
By default the return value is appropriate for writing in a ``def``;
set `as_call` to true to build arguments to be passed to the function
instead (assuming locals with the same names as the arguments exist).
"""

namedecls = []
defaults = [d for d in self.defaults]
kwargs = self.kwargs
varargs = self.varargs
argnames = [f for f in self.argnames]
argnames.reverse()
for arg in argnames:
default = None
if kwargs:
arg = "**" + arg_stringname(arg)
kwargs = False
elif varargs:
arg = "*" + arg_stringname(arg)
varargs = False

# Build in reverse order, since defaults and slurpy args come last
argnames = self.argnames[::-1]
kwargnames = self.kwargnames[::-1]
defaults = self.defaults[::-1]
kwdefaults = self.kwdefaults[::-1]

# Named arguments
if self.kwargs:
namedecls.append("**" + kwargnames.pop(0))

for name in kwargnames:
# Keyword-only arguments must always be used by name, so even if
# this is a call, print out `foo=foo`
if as_call:
namedecls.append("%s=%s" % (name, name))
elif kwdefaults:
default = kwdefaults.pop(0)
if default is None:
# The AST always gives kwargs a default, since you can do
# `def foo(*, a=1, b, c=3)`
namedecls.append(name)
else:
namedecls.append("%s=%s" % (
name, pyparser.ExpressionGenerator(default).value()))
else:
default = len(defaults) and defaults.pop() or None
if include_defaults and default:
namedecls.insert(0, "%s=%s" %
(arg,
pyparser.ExpressionGenerator(default).value()
)
)
namedecls.append(name)

# Positional arguments
if self.varargs:
namedecls.append("*" + argnames.pop(0))

for name in argnames:
if as_call or not defaults:
namedecls.append(name)
else:
namedecls.insert(0, arg)
default = defaults.pop(0)
namedecls.append("%s=%s" % (
name, pyparser.ExpressionGenerator(default).value()))

namedecls.reverse()
return namedecls

@property
def allargnames(self):
return self.argnames + self.kwargnames

class FunctionArgs(FunctionDecl):
"""the argument portion of a function declaration"""

Expand Down
4 changes: 2 additions & 2 deletions mako/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def write_def_decl(self, node, identifiers):
"""write a locally-available callable referencing a top-level def"""
funcname = node.funcname
namedecls = node.get_argument_expressions()
nameargs = node.get_argument_expressions(include_defaults=False)
nameargs = node.get_argument_expressions(as_call=True)

if not self.in_def and (
len(self.identifiers.locally_assigned) > 0 or
Expand Down Expand Up @@ -864,7 +864,7 @@ def visitBlockTag(self, node):
if node.is_anonymous:
self.printer.writeline("%s()" % node.funcname)
else:
nameargs = node.get_argument_expressions(include_defaults=False)
nameargs = node.get_argument_expressions(as_call=True)
nameargs += ['**pageargs']
self.printer.writeline("if 'parent' not in context._data or "
"not hasattr(context._data['parent'], '%s'):"
Expand Down
1 change: 1 addition & 0 deletions mako/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

py3k = sys.version_info >= (3, 0)
py33 = sys.version_info >= (3, 3)
py2k = sys.version_info < (3,)
py26 = sys.version_info >= (2, 6)
py25 = sys.version_info >= (2, 5)
jython = sys.platform.startswith('java')
Expand Down
12 changes: 6 additions & 6 deletions mako/parsetree.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def get_argument_expressions(self, **kw):
return self.function_decl.get_argument_expressions(**kw)

def declared_identifiers(self):
return self.function_decl.argnames
return self.function_decl.allargnames

def undeclared_identifiers(self):
res = []
Expand All @@ -451,7 +451,7 @@ def undeclared_identifiers(self):
).union(
self.expression_undeclared_identifiers
).difference(
self.function_decl.argnames
self.function_decl.allargnames
)

class BlockTag(Tag):
Expand Down Expand Up @@ -502,7 +502,7 @@ def get_argument_expressions(self, **kw):
return self.body_decl.get_argument_expressions(**kw)

def declared_identifiers(self):
return self.body_decl.argnames
return self.body_decl.allargnames

def undeclared_identifiers(self):
return (self.filter_args.\
Expand All @@ -524,7 +524,7 @@ def __init__(self, keyword, attributes, **kwargs):
**self.exception_kwargs)

def declared_identifiers(self):
return self.code.declared_identifiers.union(self.body_decl.argnames)
return self.code.declared_identifiers.union(self.body_decl.allargnames)

def undeclared_identifiers(self):
return self.code.undeclared_identifiers.\
Expand Down Expand Up @@ -554,7 +554,7 @@ def __init__(self, namespace, defname, attributes, **kwargs):
**self.exception_kwargs)

def declared_identifiers(self):
return self.code.declared_identifiers.union(self.body_decl.argnames)
return self.code.declared_identifiers.union(self.body_decl.allargnames)

def undeclared_identifiers(self):
return self.code.undeclared_identifiers.\
Expand Down Expand Up @@ -589,6 +589,6 @@ def __init__(self, keyword, attributes, **kwargs):
**self.exception_kwargs)

def declared_identifiers(self):
return self.body_decl.argnames
return self.body_decl.allargnames


14 changes: 13 additions & 1 deletion mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,25 @@ def __init__(self, listener, **exception_kwargs):

def visit_FunctionDef(self, node):
self.listener.funcname = node.name

argnames = [arg_id(arg) for arg in node.args.args]
if node.args.vararg:
argnames.append(arg_stringname(node.args.vararg))

if compat.py2k:
# kw-only args don't exist in Python 2
kwargnames = []
else:
kwargnames = [arg_id(arg) for arg in node.args.kwonlyargs]
if node.args.kwarg:
argnames.append(arg_stringname(node.args.kwarg))
kwargnames.append(arg_stringname(node.args.kwarg))
self.listener.argnames = argnames
self.listener.defaults = node.args.defaults # ast
self.listener.kwargnames = kwargnames
if compat.py2k:
self.listener.kwdefaults = []
else:
self.listener.kwdefaults = node.args.kw_defaults
self.listener.varargs = node.args.vararg
self.listener.kwargs = node.args.kwarg

Expand Down
3 changes: 3 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def maybe(*args, **kw):
return function_named(maybe, fn_name)
return decorate

def requires_python_3(fn):
return skip_if(lambda: not py3k, "Requires Python 3.xx")(fn)

def requires_python_2(fn):
return skip_if(lambda: py3k, "Requires Python 2.xx")(fn)

Expand Down
19 changes: 17 additions & 2 deletions test/test_ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from mako import ast, exceptions, pyparser, util, compat
from test import eq_, requires_python_2
from test import eq_, requires_python_2, requires_python_3

exception_kwargs = {
'source': '',
Expand Down Expand Up @@ -263,14 +263,29 @@ def test_function_decl(self):
eq_(parsed.funcname, 'foo')
eq_(parsed.argnames,
['a', 'b', 'c', 'd', 'e', 'f'])
eq_(parsed.kwargnames,
[])

def test_function_decl_2(self):
"""test getting the arguments from a function"""
code = "def foo(a, b, c=None, *args, **kwargs):pass"
parsed = ast.FunctionDecl(code, **exception_kwargs)
eq_(parsed.funcname, 'foo')
eq_(parsed.argnames,
['a', 'b', 'c', 'args', 'kwargs'])
['a', 'b', 'c', 'args'])
eq_(parsed.kwargnames,
['kwargs'])

@requires_python_3
def test_function_decl_3(self):
"""test getting the arguments from a fancy py3k function"""
code = "def foo(a, b, *c, d, e, **f):pass"
parsed = ast.FunctionDecl(code, **exception_kwargs)
eq_(parsed.funcname, 'foo')
eq_(parsed.argnames,
['a', 'b', 'c'])
eq_(parsed.kwargnames,
['d', 'e', 'f'])

def test_expr_generate(self):
"""test the round trip of expressions to AST back to python source"""
Expand Down
15 changes: 14 additions & 1 deletion test/test_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from mako import lookup
from test import TemplateTest
from test.util import flatten_result, result_lines
from test import eq_, assert_raises
from test import eq_, assert_raises, requires_python_3
from mako import compat

class DefTest(TemplateTest):
Expand Down Expand Up @@ -45,6 +45,19 @@ def test_def_args(self):
"""hello mycomp hi, 5, 6"""
)

@requires_python_3
def test_def_py3k_args(self):
template = Template("""
<%def name="kwonly(one, two, *three, four, five=5, **six)">
look at all these args: ${one} ${two} ${three[0]} ${four} ${five} ${six['seven']}
</%def>
${kwonly('one', 'two', 'three', four='four', seven='seven')}""")
eq_(
template.render(one=1, two=2, three=(3,), six=6).strip(),
"""look at all these args: one two three four 5 seven"""
)

def test_inter_def(self):
"""test defs calling each other"""
template = Template("""
Expand Down

0 comments on commit 836e5f9

Please sign in to comment.