Skip to content

Commit

Permalink
Added argument splatting support (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuhiko authored Nov 10, 2024
1 parent 92852d4 commit 3885d10
Show file tree
Hide file tree
Showing 20 changed files with 492 additions and 259 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ All notable changes to MiniJinja are documented here.
- `minijinja-cli` now does not convert INI files to lowercase anymore. This was
an unintended behavior. #633
- Moved up MSRV to 1.63.0 due to indexmap. #635
- Added argument splatting support (`*args` for variable args and `**kwargs`
for keyword arguments) and fixed a bug where sometimes maps and keyword
arguments were created in inverse order. #642

## 2.4.0

Expand Down
43 changes: 13 additions & 30 deletions minijinja/src/compiler/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ pub enum Expr<'a> {
Call(Spanned<Call<'a>>),
List(Spanned<List<'a>>),
Map(Spanned<Map<'a>>),
Kwargs(Spanned<Kwargs<'a>>),
}

#[cfg(feature = "internal_debug")]
Expand All @@ -165,7 +164,6 @@ impl<'a> fmt::Debug for Expr<'a> {
Expr::Call(s) => fmt::Debug::fmt(s, f),
Expr::List(s) => fmt::Debug::fmt(s, f),
Expr::Map(s) => fmt::Debug::fmt(s, f),
Expr::Kwargs(s) => fmt::Debug::fmt(s, f),
}
}
}
Expand All @@ -186,7 +184,6 @@ impl<'a> Expr<'a> {
Expr::Map(_) => "map literal",
Expr::Test(_) => "test expression",
Expr::Filter(_) => "filter expression",
Expr::Kwargs(_) => "keyword arguments",
}
}
}
Expand Down Expand Up @@ -444,7 +441,7 @@ pub struct IfExpr<'a> {
pub struct Filter<'a> {
pub name: &'a str,
pub expr: Option<Expr<'a>>,
pub args: Vec<Expr<'a>>,
pub args: Vec<CallArg<'a>>,
}

/// A test expression.
Expand All @@ -453,7 +450,7 @@ pub struct Filter<'a> {
pub struct Test<'a> {
pub name: &'a str,
pub expr: Expr<'a>,
pub args: Vec<Expr<'a>>,
pub args: Vec<CallArg<'a>>,
}

/// An attribute lookup expression.
Expand All @@ -477,7 +474,17 @@ pub struct GetItem<'a> {
#[cfg_attr(feature = "unstable_machinery_serde", derive(serde::Serialize))]
pub struct Call<'a> {
pub expr: Expr<'a>,
pub args: Vec<Expr<'a>>,
pub args: Vec<CallArg<'a>>,
}

/// A call argument helper
#[cfg_attr(feature = "internal_debug", derive(Debug))]
#[cfg_attr(feature = "unstable_machinery_serde", derive(serde::Serialize))]
pub enum CallArg<'a> {
Pos(Expr<'a>),
Kwarg(&'a str, Expr<'a>),
PosSplat(Expr<'a>),
KwargSplat(Expr<'a>),
}

/// Creates a list of values.
Expand All @@ -503,30 +510,6 @@ impl<'a> List<'a> {
}
}

/// Creates a map of kwargs
#[cfg_attr(feature = "internal_debug", derive(Debug))]
#[cfg_attr(feature = "unstable_machinery_serde", derive(serde::Serialize))]
pub struct Kwargs<'a> {
pub pairs: Vec<(&'a str, Expr<'a>)>,
}

impl<'a> Kwargs<'a> {
pub fn as_const(&self) -> Option<Value> {
if !self.pairs.iter().all(|x| matches!(x.1, Expr::Const(_))) {
return None;
}

let mut rv = value_map_with_capacity(self.pairs.len());
for (key, value) in &self.pairs {
if let Expr::Const(value) = value {
rv.insert(Value::from(*key), value.value.clone());
}
}

Some(crate::value::Kwargs::wrap(rv))
}
}

