Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ jobs:
Release,
Debug,
]
sanitizer: [
none,
address,
thread,
undefined,
]
runs-on: ${{ matrix.setup.os }}
name: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}
name: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}-sanitizer-${{ matrix.sanitizer }}
timeout-minutes: 30

steps:
Expand All @@ -58,7 +64,7 @@ jobs:
- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.11
with:
key: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}
key: ${{ matrix.setup.os }}-${{ matrix.setup.build }}-${{ matrix.type }}-sanitizer-${{ matrix.sanitizer }}

- name: Set up CMake
uses: lukka/get-cmake@latest
Expand All @@ -75,7 +81,7 @@ jobs:
- name: Configure CMake
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: cmake -B ${{github.workspace}}/build ${{ matrix.setup.defines }} -DCMAKE_BUILD_TYPE=${{ matrix.type }}
run: cmake -B ${{github.workspace}}/build ${{ matrix.setup.defines }} -DCMAKE_BUILD_TYPE=${{ matrix.type }} -DMINJA_SANITIZER=${{ matrix.sanitizer }}

- name: Build
run: cmake --build ${{github.workspace}}/build --config ${{ matrix.type }} --parallel
Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ option(MINJA_EXAMPLE_ENABLED "minja: Build with example"
option(MINJA_FUZZTEST_ENABLED "minja: fuzztests enabled" MINJA_FUZZTEST_ENABLED_DEFAULT)
option(MINJA_FUZZTEST_FUZZING_MODE "minja: run fuzztests (if enabled) in fuzzing mode" OFF)
option(MINJA_USE_VENV "minja: use Python venv for build" MINJA_USE_VENV_DEFAULT)
set(MINJA_SANITIZERS thread address undefined none)
set(MINJA_SANITIZER none CACHE STRING "minja: sanitizer to use")
set_property(CACHE MINJA_SANITIZER PROPERTY STRINGS ${MINJA_SANITIZERS})

if (NOT MSVC AND NOT MINJA_SANITIZER STREQUAL "none")
message(STATUS "Using -fsanitize=${MINJA_SANITIZER}")
add_compile_options("-fsanitize=${MINJA_SANITIZER}")
link_libraries ("-fsanitize=${MINJA_SANITIZER}")
endif()

set(CMAKE_CXX_STANDARD 17)

Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,26 @@ Main limitations (non-exhaustive list):
./scripts/fuzzing_tests.sh
```

- Sanitizer tests:

```bash
for sanitizer in ADDRESS THREAD UNDEFINED ; do
docker run --rm \
-v "$PWD":/src:ro \
-v "$PWD/build-sanitizer-${sanitizer}":/src/build \
-w /src \
"$(echo "
FROM ghcr.io/astral-sh/uv:debian-slim
RUN apt-get update && apt-get install -y build-essential libcurl4-openssl-dev cmake clang-tidy
" | docker build . -q -f - )" \
bash -c "
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DMINJA_SANITIZER=${sanitizer} && \
cmake --build build -j --config Debug && \
ctest --test-dir build -j -C Debug --output-on-failure
"
done
```

- If your model's template doesn't run fine, please consider the following before [opening a bug](https://github.com/googlestaging/minja/issues/new):

- Is the template using any unsupported filter / test / method / global function, and which one(s)?
Expand Down
4 changes: 2 additions & 2 deletions include/minja/chat-template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ class chat_template {
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
}), {}, false);
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_str_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':") || contains(out, ">argument_needle<");
out = try_raw_render(json::array({
dummy_user_msg,
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
}), {}, false);
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
auto tool_call_renders_obj_arguments = contains(out, "<parameter=argument_needle>") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':") || contains(out, ">argument_needle<");

caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
Expand Down
63 changes: 37 additions & 26 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) {
}

/* Values that behave roughly like in Python. */
class Value : public std::enable_shared_from_this<Value> {
class Value {
public:
using CallableType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
using FilterType = std::function<Value(const std::shared_ptr<Context> &, ArgumentsValue &)>;
Expand Down Expand Up @@ -158,12 +158,14 @@ class Value : public std::enable_shared_from_this<Value> {
Value(const json & v) {
if (v.is_object()) {
auto object = std::make_shared<ObjectType>();
object->reserve(v.size());
for (auto it = v.begin(); it != v.end(); ++it) {
(*object)[it.key()] = it.value();
object->emplace_back(it.key(), Value(it.value()));
}
object_ = std::move(object);
} else if (v.is_array()) {
auto array = std::make_shared<ArrayType>();
array->reserve(v.size());
for (const auto& item : v) {
array->push_back(Value(item));
}
Expand Down Expand Up @@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
return out.str();
}

class Context : public std::enable_shared_from_this<Context> {
class Context {
protected:
Value values_;
std::shared_ptr<Context> parent_;
Expand Down Expand Up @@ -850,12 +852,12 @@ struct LoopControlTemplateToken : public TemplateToken {

struct CallTemplateToken : public TemplateToken {
std::shared_ptr<Expression> expr;
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e)
: TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {}
};

struct EndCallTemplateToken : public TemplateToken {
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post)
: TemplateToken(Type::EndCall, loc, pre, post) {}
};

Expand Down Expand Up @@ -1060,11 +1062,18 @@ class MacroNode : public TemplateNode {
}
}
}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([this, macro_context](const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto execution_context = Context::make(Value::object(), macro_context);

// Use init-capture to avoid dangling 'this' pointer and circular references
auto callable = Value::callable([weak_context = std::weak_ptr<Context>(context),
name = name, params = params, body = body,
named_param_positions = named_param_positions]
(const std::shared_ptr<Context> & call_context, ArgumentsValue & args) {
auto context_locked = weak_context.lock();
if (!context_locked) throw std::runtime_error("Macro context no longer valid");
auto execution_context = Context::make(Value::object(), context_locked);

if (call_context->contains("caller")) {
execution_context->set("caller", call_context->get("caller"));
Expand All @@ -1075,7 +1084,7 @@ class MacroNode : public TemplateNode {
auto & arg = args.args[i];
if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name());
param_set[i] = true;
auto & param_name = params[i].first;
const auto & param_name = params[i].first;
execution_context->set(param_name, arg);
}
for (auto & [arg_name, value] : args.kwargs) {
Expand All @@ -1094,7 +1103,7 @@ class MacroNode : public TemplateNode {
}
return body->render(execution_context);
});
macro_context->set(name->get_name(), callable);
context->set(name->get_name(), callable);
}
};

Expand Down Expand Up @@ -1264,7 +1273,7 @@ class SubscriptExpr : public Expression {
}
return result;

} else if (target_value.is_array()) {
} else if (target_value.is_array()) {
auto result = Value::array();
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
result.push_back(target_value.at(i));
Expand Down Expand Up @@ -1313,7 +1322,7 @@ static bool in(const Value & value, const Value & container) {
return (((container.is_array() || container.is_object()) && container.contains(value)) ||
(value.is_string() && container.is_string() &&
container.to_str().find(value.to_str()) != std::string::npos));
};
}

class BinaryOpExpr : public Expression {
public:
Expand Down Expand Up @@ -1640,13 +1649,17 @@ class CallNode : public TemplateNode {
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("CallNode.expr is null");
if (!body) throw std::runtime_error("CallNode.body is null");

auto caller = Value::callable([this, context](const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
return Value(body->render(context));

// Use init-capture to avoid dangling 'this' pointer and circular references
auto caller = Value::callable([weak_context = std::weak_ptr<Context>(context), body=body]
(const std::shared_ptr<Context> &, ArgumentsValue &) -> Value {
auto context_locked = weak_context.lock();
if (!context_locked) throw std::runtime_error("Caller context no longer valid");
return Value(body->render(context_locked));
});

context->set("caller", caller);

auto call_expr = dynamic_cast<CallExpr*>(expr.get());
if (!call_expr) {
throw std::runtime_error("Invalid call block syntax - expected function call");
Expand All @@ -1657,7 +1670,7 @@ class CallNode : public TemplateNode {
throw std::runtime_error("Call target must be callable: " + function.dump());
}
ArgumentsValue args = call_expr->args.evaluate(context);

Value result = function.call(context, args);
out << result.to_str();
}
Expand Down Expand Up @@ -2192,7 +2205,7 @@ class Parser {

auto value = parseValue();

while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
while (it != end && consumeSpaces() && peekSymbols({ "[", ".", "(" })) {
if (!consumeToken("[").empty()) {
std::shared_ptr<Expression> index;
auto slice_loc = get_location();
Expand All @@ -2215,7 +2228,7 @@ class Parser {
}
}
}

if ((has_first_colon || has_second_colon)) {
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
} else {
Expand All @@ -2237,15 +2250,13 @@ class Parser {
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
}
} else if (peekSymbols({ "(" })) {
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(get_location(), std::move(value), std::move(callParams));
}
consumeSpaces();
}

if (peekSymbols({ "(" })) {
auto location = get_location();
auto callParams = parseCallArgs();
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
}
return value;
}

Expand Down Expand Up @@ -2725,7 +2736,7 @@ inline std::shared_ptr<Context> Context::builtins() {
globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr<Context> &, Value & args) -> Value {
throw std::runtime_error(args.at("message").get<std::string>());
}));
globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr<Context> &, Value & args) {
globals.set("tojson", simple_function("tojson", { "value", "indent", "ensure_ascii" }, [](const std::shared_ptr<Context> &, Value & args) {
return Value(args.at("value").dump(args.get<int64_t>("indent", -1), /* to_json= */ true));
}));
globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr<Context> &, Value & args) {
Expand Down
11 changes: 8 additions & 3 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
return now.strftime(format)

def tojson(value, indent=None, ensure_ascii=False, sort_keys=False):
return json.dumps(value, indent=indent, ensure_ascii=ensure_ascii, sort_keys=sort_keys)

def join_cmake_path(parent, child):
'''
Expand Down Expand Up @@ -119,8 +121,11 @@ def __init__(self, template, env=None, filters=None, global_functions=None):
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
extensions=[jinja2.ext.loopcontrols],
)
# https://jinja.palletsprojects.com/en/stable/api/#policies
env.policies["json.dumps_function"] = tojson
env.filters['tojson'] = tojson
if filters:
for name, func in filters.items():
env.filters[name] = func
Expand Down Expand Up @@ -192,12 +197,12 @@ def make_tool_call(tool_name, arguments):
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", json.dumps(dummy_args_obj))]),
])
tool_call_renders_str_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out
tool_call_renders_str_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out or ">argument_needle<" in out
out = self.try_raw_render([
dummy_user_msg,
make_tool_calls_msg([make_tool_call("ipython", dummy_args_obj)]),
])
tool_call_renders_obj_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out
tool_call_renders_obj_arguments = "<parameter=argument_needle>" in out or '"argument_needle":' in out or "'argument_needle':" in out or ">argument_needle<" in out

caps.supports_tool_calls = tool_call_renders_str_arguments or tool_call_renders_obj_arguments
caps.requires_object_arguments = not tool_call_renders_str_arguments and tool_call_renders_obj_arguments
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ set(MODEL_IDS
Qwen/Qwen3-235B-A22B-Thinking-2507
Qwen/Qwen3-Coder-30B-A3B-Instruct
Qwen/QwQ-32B
zai-org/GLM-4.6

# Broken, TODO:
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8
Expand Down
12 changes: 12 additions & 0 deletions tests/test-capabilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,15 @@ TEST(CapabilitiesTest, CommandRPlusToolUse) {
// EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_typed_content);
}

TEST(CapabilitiesTest, GLM46) {
auto caps = get_caps("tests/zai-org-GLM-4.6.jinja");
EXPECT_TRUE(caps.supports_system_role);
EXPECT_TRUE(caps.supports_tools);
EXPECT_TRUE(caps.supports_tool_calls);
EXPECT_TRUE(caps.supports_tool_responses);
EXPECT_TRUE(caps.supports_parallel_tool_calls);
EXPECT_TRUE(caps.requires_object_arguments);
// EXPECT_TRUE(caps.requires_non_null_content);
EXPECT_FALSE(caps.requires_typed_content);
}
2 changes: 1 addition & 1 deletion tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <gmock/gmock-matchers.h>

#include <fstream>
#include <iostream>
#include <string>

static std::string render_python(const std::string & template_str, const json & bindings, const minja::Options & options) {
Expand Down Expand Up @@ -373,6 +372,7 @@ TEST(SyntaxTest, SimpleCases) {
{}, {}
)
);
EXPECT_EQ("False", render("{{ trim(' a ').endswith(' ') }}", {} , {})); // Test parsing of expression (chaining of identifier, function call, method call)
}
EXPECT_EQ(
"[0, 1, 2][0, 2]",
Expand Down
Loading