|
| 1 | +import os, json, subprocess, hashlib |
| 2 | +from pathlib import Path |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | + |
| 6 | +# --------------------------------------------------------------------------- |
| 7 | +# Configuration constants |
| 8 | +# --------------------------------------------------------------------------- |
| 9 | + |
| 10 | +EPS = 1e-3 |
| 11 | +REPO_ROOT = Path(__file__).resolve().parents[2] |
| 12 | +EXE = REPO_ROOT / ("build/bin/llama-embedding.exe" if os.name == "nt" else "build/bin/llama-embedding") |
| 13 | +DEFAULT_ENV = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")} |
| 14 | +SEED = "42" |
| 15 | + |
| 16 | + |
| 17 | +# --------------------------------------------------------------------------- |
| 18 | +# Model setup helpers |
| 19 | +# --------------------------------------------------------------------------- |
| 20 | + |
| 21 | +def get_model_hf_params(): |
| 22 | + """Default lightweight embedding model.""" |
| 23 | + return { |
| 24 | + "hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF", |
| 25 | + "hf_file": "embeddinggemma-300M-qat-Q4_0.gguf", |
| 26 | + } |
| 27 | + |
| 28 | + |
| 29 | +@pytest.fixture(scope="session") |
| 30 | +def embedding_model(): |
| 31 | + """Download/cache model once per session.""" |
| 32 | + exe_path = EXE |
| 33 | + if not exe_path.exists(): |
| 34 | + alt = REPO_ROOT / "build/bin/Release/llama-embedding.exe" |
| 35 | + if alt.exists(): |
| 36 | + exe_path = alt |
| 37 | + else: |
| 38 | + raise FileNotFoundError(f"llama-embedding binary not found under {REPO_ROOT}/build/bin") |
| 39 | + |
| 40 | + params = get_model_hf_params() |
| 41 | + cmd = [ |
| 42 | + str(exe_path), |
| 43 | + "-hfr", params["hf_repo"], |
| 44 | + "-hff", params["hf_file"], |
| 45 | + "--ctx-size", "16", |
| 46 | + "--embd-output-format", "json", |
| 47 | + "--no-warmup", |
| 48 | + "--threads", "1", |
| 49 | + "--seed", SEED, |
| 50 | + ] |
| 51 | + res = subprocess.run(cmd, input="ok", capture_output=True, text=True, env=DEFAULT_ENV) |
| 52 | + assert res.returncode == 0, f"model download failed: {res.stderr}" |
| 53 | + return params |
| 54 | + |
| 55 | + |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | +# Utility functions |
| 58 | +# --------------------------------------------------------------------------- |
| 59 | + |
| 60 | +def run_embedding(text: str, fmt: str = "raw", params=None) -> str: |
| 61 | + """Runs llama-embedding and returns stdout (string).""" |
| 62 | + exe_path = EXE |
| 63 | + if not exe_path.exists(): |
| 64 | + raise FileNotFoundError(f"Missing binary: {exe_path}") |
| 65 | + params = params or get_model_hf_params() |
| 66 | + cmd = [ |
| 67 | + str(exe_path), |
| 68 | + "-hfr", params["hf_repo"], |
| 69 | + "-hff", params["hf_file"], |
| 70 | + "--ctx-size", "2048", |
| 71 | + "--embd-output-format", fmt, |
| 72 | + "--threads", "1", |
| 73 | + "--seed", SEED, |
| 74 | + ] |
| 75 | + result = subprocess.run(cmd, input=text, capture_output=True, text=True, env=DEFAULT_ENV) |
| 76 | + if result.returncode: |
| 77 | + raise AssertionError(f"embedding failed ({result.returncode}):\n{result.stderr[:400]}") |
| 78 | + out = result.stdout.strip() |
| 79 | + assert out, f"empty output for text={text!r}, fmt={fmt}" |
| 80 | + return out |
| 81 | + |
| 82 | + |
| 83 | +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
| 84 | + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) |
| 85 | + |
| 86 | + |
| 87 | +def embedding_hash(vec: np.ndarray) -> str: |
| 88 | + """Return short deterministic signature for regression tracking.""" |
| 89 | + return hashlib.sha256(vec[:8].tobytes()).hexdigest()[:16] |
| 90 | + |
| 91 | + |
| 92 | +# --------------------------------------------------------------------------- |
| 93 | +# Tests |
| 94 | +# --------------------------------------------------------------------------- |
| 95 | + |
| 96 | +# Register custom mark so pytest doesn't warn about it |
| 97 | +pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnknownMarkWarning") |
| 98 | + |
| 99 | +@pytest.mark.slow |
| 100 | +@pytest.mark.parametrize("fmt", ["raw", "json"]) |
| 101 | +@pytest.mark.parametrize("text", ["hello world", "hi 🌎", "line1\nline2\nline3"]) |
| 102 | +def test_embedding_runs_and_finite(fmt, text, embedding_model): |
| 103 | + """Ensure embeddings run end-to-end and produce finite floats.""" |
| 104 | + out = run_embedding(text, fmt, embedding_model) |
| 105 | + floats = ( |
| 106 | + np.array(out.split(), float) |
| 107 | + if fmt == "raw" |
| 108 | + else np.array(json.loads(out)["data"][0]["embedding"], float) |
| 109 | + ) |
| 110 | + assert len(floats) > 100 |
| 111 | + assert np.all(np.isfinite(floats)), f"non-finite values in {fmt} output" |
| 112 | + assert 0.1 < np.linalg.norm(floats) < 10 |
| 113 | + |
| 114 | + |
| 115 | +def test_raw_vs_json_consistency(embedding_model): |
| 116 | + """Compare raw vs JSON embedding output for same text.""" |
| 117 | + text = "hello world" |
| 118 | + raw = np.array(run_embedding(text, "raw", embedding_model).split(), float) |
| 119 | + jsn = np.array(json.loads(run_embedding(text, "json", embedding_model))["data"][0]["embedding"], float) |
| 120 | + |
| 121 | + assert raw.shape == jsn.shape |
| 122 | + cos = cosine_similarity(raw, jsn) |
| 123 | + assert cos > 0.999, f"divergence: cos={cos:.4f}" |
| 124 | + assert embedding_hash(raw) == embedding_hash(jsn), "hash mismatch → possible nondeterminism" |
| 125 | + |
| 126 | + |
| 127 | +def test_empty_input_deterministic(embedding_model): |
| 128 | + """Empty input should yield finite, deterministic vector.""" |
| 129 | + v1 = np.array(run_embedding("", "raw", embedding_model).split(), float) |
| 130 | + v2 = np.array(run_embedding("", "raw", embedding_model).split(), float) |
| 131 | + assert np.all(np.isfinite(v1)) |
| 132 | + cos = cosine_similarity(v1, v2) |
| 133 | + assert cos > 0.9999, f"Empty input not deterministic (cos={cos:.5f})" |
| 134 | + assert 0.1 < np.linalg.norm(v1) < 10 |
| 135 | + |
| 136 | + |
| 137 | +@pytest.mark.slow |
| 138 | +def test_very_long_input_stress(embedding_model): |
| 139 | + """Stress test: large input near context window.""" |
| 140 | + text = "lorem " * 2000 |
| 141 | + vec = np.array(run_embedding(text, "raw", embedding_model).split(), float) |
| 142 | + assert len(vec) > 100 |
| 143 | + assert np.isfinite(np.linalg.norm(vec)) |
| 144 | + |
| 145 | + |
| 146 | +@pytest.mark.parametrize( |
| 147 | + "text", |
| 148 | + [" ", "\n\n\n", "123 456 789"], |
| 149 | +) |
| 150 | +def test_low_information_inputs_stable(text, embedding_model): |
| 151 | + """Whitespace/numeric inputs should yield stable embeddings.""" |
| 152 | + v1 = np.array(run_embedding(text, "raw", embedding_model).split(), float) |
| 153 | + v2 = np.array(run_embedding(text, "raw", embedding_model).split(), float) |
| 154 | + cos = cosine_similarity(v1, v2) |
| 155 | + assert cos > 0.999, f"unstable embedding for {text!r}" |
| 156 | + |
| 157 | + |
| 158 | +@pytest.mark.parametrize("flag", ["--no-such-flag", "--help"]) |
| 159 | +def test_invalid_or_help_flag(flag): |
| 160 | + """Invalid flags should fail; help should succeed.""" |
| 161 | + res = subprocess.run([str(EXE), flag], capture_output=True, text=True) |
| 162 | + if flag == "--no-such-flag": |
| 163 | + assert res.returncode != 0 |
| 164 | + assert any(k in res.stderr.lower() for k in ("error", "invalid", "unknown")) |
| 165 | + else: |
| 166 | + assert res.returncode == 0 |
| 167 | + assert "usage" in (res.stdout.lower() + res.stderr.lower()) |
| 168 | + |
| 169 | + |
| 170 | +@pytest.mark.parametrize("fmt", ["raw", "json"]) |
| 171 | +@pytest.mark.parametrize("text", ["deterministic test", "deterministic test again"]) |
| 172 | +def test_repeated_call_consistent(fmt, text, embedding_model): |
| 173 | + """Same input → same hash across repeated runs.""" |
| 174 | + out1 = run_embedding(text, fmt, embedding_model) |
| 175 | + out2 = run_embedding(text, fmt, embedding_model) |
| 176 | + |
| 177 | + if fmt == "json": |
| 178 | + v1 = np.array(json.loads(out1)["data"][0]["embedding"], float) |
| 179 | + v2 = np.array(json.loads(out2)["data"][0]["embedding"], float) |
| 180 | + else: |
| 181 | + v1 = np.array(out1.split(), float) |
| 182 | + v2 = np.array(out2.split(), float) |
| 183 | + |
| 184 | + assert embedding_hash(v1) == embedding_hash(v2) |
0 commit comments