/// Creates a map of values.
#[cfg_attr(feature = "internal_debug", derive(Debug))]
#[cfg_attr(feature = "unstable_machinery_serde", derive(serde::Serialize))]
Expand Down
173 changes: 106 additions & 67 deletions minijinja/src/compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::compiler::instructions::{
use crate::compiler::tokens::Span;
use crate::output::CaptureMode;
use crate::value::ops::neg;
use crate::value::Value;
use crate::value::{Kwargs, Value, ValueMap};

#[cfg(test)]
use similar_asserts::assert_eq;
Expand Down Expand Up @@ -503,7 +503,7 @@ impl<'source> CodeGenerator<'source> {
self.add_with_span(Instruction::FastSuper, call.span());
return;
} else if name == "loop" && call.args.len() == 1 {
self.compile_expr(&call.args[0]);
self.compile_call_args(std::slice::from_ref(&call.args[0]), 0, None);
self.add(Instruction::FastRecurse);
return;
}
Expand Down Expand Up @@ -660,21 +660,17 @@ impl<'source> CodeGenerator<'source> {
if let Some(ref expr) = f.expr {
self.compile_expr(expr);
}
for arg in &f.args {
self.compile_expr(arg);
}
let arg_count = self.compile_call_args(&f.args, 1, None);
let local_id = get_local_id(&mut self.filter_local_ids, f.name);
self.add(Instruction::ApplyFilter(f.name, f.args.len() + 1, local_id));
self.add(Instruction::ApplyFilter(f.name, arg_count, local_id));
self.pop_span();
}
ast::Expr::Test(f) => {
self.push_span(f.span());
self.compile_expr(&f.expr);
for arg in &f.args {
self.compile_expr(arg);
}
let arg_count = self.compile_call_args(&f.args, 1, None);
let local_id = get_local_id(&mut self.test_local_ids, f.name);
self.add(Instruction::PerformTest(f.name, f.args.len() + 1, local_id));
self.add(Instruction::PerformTest(f.name, arg_count, local_id));
self.pop_span();
}
ast::Expr::GetAttr(g) => {
Expand Down Expand Up @@ -717,18 +713,6 @@ impl<'source> CodeGenerator<'source> {
self.add(Instruction::BuildMap(m.keys.len()));
}
}
ast::Expr::Kwargs(m) => {
if let Some(val) = m.as_const() {
self.add(Instruction::LoadConst(val));
} else {
self.set_line_from_span(m.span());
for (key, value) in &m.pairs {
self.add(Instruction::LoadConst(Value::from(*key)));
self.compile_expr(value);
}
self.add(Instruction::BuildKwargs(m.pairs.len()));
}
}
}
}

Expand All @@ -740,7 +724,7 @@ impl<'source> CodeGenerator<'source> {
self.push_span(c.span());
match c.identify_call() {
ast::CallType::Function(name) => {
let arg_count = self.compile_call_args(&c.args, caller);
let arg_count = self.compile_call_args(&c.args, 0, caller);
self.add(Instruction::CallFunction(name, arg_count));
}
#[cfg(feature = "multi_template")]
Expand All @@ -751,71 +735,126 @@ impl<'source> CodeGenerator<'source> {
}
ast::CallType::Method(expr, name) => {
self.compile_expr(expr);
let arg_count = self.compile_call_args(&c.args, caller);
self.add(Instruction::CallMethod(name, arg_count + 1));
let arg_count = self.compile_call_args(&c.args, 1, caller);
self.add(Instruction::CallMethod(name, arg_count));
}
ast::CallType::Object(expr) => {
self.compile_expr(expr);
let arg_count = self.compile_call_args(&c.args, caller);
self.add(Instruction::CallObject(arg_count + 1));
let arg_count = self.compile_call_args(&c.args, 1, caller);
self.add(Instruction::CallObject(arg_count));
}
};
self.pop_span();
}

