diff --git a/.github/workflows/config.yml b/.github/workflows/config.yml new file mode 100644 index 0000000000000..9fc2639be64c5 --- /dev/null +++ b/.github/workflows/config.yml @@ -0,0 +1,41 @@ +name: YAML Config Tests + +on: + push: + branches: [ "master", "devin/*" ] + pull_request: + branches: [ "master" ] + +jobs: + test-yaml-config: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake git-lfs + git lfs install + + - name: Download tiny model (stories15M) + run: | + mkdir -p models + # Download only the specific model file we need (19MB) to avoid disk space issues + wget -q "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" -O models/stories15M-q4_0.gguf + ls -lh models/stories15M-q4_0.gguf + + - name: Build + id: cmake_build + run: | + cmake -B build -DLLAMA_BUILD_TESTS=ON -DLLAMA_BUILD_TOOLS=ON -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=OFF + cmake --build build --config Release -j $(nproc) + + - name: Test YAML config functionality + run: | + cd build + ctest -R "test-config-yaml|test-config-yaml-cli-.*|test-config-yaml-parity" --output-on-failure --timeout 300 diff --git a/README.md b/README.md index 17f59e988e3d1..cbdd697977dc4 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,38 @@ llama-cli -hf ggml-org/gemma-3-1b-it-GGUF llama-server -hf ggml-org/gemma-3-1b-it-GGUF ``` +### YAML Configuration + +You can use YAML configuration files to set parameters instead of command-line flags: + +```bash +llama-cli --config configs/minimal.yaml +``` + +Example `minimal.yaml`: +```yaml +model: + path: models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf +n_ctx: 256 +sampling: + seed: 42 + temp: 0.0 +prompt: "Hello from YAML" +n_predict: 16 +simple_io: true +``` + +You can override YAML values with command-line flags: +```bash +llama-cli --config configs/minimal.yaml -n 32 --temp 0.8 +``` + +**Precedence rules:** Command-line flags > YAML config > defaults + +**Path resolution:** Relative paths in YAML files are resolved relative to the YAML file's directory. + +**Error handling:** Unknown YAML keys will cause an error with a list of valid keys. + ## Description The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0ae4d698f080c..b40b1f155e3a3 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -133,6 +133,32 @@ if (LLAMA_LLGUIDANCE) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) endif () +if (LLAMA_BUILD_TOOLS) + # yaml-cpp for YAML config (CLI-only) + find_package(yaml-cpp QUIET) + if (NOT yaml-cpp_FOUND) + include(FetchContent) + FetchContent_Declare(yaml-cpp + GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git + GIT_TAG 0.8.0) + set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(YAML_CPP_BUILD_TOOLS OFF CACHE BOOL "" FORCE) + set(YAML_CPP_BUILD_CONTRIB OFF CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(yaml-cpp) + # Suppress all warnings for yaml-cpp to avoid -Werror failures + if(TARGET yaml-cpp) + target_compile_options(yaml-cpp PRIVATE -w) + endif() + endif() + + target_sources(${TARGET} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/config.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/config.h + ) + target_link_libraries(${TARGET} PRIVATE yaml-cpp) + target_compile_definitions(${TARGET} PUBLIC LLAMA_ENABLE_CONFIG_YAML) +endif() + target_include_directories(${TARGET} PUBLIC . ../vendor) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) diff --git a/common/arg.cpp b/common/arg.cpp index fcee0c4470077..5d4a326b5ef75 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,6 +2,9 @@ #include "chat.h" #include "common.h" +#ifdef LLAMA_ENABLE_CONFIG_YAML +#include "config.h" +#endif #include "gguf.h" // for reading GGUF splits #include "json-schema-to-grammar.h" #include "log.h" @@ -1223,6 +1226,26 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e const common_params params_org = ctx_arg.params; // the example can modify the default params try { +#ifdef LLAMA_ENABLE_CONFIG_YAML + for (int i = 1; i < argc; ++i) { + if (std::string(argv[i]) == "--config") { + if (i + 1 >= argc) { + throw std::invalid_argument("error: --config requires a file path"); + } + std::string cfg_path = argv[++i]; + if (!common_load_yaml_config(cfg_path, ctx_arg.params)) { + throw std::invalid_argument("error: failed to load YAML config: " + cfg_path); + } + break; + } + } +#else + for (int i = 1; i < argc; ++i) { + if (std::string(argv[i]) == "--config") { + throw std::invalid_argument("error: this build does not include YAML config support (LLAMA_BUILD_TOOLS=OFF)"); + } + } +#endif if (!common_params_parse_ex(argc, argv, ctx_arg)) { ctx_arg.params = params_org; return false; @@ -1317,6 +1340,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.completion = true; } )); + +#ifdef LLAMA_ENABLE_CONFIG_YAML + add_opt(common_arg( + {"--config"}, + "", + "Load parameters from a YAML config file; flags passed on the command line override values from the YAML file.", + [](common_params &, const std::string &) { + } + )); +#endif add_opt(common_arg( {"--verbose-prompt"}, string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), diff --git a/common/config.cpp b/common/config.cpp new file mode 100644 index 0000000000000..78a1605fa8cae --- /dev/null +++ b/common/config.cpp @@ -0,0 +1,343 @@ +#ifdef LLAMA_ENABLE_CONFIG_YAML + +#include "config.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +static std::set get_valid_keys() { + return { + "model.path", "model.url", "model.hf_repo", "model.hf_file", + "model_alias", "hf_token", "prompt", "system_prompt", "prompt_file", + "n_predict", "n_ctx", "n_batch", "n_ubatch", "n_keep", "n_chunks", + "n_parallel", "n_sequences", "grp_attn_n", "grp_attn_w", "n_print", + "rope_freq_base", "rope_freq_scale", "yarn_ext_factor", "yarn_attn_factor", + "yarn_beta_fast", "yarn_beta_slow", "yarn_orig_ctx", + "n_gpu_layers", "main_gpu", "split_mode", "pooling_type", "attention_type", + "flash_attn_type", "numa", "use_mmap", "use_mlock", "verbose_prompt", + "display_prompt", "no_kv_offload", "warmup", "check_tensors", "no_op_offload", + "no_extra_bufts", "cache_type_k", "cache_type_v", "conversation_mode", + "simple_io", "interactive", "interactive_first", "input_prefix", "input_suffix", + "logits_file", "path_prompt_cache", "antiprompt", "in_files", "kv_overrides", + "tensor_buft_overrides", "lora_adapters", "control_vectors", "image", "seed", + "sampling.seed", "sampling.n_prev", "sampling.n_probs", "sampling.min_keep", + "sampling.top_k", "sampling.top_p", "sampling.min_p", "sampling.xtc_probability", + "sampling.xtc_threshold", "sampling.typ_p", "sampling.temp", "sampling.dynatemp_range", + "sampling.dynatemp_exponent", "sampling.penalty_last_n", "sampling.penalty_repeat", + "sampling.penalty_freq", "sampling.penalty_present", "sampling.dry_multiplier", + "sampling.dry_base", "sampling.dry_allowed_length", "sampling.dry_penalty_last_n", + "sampling.mirostat", "sampling.mirostat_tau", "sampling.mirostat_eta", + "sampling.top_n_sigma", "sampling.ignore_eos", "sampling.no_perf", + "sampling.timing_per_token", "sampling.dry_sequence_breakers", "sampling.samplers", + "sampling.grammar", "sampling.grammar_lazy", "sampling.grammar_triggers", + "speculative.devices", "speculative.n_ctx", "speculative.n_max", "speculative.n_min", + "speculative.n_gpu_layers", "speculative.p_split", "speculative.p_min", + "speculative.model.path", "speculative.model.url", "speculative.model.hf_repo", + "speculative.model.hf_file", "speculative.tensor_buft_overrides", + "speculative.cpuparams", "speculative.cpuparams_batch", + "vocoder.model.path", "vocoder.model.url", "vocoder.model.hf_repo", + "vocoder.model.hf_file", "vocoder.speaker_file", "vocoder.use_guide_tokens" + }; +} + +std::string common_yaml_valid_keys_help() { + const auto keys = get_valid_keys(); + std::ostringstream ss; + bool first = true; + for (const auto & key : keys) { + if (!first) ss << ", "; + ss << key; + first = false; + } + return ss.str(); +} + +static std::string resolve_path(const std::string & path, const fs::path & yaml_dir) { + fs::path p(path); + if (p.is_absolute()) { + return path; + } + return fs::weakly_canonical(yaml_dir / p).string(); +} + +static void collect_keys(const YAML::Node & node, const std::string & prefix, std::set & found_keys) { + if (node.IsMap()) { + for (const auto & kv : node) { + std::string key = kv.first.as(); + std::string full_key = prefix.empty() ? key : prefix + "." + key; + found_keys.insert(full_key); + collect_keys(kv.second, full_key, found_keys); + } + } +} + +static void validate_keys(const YAML::Node & root) { + std::set found_keys; + collect_keys(root, "", found_keys); + + const auto valid_keys = get_valid_keys(); + std::vector unknown_keys; + + for (const auto & key : found_keys) { + if (valid_keys.find(key) == valid_keys.end()) { + bool is_parent = false; + for (const auto & valid_key : valid_keys) { + if (valid_key.find(key + ".") == 0) { + is_parent = true; + break; + } + } + if (!is_parent) { + unknown_keys.push_back(key); + } + } + } + if (!unknown_keys.empty()) { + std::ostringstream ss; + ss << "Unknown YAML keys: "; + for (size_t i = 0; i < unknown_keys.size(); ++i) { + if (i > 0) ss << ", "; + ss << unknown_keys[i]; + } + ss << "; valid keys are: " << common_yaml_valid_keys_help(); + throw std::invalid_argument(ss.str()); + } +} + +static ggml_type parse_ggml_type(const std::string & type_str) { + if (type_str == "f32") return GGML_TYPE_F32; + if (type_str == "f16") return GGML_TYPE_F16; + if (type_str == "bf16") return GGML_TYPE_BF16; + if (type_str == "q8_0") return GGML_TYPE_Q8_0; + if (type_str == "q4_0") return GGML_TYPE_Q4_0; + if (type_str == "q4_1") return GGML_TYPE_Q4_1; + if (type_str == "iq4_nl") return GGML_TYPE_IQ4_NL; + if (type_str == "q5_0") return GGML_TYPE_Q5_0; + if (type_str == "q5_1") return GGML_TYPE_Q5_1; + throw std::invalid_argument("Unknown ggml_type: " + type_str); +} + +static enum llama_split_mode parse_split_mode(const std::string & mode_str) { + if (mode_str == "none") return LLAMA_SPLIT_MODE_NONE; + if (mode_str == "layer") return LLAMA_SPLIT_MODE_LAYER; + if (mode_str == "row") return LLAMA_SPLIT_MODE_ROW; + throw std::invalid_argument("Unknown split_mode: " + mode_str); +} + +static enum llama_pooling_type parse_pooling_type(const std::string & type_str) { + if (type_str == "unspecified") return LLAMA_POOLING_TYPE_UNSPECIFIED; + if (type_str == "none") return LLAMA_POOLING_TYPE_NONE; + if (type_str == "mean") return LLAMA_POOLING_TYPE_MEAN; + if (type_str == "cls") return LLAMA_POOLING_TYPE_CLS; + if (type_str == "last") return LLAMA_POOLING_TYPE_LAST; + if (type_str == "rank") return LLAMA_POOLING_TYPE_RANK; + throw std::invalid_argument("Unknown pooling_type: " + type_str); +} + +static enum llama_attention_type parse_attention_type(const std::string & type_str) { + if (type_str == "unspecified") return LLAMA_ATTENTION_TYPE_UNSPECIFIED; + if (type_str == "causal") return LLAMA_ATTENTION_TYPE_CAUSAL; + if (type_str == "non_causal") return LLAMA_ATTENTION_TYPE_NON_CAUSAL; + throw std::invalid_argument("Unknown attention_type: " + type_str); +} + +static enum llama_flash_attn_type parse_flash_attn_type(const std::string & type_str) { + if (type_str == "auto") return LLAMA_FLASH_ATTN_TYPE_AUTO; + if (type_str == "disabled") return LLAMA_FLASH_ATTN_TYPE_DISABLED; + if (type_str == "enabled") return LLAMA_FLASH_ATTN_TYPE_ENABLED; + throw std::invalid_argument("Unknown flash_attn_type: " + type_str); +} + +static ggml_numa_strategy parse_numa_strategy(const std::string & strategy_str) { + if (strategy_str == "disabled") return GGML_NUMA_STRATEGY_DISABLED; + if (strategy_str == "distribute") return GGML_NUMA_STRATEGY_DISTRIBUTE; + if (strategy_str == "isolate") return GGML_NUMA_STRATEGY_ISOLATE; + if (strategy_str == "numactl") return GGML_NUMA_STRATEGY_NUMACTL; + if (strategy_str == "mirror") return GGML_NUMA_STRATEGY_MIRROR; + throw std::invalid_argument("Unknown numa_strategy: " + strategy_str); +} + +static common_conversation_mode parse_conversation_mode(const std::string & mode_str) { + if (mode_str == "auto") return COMMON_CONVERSATION_MODE_AUTO; + if (mode_str == "enabled") return COMMON_CONVERSATION_MODE_ENABLED; + if (mode_str == "disabled") return COMMON_CONVERSATION_MODE_DISABLED; + throw std::invalid_argument("Unknown conversation_mode: " + mode_str); +} + +bool common_load_yaml_config(const std::string & path, common_params & params) { + try { + YAML::Node root = YAML::LoadFile(path); + + validate_keys(root); + + fs::path yaml_dir = fs::absolute(path).parent_path(); + + if (root["model"]) { + auto model = root["model"]; + if (model["path"]) { + params.model.path = resolve_path(model["path"].as(), yaml_dir); + } + if (model["url"]) { + params.model.url = model["url"].as(); + } + if (model["hf_repo"]) { + params.model.hf_repo = model["hf_repo"].as(); + } + if (model["hf_file"]) { + params.model.hf_file = model["hf_file"].as(); + } + } + + if (root["model_alias"]) params.model_alias = root["model_alias"].as(); + if (root["hf_token"]) params.hf_token = root["hf_token"].as(); + if (root["prompt"]) params.prompt = root["prompt"].as(); + if (root["system_prompt"]) params.system_prompt = root["system_prompt"].as(); + if (root["prompt_file"]) { + params.prompt_file = resolve_path(root["prompt_file"].as(), yaml_dir); + } + if (root["n_predict"]) params.n_predict = root["n_predict"].as(); + if (root["n_ctx"]) params.n_ctx = root["n_ctx"].as(); + if (root["n_batch"]) params.n_batch = root["n_batch"].as(); + if (root["n_ubatch"]) params.n_ubatch = root["n_ubatch"].as(); + if (root["n_keep"]) params.n_keep = root["n_keep"].as(); + if (root["n_chunks"]) params.n_chunks = root["n_chunks"].as(); + if (root["n_parallel"]) params.n_parallel = root["n_parallel"].as(); + if (root["n_sequences"]) params.n_sequences = root["n_sequences"].as(); + if (root["grp_attn_n"]) params.grp_attn_n = root["grp_attn_n"].as(); + if (root["grp_attn_w"]) params.grp_attn_w = root["grp_attn_w"].as(); + if (root["n_print"]) params.n_print = root["n_print"].as(); + if (root["rope_freq_base"]) params.rope_freq_base = root["rope_freq_base"].as(); + if (root["rope_freq_scale"]) params.rope_freq_scale = root["rope_freq_scale"].as(); + if (root["yarn_ext_factor"]) params.yarn_ext_factor = root["yarn_ext_factor"].as(); + if (root["yarn_attn_factor"]) params.yarn_attn_factor = root["yarn_attn_factor"].as(); + if (root["yarn_beta_fast"]) params.yarn_beta_fast = root["yarn_beta_fast"].as(); + if (root["yarn_beta_slow"]) params.yarn_beta_slow = root["yarn_beta_slow"].as(); + if (root["yarn_orig_ctx"]) params.yarn_orig_ctx = root["yarn_orig_ctx"].as(); + + if (root["n_gpu_layers"]) params.n_gpu_layers = root["n_gpu_layers"].as(); + if (root["main_gpu"]) params.main_gpu = root["main_gpu"].as(); + + if (root["split_mode"]) { + params.split_mode = parse_split_mode(root["split_mode"].as()); + } + if (root["pooling_type"]) { + params.pooling_type = parse_pooling_type(root["pooling_type"].as()); + } + if (root["attention_type"]) { + params.attention_type = parse_attention_type(root["attention_type"].as()); + } + if (root["flash_attn_type"]) { + params.flash_attn_type = parse_flash_attn_type(root["flash_attn_type"].as()); + } + if (root["numa"]) { + params.numa = parse_numa_strategy(root["numa"].as()); + } + if (root["conversation_mode"]) { + params.conversation_mode = parse_conversation_mode(root["conversation_mode"].as()); + } + + if (root["use_mmap"]) params.use_mmap = root["use_mmap"].as(); + if (root["use_mlock"]) params.use_mlock = root["use_mlock"].as(); + if (root["verbose_prompt"]) params.verbose_prompt = root["verbose_prompt"].as(); + if (root["display_prompt"]) params.display_prompt = root["display_prompt"].as(); + if (root["no_kv_offload"]) params.no_kv_offload = root["no_kv_offload"].as(); + if (root["warmup"]) params.warmup = root["warmup"].as(); + if (root["check_tensors"]) params.check_tensors = root["check_tensors"].as(); + if (root["no_op_offload"]) params.no_op_offload = root["no_op_offload"].as(); + if (root["no_extra_bufts"]) params.no_extra_bufts = root["no_extra_bufts"].as(); + if (root["simple_io"]) params.simple_io = root["simple_io"].as(); + if (root["interactive"]) params.interactive = root["interactive"].as(); + if (root["interactive_first"]) params.interactive_first = root["interactive_first"].as(); + + if (root["input_prefix"]) params.input_prefix = root["input_prefix"].as(); + if (root["input_suffix"]) params.input_suffix = root["input_suffix"].as(); + if (root["logits_file"]) { + params.logits_file = resolve_path(root["logits_file"].as(), yaml_dir); + } + if (root["path_prompt_cache"]) { + params.path_prompt_cache = resolve_path(root["path_prompt_cache"].as(), yaml_dir); + } + + if (root["cache_type_k"]) { + params.cache_type_k = parse_ggml_type(root["cache_type_k"].as()); + } + if (root["cache_type_v"]) { + params.cache_type_v = parse_ggml_type(root["cache_type_v"].as()); + } + + if (root["antiprompt"]) { + params.antiprompt.clear(); + for (const auto & item : root["antiprompt"]) { + params.antiprompt.push_back(item.as()); + } + } + + if (root["in_files"]) { + params.in_files.clear(); + for (const auto & item : root["in_files"]) { + params.in_files.push_back(resolve_path(item.as(), yaml_dir)); + } + } + + if (root["image"]) { + params.image.clear(); + for (const auto & item : root["image"]) { + params.image.push_back(resolve_path(item.as(), yaml_dir)); + } + } + + if (root["seed"]) { + params.sampling.seed = root["seed"].as(); + } + + if (root["sampling"]) { + auto sampling = root["sampling"]; + if (sampling["seed"]) params.sampling.seed = sampling["seed"].as(); + if (sampling["n_prev"]) params.sampling.n_prev = sampling["n_prev"].as(); + if (sampling["n_probs"]) params.sampling.n_probs = sampling["n_probs"].as(); + if (sampling["min_keep"]) params.sampling.min_keep = sampling["min_keep"].as(); + if (sampling["top_k"]) params.sampling.top_k = sampling["top_k"].as(); + if (sampling["top_p"]) params.sampling.top_p = sampling["top_p"].as(); + if (sampling["min_p"]) params.sampling.min_p = sampling["min_p"].as(); + if (sampling["xtc_probability"]) params.sampling.xtc_probability = sampling["xtc_probability"].as(); + if (sampling["xtc_threshold"]) params.sampling.xtc_threshold = sampling["xtc_threshold"].as(); + if (sampling["typ_p"]) params.sampling.typ_p = sampling["typ_p"].as(); + if (sampling["temp"]) params.sampling.temp = sampling["temp"].as(); + if (sampling["dynatemp_range"]) params.sampling.dynatemp_range = sampling["dynatemp_range"].as(); + if (sampling["dynatemp_exponent"]) params.sampling.dynatemp_exponent = sampling["dynatemp_exponent"].as(); + if (sampling["penalty_last_n"]) params.sampling.penalty_last_n = sampling["penalty_last_n"].as(); + if (sampling["penalty_repeat"]) params.sampling.penalty_repeat = sampling["penalty_repeat"].as(); + if (sampling["penalty_freq"]) params.sampling.penalty_freq = sampling["penalty_freq"].as(); + if (sampling["penalty_present"]) params.sampling.penalty_present = sampling["penalty_present"].as(); + if (sampling["dry_multiplier"]) params.sampling.dry_multiplier = sampling["dry_multiplier"].as(); + if (sampling["dry_base"]) params.sampling.dry_base = sampling["dry_base"].as(); + if (sampling["dry_allowed_length"]) params.sampling.dry_allowed_length = sampling["dry_allowed_length"].as(); + if (sampling["dry_penalty_last_n"]) params.sampling.dry_penalty_last_n = sampling["dry_penalty_last_n"].as(); + if (sampling["mirostat"]) params.sampling.mirostat = sampling["mirostat"].as(); + if (sampling["mirostat_tau"]) params.sampling.mirostat_tau = sampling["mirostat_tau"].as(); + if (sampling["mirostat_eta"]) params.sampling.mirostat_eta = sampling["mirostat_eta"].as(); + if (sampling["top_n_sigma"]) params.sampling.top_n_sigma = sampling["top_n_sigma"].as(); + if (sampling["ignore_eos"]) params.sampling.ignore_eos = sampling["ignore_eos"].as(); + if (sampling["no_perf"]) params.sampling.no_perf = sampling["no_perf"].as(); + if (sampling["timing_per_token"]) params.sampling.timing_per_token = sampling["timing_per_token"].as(); + if (sampling["grammar"]) params.sampling.grammar = sampling["grammar"].as(); + if (sampling["grammar_lazy"]) params.sampling.grammar_lazy = sampling["grammar_lazy"].as(); + } + + return true; + } catch (const YAML::Exception & e) { + throw std::invalid_argument("YAML parsing error: " + std::string(e.what())); + } catch (const std::exception & e) { + throw std::invalid_argument("Config loading error: " + std::string(e.what())); + } +} + +#endif // LLAMA_ENABLE_CONFIG_YAML diff --git a/common/config.h b/common/config.h new file mode 100644 index 0000000000000..a8bb0b16cbe23 --- /dev/null +++ b/common/config.h @@ -0,0 +1,9 @@ +#pragma once + +#include "common.h" +#include + +#ifdef LLAMA_ENABLE_CONFIG_YAML +bool common_load_yaml_config(const std::string & path, common_params & params); +std::string common_yaml_valid_keys_help(); +#endif diff --git a/configs/minimal.yaml b/configs/minimal.yaml new file mode 100644 index 0000000000000..af49055cb8473 --- /dev/null +++ b/configs/minimal.yaml @@ -0,0 +1,9 @@ +model: + path: ../models/stories15M-q4_0.gguf +n_ctx: 256 +sampling: + seed: 42 + temp: 0.0 +prompt: "Hello from YAML" +n_predict: 16 +simple_io: true diff --git a/configs/override.yaml b/configs/override.yaml new file mode 100644 index 0000000000000..e20412e9691a8 --- /dev/null +++ b/configs/override.yaml @@ -0,0 +1,9 @@ +model: + path: ../models/stories15M-q4_0.gguf +n_ctx: 256 +sampling: + seed: 42 + temp: 0.8 +prompt: "Hello from YAML override" +n_predict: 32 +simple_io: true diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 91719577564a9..87a15f4ef79e0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -184,9 +184,36 @@ llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test(test-regex-partial.cpp) +if (LLAMA_BUILD_TOOLS) + llama_build_and_test(test-config-yaml.cpp) +endif() llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) +# YAML config integration tests +if(EXISTS ${PROJECT_SOURCE_DIR}/models/stories15M-q4_0.gguf) + llama_test_cmd( + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/llama-cli + NAME test-config-yaml-cli-only + ARGS --config ${PROJECT_SOURCE_DIR}/configs/minimal.yaml -no-cnv + ) + + llama_test_cmd( + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/llama-cli + NAME test-config-yaml-cli-overrides + ARGS --config ${PROJECT_SOURCE_DIR}/configs/override.yaml -n 8 --temp 0.0 -no-cnv + ) + + # Parity test - compare YAML config vs equivalent flags + add_test( + NAME test-config-yaml-parity + WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test-yaml-parity.sh + ) + set_property(TEST test-config-yaml-parity PROPERTY LABELS main) + set_property(TEST test-config-yaml-parity PROPERTY ENVIRONMENT "PROJECT_SOURCE_DIR=${PROJECT_SOURCE_DIR}") +endif() + # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) diff --git a/tests/test-config-yaml.cpp b/tests/test-config-yaml.cpp new file mode 100644 index 0000000000000..d65c74dad621c --- /dev/null +++ b/tests/test-config-yaml.cpp @@ -0,0 +1,125 @@ +#include "common.h" +#include "config.h" +#include +#include +#include +#include + +namespace fs = std::filesystem; + +static void test_minimal_config() { + common_params params; + fs::path temp_dir = fs::temp_directory_path() / "llama_test"; + fs::create_directories(temp_dir); + + std::string config_content = R"( +model: + path: test_model.gguf +n_ctx: 512 +sampling: + seed: 123 + temp: 0.5 +prompt: "Test prompt" +n_predict: 64 +simple_io: true +)"; + + fs::path config_path = temp_dir / "test_config.yaml"; + std::ofstream config_file(config_path); + config_file << config_content; + config_file.close(); + + bool result = common_load_yaml_config(config_path.string(), params); + assert(result); + (void)result; + + assert(params.model.path == (temp_dir / "test_model.gguf").string()); + assert(params.n_ctx == 512); + assert(params.sampling.seed == 123); + assert(params.sampling.temp == 0.5f); + assert(params.prompt == "Test prompt"); + assert(params.n_predict == 64); + assert(params.simple_io == true); + fs::remove_all(temp_dir); + + std::cout << "test_minimal_config: PASSED\n"; +} + +static void test_unknown_key_error() { + common_params params; + fs::path temp_dir = fs::temp_directory_path() / "llama_test"; + fs::create_directories(temp_dir); + + std::string config_content = R"( +model: + path: test_model.gguf +unknown_key: "should fail" +n_ctx: 512 +)"; + + fs::path config_path = temp_dir / "test_config.yaml"; + std::ofstream config_file(config_path); + config_file << config_content; + config_file.close(); + + bool threw_exception = false; + try { + common_load_yaml_config(config_path.string(), params); + } catch (const std::invalid_argument & e) { + threw_exception = true; + std::string error_msg = e.what(); + assert(error_msg.find("Unknown YAML keys") != std::string::npos); + assert(error_msg.find("valid keys are") != std::string::npos); + } + + assert(threw_exception); + (void)threw_exception; + fs::remove_all(temp_dir); + + std::cout << "test_unknown_key_error: PASSED\n"; +} + +static void test_relative_path_resolution() { + common_params params; + fs::path temp_dir = fs::temp_directory_path() / "llama_test"; + fs::path config_dir = temp_dir / "configs"; + fs::create_directories(config_dir); + + std::string config_content = R"( +model: + path: ../models/test_model.gguf +prompt_file: prompts/test.txt +)"; + + fs::path config_path = config_dir / "test_config.yaml"; + std::ofstream config_file(config_path); + config_file << config_content; + config_file.close(); + + bool result = common_load_yaml_config(config_path.string(), params); + assert(result); + (void)result; + + fs::path expected_model = temp_dir / "models" / "test_model.gguf"; + fs::path expected_prompt = config_dir / "prompts" / "test.txt"; + + assert(params.model.path == expected_model.lexically_normal().string()); + assert(params.prompt_file == expected_prompt.lexically_normal().string()); + fs::remove_all(temp_dir); + + std::cout << "test_relative_path_resolution: PASSED\n"; +} + +int main() { + try { + test_minimal_config(); + test_unknown_key_error(); + test_relative_path_resolution(); + + std::cout << "All tests passed!\n"; + return 0; + } catch (const std::exception & e) { + std::cerr << "Test failed: " << e.what() << std::endl; + return 1; + } +} diff --git a/tests/test-yaml-parity.sh b/tests/test-yaml-parity.sh new file mode 100755 index 0000000000000..abd1eadf254ef --- /dev/null +++ b/tests/test-yaml-parity.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +set -e + +LLAMA_CLI="./llama-cli" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(dirname "$SCRIPT_DIR")" +CONFIG_FILE="$REPO_ROOT/configs/minimal.yaml" +MODEL_PATH="$REPO_ROOT/models/stories15M-q4_0.gguf" + +if [ ! -f "$MODEL_PATH" ]; then + echo "Model file not found: $MODEL_PATH" + exit 1 +fi + +if [ ! -f "$CONFIG_FILE" ]; then + echo "Config file not found: $CONFIG_FILE" + exit 1 +fi + +echo "Running with YAML config..." +YAML_OUTPUT=$($LLAMA_CLI --config "$CONFIG_FILE" -no-cnv 2>/dev/null | tail -n +2) + +echo "Running with equivalent flags..." +FLAGS_OUTPUT=$($LLAMA_CLI -m "$MODEL_PATH" -n 16 -s 42 -c 256 --temp 0.0 -p "Hello from YAML" --simple-io -no-cnv 2>/dev/null | tail -n +2) + +if [ "$YAML_OUTPUT" = "$FLAGS_OUTPUT" ]; then + echo "PARITY TEST PASSED: YAML and flags produce identical output" + exit 0 +else + echo "PARITY TEST FAILED: Outputs differ" + echo "YAML output:" + echo "$YAML_OUTPUT" + echo "Flags output:" + echo "$FLAGS_OUTPUT" + exit 1 +fi