Skip to content

Commit

Permalink
Merge pull request #14 from google/crlf
Browse files Browse the repository at this point in the history
Fix CRLF handling on Windows & lstrip_blocks + trim_blocks behaviour (testing against jinja2)
  • Loading branch information
ochafik authored Dec 26, 2024
2 parents 916c181 + 1300d88 commit 202aa2f
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 73 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
# ubuntu-22.04,
ubuntu-latest,
# windows-2019,
# windows-latest,
windows-latest,
]
type: [
Release,
Expand Down Expand Up @@ -65,5 +65,4 @@ jobs:
run: cmake --build ${{github.workspace}}/build --config ${{ matrix.type }} --parallel

- name: Test
working-directory: ${{github.workspace}}/build
run: ctest --test-dir tests --output-on-failure --verbose -C ${{ matrix.type }}
run: ctest --test-dir build --output-on-failure --verbose -C ${{ matrix.type }}
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ foreach(example
raw
)
add_executable(${example} ${example}.cpp)
target_compile_features(${example} PUBLIC cxx_std_17)
target_link_libraries(${example} PRIVATE nlohmann_json::nlohmann_json)

endforeach()
88 changes: 60 additions & 28 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
#include <unordered_set>
#include <json.hpp>

#ifdef _WIN32
#define ENDL "\r\n"
#else
#define ENDL "\n"
#endif

using json = nlohmann::ordered_json;

namespace minja {
Expand All @@ -32,6 +38,15 @@ struct Options {

struct ArgumentsValue;

static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
static const std::regex nl_regex("\r\n");
return std::regex_replace(s, nl_regex, "\n");
#else
return s;
#endif
}

/* Values that behave roughly like in Python. */
class Value : public std::enable_shared_from_this<Value> {
public:
Expand Down Expand Up @@ -76,7 +91,7 @@ class Value : public std::enable_shared_from_this<Value> {
void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const {
auto print_indent = [&](int level) {
if (indent > 0) {
out << "\n";
out << ENDL;
for (int i = 0, n = level * indent; i < n; ++i) out << ' ';
}
};
Expand Down Expand Up @@ -547,11 +562,11 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
auto max_line = std::count(start, end, '\n') + 1;
auto col = pos - std::string(start, it).rfind('\n');
std::ostringstream out;
out << " at row " << line << ", column " << col << ":\n";
if (line > 1) out << get_line(line - 1) << "\n";
out << get_line(line) << "\n";
out << std::string(col - 1, ' ') << "^" << "\n";
if (line < max_line) out << get_line(line + 1) << "\n";
out << " at row " << line << ", column " << col << ":" ENDL;
if (line > 1) out << get_line(line - 1) << ENDL;
out << get_line(line) << ENDL;
out << std::string(col - 1, ' ') << "^" << ENDL;
if (line < max_line) out << get_line(line + 1) << ENDL;

return out.str();
}
Expand Down Expand Up @@ -786,7 +801,7 @@ class TemplateNode {
std::string render(const std::shared_ptr<Context> & context) const {
std::ostringstream out;
render(out, context);
return out.str();
return normalize_newlines(out.str());
}
};

Expand Down Expand Up @@ -1214,8 +1229,8 @@ class BinaryOpExpr : public Expression {
if (!l.to_bool()) return Value(false);
return right->evaluate(context).to_bool();
} else if (op == Op::Or) {
if (l.to_bool()) return Value(true);
return right->evaluate(context).to_bool();
if (l.to_bool()) return l;
return right->evaluate(context);
}

auto r = right->evaluate(context);
Expand Down Expand Up @@ -1292,6 +1307,10 @@ struct ArgumentsExpression {
static std::string strip(const std::string & s) {
static std::regex trailing_spaces_regex("^\\s+|\\s+$");
return std::regex_replace(s, trailing_spaces_regex, "");
// auto start = s.find_first_not_of(" \t\n\r");
// if (start == std::string::npos) return "";
// auto end = s.find_last_not_of(" \t\n\r");
// return s.substr(start, end - start + 1);
}

static std::string html_escape(const std::string & s) {
Expand All @@ -1302,7 +1321,7 @@ static std::string html_escape(const std::string & s) {
case '&': result += "&amp;"; break;
case '<': result += "&lt;"; break;
case '>': result += "&gt;"; break;
case '"': result += "&quot;"; break;
case '"': result += "&#34;"; break;
case '\'': result += "&apos;"; break;
default: result += c; break;
}
Expand Down Expand Up @@ -2101,13 +2120,14 @@ class Parser {
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex text_regex(R"([\s\S\n\r]*?($|(?=\{\{|\{%|\{#)))");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");

TemplateTokenVector tokens;
std::vector<std::string> group;
std::string text;
std::smatch match;

try {
while (it != end) {
Expand Down Expand Up @@ -2228,10 +2248,15 @@ class Parser {
} else {
throw std::runtime_error("Unexpected block: " + keyword);
}
} else if (!(text = consumeToken(text_regex, SpaceHandling::Keep)).empty()) {
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
auto text_end = it + match.position();
text = std::string(it, text_end);
it = text_end;
tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
} else {
if (it != end) throw std::runtime_error("Unexpected character");
text = std::string(it, end);
it = end;
tokens.push_back(std::make_unique<TextTemplateToken>(location, SpaceHandling::Keep, SpaceHandling::Keep, text));
}
}
return tokens;
Expand Down Expand Up @@ -2280,24 +2305,31 @@ class Parser {
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;

auto text = text_token->text;
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
static std::regex leading_line(R"(^[ \t]*\r?\n)");
text = std::regex_replace(text, leading_line, "");
}
if (post_space == SpaceHandling::Strip) {
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
text = std::regex_replace(text, trailing_space_regex, "");
} else if (options.lstrip_blocks && it != end) {
static std::regex trailing_last_line_space_regex(R"((\r?\n)[ \t]*$)");
text = std::regex_replace(text, trailing_last_line_space_regex, "$1");
auto i = text.size();
while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--;
if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) {
text.resize(i);
}
}
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
if (text.length() > 0 && text[0] == '\n') {
text.erase(0, 1);
}
}

if (it == end && !options.keep_trailing_newline) {
static std::regex r(R"(\r?\n$)");
text = std::regex_replace(text, r, ""); // Strip one trailing newline
auto i = text.size();
if (i > 0 && text[i - 1] == '\n') {
i--;
if (i > 0 && text[i - 1] == '\r') i--;
text.resize(i);
}
}
children.emplace_back(std::make_shared<TextNode>(token->location, text));
} else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
Expand Down Expand Up @@ -2357,7 +2389,7 @@ class Parser {
public:

static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
Parser parser(std::make_shared<std::string>(template_str), options);
Parser parser(std::make_shared<std::string>(normalize_newlines(template_str)), options);
auto tokens = parser.tokenize();
TemplateTokenIterator begin = tokens.begin();
auto it = begin;
Expand Down Expand Up @@ -2627,11 +2659,11 @@ inline std::shared_ptr<Context> Context::builtins() {
while (std::getline(iss, line, '\n')) {
auto needs_indent = !is_first || first;
if (is_first) is_first = false;
else out += "\n";
else out += ENDL;
if (needs_indent) out += indent;
out += line;
}
if (!text.empty() && text.back() == '\n') out += "\n";
if (!text.empty() && text.back() == '\n') out += ENDL;
return out;
}));
globals.set("selectattr", Value::callable([=](const std::shared_ptr<Context> & context, ArgumentsValue & args) {
Expand Down
21 changes: 21 additions & 0 deletions scripts/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
#
# SPDX-License-Identifier: MIT
import sys
import json
from jinja2 import Environment
import jinja2.ext
from pathlib import Path

input_file, output_file = sys.argv[1:3]
data = json.loads(Path(input_file).read_text())
# print(json.dumps(data, indent=2), file=sys.stderr)

env = Environment(**data['options'], extensions=[jinja2.ext.loopcontrols])
tmpl = env.from_string(data['template'])
output = tmpl.render(data['bindings'])
Path(output_file).write_text(output)
17 changes: 15 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@
# SPDX-License-Identifier: MIT

add_executable(test-syntax test-syntax.cpp)
target_compile_features(test-syntax PUBLIC cxx_std_17)
target_link_libraries(test-syntax PRIVATE
nlohmann_json::nlohmann_json
gtest_main
gmock
)
gtest_discover_tests(test-syntax)

add_test(NAME test-syntax-jinja2 COMMAND test-syntax)
set_tests_properties(test-syntax-jinja2 PROPERTIES ENVIRONMENT "USE_JINJA2=1;PYTHON_EXECUTABLE=${Python_EXECUTABLE};PYTHONPATH=${CMAKE_SOURCE_DIR}")


add_executable(test-chat-template test-chat-template.cpp)
target_compile_features(test-chat-template PUBLIC cxx_std_17)
target_link_libraries(test-chat-template PRIVATE nlohmann_json::nlohmann_json)

set(MODEL_IDS
Expand Down Expand Up @@ -68,15 +74,21 @@ set(MODEL_IDS
TheBloke/FusionNet_34Bx2_MoE-AWQ

# Broken, TODO:
# fireworks-ai/llama-3-firefunction-v2
# fireworks-ai/llama-3-firefunction-v2 # https://github.com/google/minja/issues/7
# ai21labs/AI21-Jamba-1.5-Large # https://github.com/google/minja/issues/8

# Can't find template(s), TODO:
# ai21labs/Jamba-v0.1
# apple/OpenELM-1_1B-Instruct
# dreamgen/WizardLM-2-7B
# xai-org/grok-1
)

if(WIN32)
list(REMOVE_ITEM MODEL_IDS
bofenghuang/vigogne-2-70b-chat
)
endif()

# Create one test case for each {template, context} combination
file(GLOB CONTEXT_FILES "${CMAKE_SOURCE_DIR}/tests/contexts/*.json")
execute_process(
Expand Down Expand Up @@ -109,6 +121,7 @@ if (MINJA_FUZZTEST_ENABLED)
fuzztest_setup_fuzzing_flags()
endif()
add_executable(test-fuzz test-fuzz.cpp)
target_compile_features(test-fuzz PUBLIC cxx_std_17)
target_include_directories(test-fuzz PRIVATE ${fuzztest_BINARY_DIR})
target_link_libraries(test-fuzz PRIVATE nlohmann_json::nlohmann_json)
link_fuzztest(test-fuzz)
Expand Down
2 changes: 1 addition & 1 deletion tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static std::string read_file(const std::string &path) {
std::string out;
out.resize(static_cast<size_t>(size));
fs.read(&out[0], static_cast<std::streamsize>(size));
return out;
return minja::normalize_newlines(out);
}

int main(int argc, char *argv[]) {
Expand Down
Loading

0 comments on commit 202aa2f

Please sign in to comment.