fn compile_call_args(
&mut self,
args: &[ast::Expr<'source>],
args: &[ast::CallArg<'source>],
extra_args: usize,
caller: Option<&Caller<'source>>,
) -> usize {
match caller {
// we can conditionally compile the caller part here since this will
// nicely call through for non macro builds
#[cfg(feature = "macros")]
Some(caller) => self.compile_call_args_with_caller(args, caller),
_ => {
for arg in args {
self.compile_expr(arg);
) -> Option<u16> {
let mut pending_args = extra_args;
let mut num_args_batches = 0;
let mut has_kwargs = caller.is_some();
let mut static_kwargs = caller.is_none();

for arg in args {
match arg {
ast::CallArg::Pos(expr) => {
self.compile_expr(expr);
pending_args += 1;
}
ast::CallArg::PosSplat(expr) => {
if pending_args > 0 {
self.add(Instruction::BuildList(Some(pending_args)));
pending_args = 0;
num_args_batches += 1;
}
self.compile_expr(expr);
num_args_batches += 1;
}
ast::CallArg::Kwarg(_, expr) => {
if !matches!(expr, ast::Expr::Const(_)) {
static_kwargs = false;
}
has_kwargs = true;
}
ast::CallArg::KwargSplat(_) => {
static_kwargs = false;
has_kwargs = true;
}
args.len()
}
}
}

#[cfg(feature = "macros")]
fn compile_call_args_with_caller(
&mut self,
args: &[ast::Expr<'source>],
caller: &Caller<'source>,
) -> usize {
let mut injected_caller = false;
if has_kwargs {
let mut pending_kwargs = 0;
let mut num_kwargs_batches = 0;
let mut collected_kwargs = ValueMap::new();
for arg in args {
match arg {
ast::CallArg::Kwarg(key, value) => {
if static_kwargs {
if let ast::Expr::Const(c) = value {
collected_kwargs.insert(Value::from(*key), c.value.clone());
} else {
unreachable!();
}
} else {
self.add(Instruction::LoadConst(Value::from(*key)));
self.compile_expr(value);
pending_kwargs += 1;
}
}
ast::CallArg::KwargSplat(expr) => {
if pending_kwargs > 0 {
self.add(Instruction::BuildKwargs(pending_kwargs));
num_kwargs_batches += 1;
pending_kwargs = 0;
}
self.compile_expr(expr);
num_kwargs_batches += 1;
}
ast::CallArg::Pos(_) | ast::CallArg::PosSplat(_) => {}
}
}

// try to add the caller to already existing keyword arguments.
for arg in args {
if let ast::Expr::Kwargs(ref m) = arg {
self.set_line_from_span(m.span());
for (key, value) in &m.pairs {
self.add(Instruction::LoadConst(Value::from(*key)));
self.compile_expr(value);
}
self.add(Instruction::LoadConst(Value::from("caller")));
self.compile_macro_expression(caller);
self.add(Instruction::BuildKwargs(m.pairs.len() + 1));
injected_caller = true;
if !collected_kwargs.is_empty() {
self.add(Instruction::LoadConst(Kwargs::wrap(collected_kwargs)));
} else {
self.compile_expr(arg);
// The conditions above guarantee that if we collect static kwargs
// we cannot enter this block (single kwargs batch, no caller).

#[cfg(feature = "macros")]
{
if let Some(caller) = caller {
self.add(Instruction::LoadConst(Value::from("caller")));
self.compile_macro_expression(caller);
pending_kwargs += 1
}
}
if num_kwargs_batches > 0 {
if pending_kwargs > 0 {
self.add(Instruction::BuildKwargs(pending_kwargs));
num_kwargs_batches += 1;
}
self.add(Instruction::MergeKwargs(num_kwargs_batches));
} else {
self.add(Instruction::BuildKwargs(pending_kwargs));
}
}
pending_args += 1;
}

// if there are no keyword args so far, create a new kwargs object
// and add caller to that.
if !injected_caller {
self.add(Instruction::LoadConst(Value::from("caller")));
self.compile_macro_expression(caller);
self.add(Instruction::BuildKwargs(1));
args.len() + 1
if num_args_batches > 0 {
if pending_args > 0 {
self.add(Instruction::BuildList(Some(pending_args)));
num_args_batches += 1;
}
self.add(Instruction::UnpackLists(num_args_batches));
None
} else {
args.len()
assert!(pending_args as u16 as usize == pending_args);
Some(pending_args as u16)
}
}

Expand Down
16 changes: 11 additions & 5 deletions minijinja/src/compiler/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ pub enum Instruction<'source> {
/// Builds a kwargs map of the last n pairs on the stack.
BuildKwargs(usize),

/// Merges N kwargs maps on the list into one.
MergeKwargs(usize),

/// Builds a list of the last n pairs on the stack.
BuildList(Option<usize>),

/// Unpacks a list into N stack items.
UnpackList(usize),

/// Unpacks N lists onto the stack and pushes the number of items there were unpacked.
UnpackLists(usize),

/// Add the top two values
Add,

Expand Down Expand Up @@ -122,10 +128,10 @@ pub enum Instruction<'source> {
In,

/// Apply a filter.
ApplyFilter(&'source str, usize, LocalId),
ApplyFilter(&'source str, Option<u16>, LocalId),

/// Perform a filter.
PerformTest(&'source str, usize, LocalId),
PerformTest(&'source str, Option<u16>, LocalId),

/// Emit the stack top as output
Emit,
Expand Down Expand Up @@ -175,13 +181,13 @@ pub enum Instruction<'source> {
EndCapture,

/// Calls a global function
CallFunction(&'source str, usize),
CallFunction(&'source str, Option<u16>),

/// Calls a method
CallMethod(&'source str, usize),
CallMethod(&'source str, Option<u16>),

/// Calls an object
CallObject(usize),
CallObject(Option<u16>),

/// Duplicates the top item
DupTop,
Expand Down
Loading

0 comments on commit 3885d10

Please sign in to comment.