Skip to content

Commit 075c324

Browse files
committed
Add e2e tests for embedding raw flag
1 parent 8284efc commit 075c324

File tree

3 files changed

+243
-0
lines changed

3 files changed

+243
-0
lines changed

.github/workflows/embeddings.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Embedding CLI build and tests
2+
name: Embedding CLI
3+
4+
on:
5+
workflow_dispatch:
6+
push:
7+
branches:
8+
- feature/*
9+
- master
10+
paths:
11+
- '.github/workflows/embeddings.yml'
12+
- 'examples/embedding/**'
13+
- 'examples/tests/**'
14+
pull_request:
15+
types: [opened, synchronize, reopened]
16+
paths:
17+
- '.github/workflows/embeddings.yml'
18+
- 'examples/embedding/**'
19+
- 'examples/tests/**'
20+
21+
jobs:
22+
embedding-cli-tests:
23+
runs-on: ubuntu-latest
24+
25+
steps:
26+
- name: Install system deps
27+
run: |
28+
sudo apt-get update
29+
sudo apt-get -y install \
30+
build-essential \
31+
cmake \
32+
curl \
33+
libcurl4-openssl-dev \
34+
python3-pip
35+
36+
- name: Checkout repository
37+
uses: actions/checkout@v4
38+
with:
39+
fetch-depth: 0
40+
41+
- name: Set up Python
42+
uses: actions/setup-python@v5
43+
with:
44+
python-version: '3.11'
45+
46+
- name: Install Python deps
47+
run: |
48+
pip install -r requirements.txt || echo "No extra requirements found"
49+
pip install pytest
50+
51+
- name: Build llama-embedding
52+
run: |
53+
cmake -B build \
54+
-DCMAKE_BUILD_TYPE=Release
55+
cmake --build build --target llama-embedding -j $(nproc)
56+
57+
- name: Run embedding tests
58+
run: |
59+
pytest -v examples/tests

examples/tests/__init__.py

Whitespace-only changes.

examples/tests/test_embedding.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import os, json, subprocess, hashlib
2+
from pathlib import Path
3+
import numpy as np
4+
import pytest
5+
6+
# ---------------------------------------------------------------------------
7+
# Configuration constants
8+
# ---------------------------------------------------------------------------
9+
10+
EPS = 1e-3
11+
REPO_ROOT = Path(__file__).resolve().parents[2]
12+
EXE = REPO_ROOT / ("build/bin/llama-embedding.exe" if os.name == "nt" else "build/bin/llama-embedding")
13+
DEFAULT_ENV = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")}
14+
SEED = "42"
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# Model setup helpers
19+
# ---------------------------------------------------------------------------
20+
21+
def get_model_hf_params():
22+
"""Default lightweight embedding model."""
23+
return {
24+
"hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF",
25+
"hf_file": "embeddinggemma-300M-qat-Q4_0.gguf",
26+
}
27+
28+
29+
@pytest.fixture(scope="session")
30+
def embedding_model():
31+
"""Download/cache model once per session."""
32+
exe_path = EXE
33+
if not exe_path.exists():
34+
alt = REPO_ROOT / "build/bin/Release/llama-embedding.exe"
35+
if alt.exists():
36+
exe_path = alt
37+
else:
38+
raise FileNotFoundError(f"llama-embedding binary not found under {REPO_ROOT}/build/bin")
39+
40+
params = get_model_hf_params()
41+
cmd = [
42+
str(exe_path),
43+
"-hfr", params["hf_repo"],
44+
"-hff", params["hf_file"],
45+
"--ctx-size", "16",
46+
"--embd-output-format", "json",
47+
"--no-warmup",
48+
"--threads", "1",
49+
"--seed", SEED,
50+
]
51+
res = subprocess.run(cmd, input="ok", capture_output=True, text=True, env=DEFAULT_ENV)
52+
assert res.returncode == 0, f"model download failed: {res.stderr}"
53+
return params
54+
55+
56+
# ---------------------------------------------------------------------------
57+
# Utility functions
58+
# ---------------------------------------------------------------------------
59+
60+
def run_embedding(text: str, fmt: str = "raw", params=None) -> str:
61+
"""Runs llama-embedding and returns stdout (string)."""
62+
exe_path = EXE
63+
if not exe_path.exists():
64+
raise FileNotFoundError(f"Missing binary: {exe_path}")
65+
params = params or get_model_hf_params()
66+
cmd = [
67+
str(exe_path),
68+
"-hfr", params["hf_repo"],
69+
"-hff", params["hf_file"],
70+
"--ctx-size", "2048",
71+
"--embd-output-format", fmt,
72+
"--threads", "1",
73+
"--seed", SEED,
74+
]
75+
result = subprocess.run(cmd, input=text, capture_output=True, text=True, env=DEFAULT_ENV)
76+
if result.returncode:
77+
raise AssertionError(f"embedding failed ({result.returncode}):\n{result.stderr[:400]}")
78+
out = result.stdout.strip()
79+
assert out, f"empty output for text={text!r}, fmt={fmt}"
80+
return out
81+
82+
83+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
84+
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
85+
86+
87+
def embedding_hash(vec: np.ndarray) -> str:
88+
"""Return short deterministic signature for regression tracking."""
89+
return hashlib.sha256(vec[:8].tobytes()).hexdigest()[:16]
90+
91+
92+
# ---------------------------------------------------------------------------
93+
# Tests
94+
# ---------------------------------------------------------------------------
95+
96+
# Register custom mark so pytest doesn't warn about it
97+
pytestmark = pytest.mark.filterwarnings("ignore::pytest.PytestUnknownMarkWarning")
98+
99+
@pytest.mark.slow
100+
@pytest.mark.parametrize("fmt", ["raw", "json"])
101+
@pytest.mark.parametrize("text", ["hello world", "hi 🌎", "line1\nline2\nline3"])
102+
def test_embedding_runs_and_finite(fmt, text, embedding_model):
103+
"""Ensure embeddings run end-to-end and produce finite floats."""
104+
out = run_embedding(text, fmt, embedding_model)
105+
floats = (
106+
np.array(out.split(), float)
107+
if fmt == "raw"
108+
else np.array(json.loads(out)["data"][0]["embedding"], float)
109+
)
110+
assert len(floats) > 100
111+
assert np.all(np.isfinite(floats)), f"non-finite values in {fmt} output"
112+
assert 0.1 < np.linalg.norm(floats) < 10
113+
114+
115+
def test_raw_vs_json_consistency(embedding_model):
116+
"""Compare raw vs JSON embedding output for same text."""
117+
text = "hello world"
118+
raw = np.array(run_embedding(text, "raw", embedding_model).split(), float)
119+
jsn = np.array(json.loads(run_embedding(text, "json", embedding_model))["data"][0]["embedding"], float)
120+
121+
assert raw.shape == jsn.shape
122+
cos = cosine_similarity(raw, jsn)
123+
assert cos > 0.999, f"divergence: cos={cos:.4f}"
124+
assert embedding_hash(raw) == embedding_hash(jsn), "hash mismatch → possible nondeterminism"
125+
126+
127+
def test_empty_input_deterministic(embedding_model):
128+
"""Empty input should yield finite, deterministic vector."""
129+
v1 = np.array(run_embedding("", "raw", embedding_model).split(), float)
130+
v2 = np.array(run_embedding("", "raw", embedding_model).split(), float)
131+
assert np.all(np.isfinite(v1))
132+
cos = cosine_similarity(v1, v2)
133+
assert cos > 0.9999, f"Empty input not deterministic (cos={cos:.5f})"
134+
assert 0.1 < np.linalg.norm(v1) < 10
135+
136+
137+
@pytest.mark.slow
138+
def test_very_long_input_stress(embedding_model):
139+
"""Stress test: large input near context window."""
140+
text = "lorem " * 2000
141+
vec = np.array(run_embedding(text, "raw", embedding_model).split(), float)
142+
assert len(vec) > 100
143+
assert np.isfinite(np.linalg.norm(vec))
144+
145+
146+
@pytest.mark.parametrize(
147+
"text",
148+
[" ", "\n\n\n", "123 456 789"],
149+
)
150+
def test_low_information_inputs_stable(text, embedding_model):
151+
"""Whitespace/numeric inputs should yield stable embeddings."""
152+
v1 = np.array(run_embedding(text, "raw", embedding_model).split(), float)
153+
v2 = np.array(run_embedding(text, "raw", embedding_model).split(), float)
154+
cos = cosine_similarity(v1, v2)
155+
assert cos > 0.999, f"unstable embedding for {text!r}"
156+
157+
158+
@pytest.mark.parametrize("flag", ["--no-such-flag", "--help"])
159+
def test_invalid_or_help_flag(flag):
160+
"""Invalid flags should fail; help should succeed."""
161+
res = subprocess.run([str(EXE), flag], capture_output=True, text=True)
162+
if flag == "--no-such-flag":
163+
assert res.returncode != 0
164+
assert any(k in res.stderr.lower() for k in ("error", "invalid", "unknown"))
165+
else:
166+
assert res.returncode == 0
167+
assert "usage" in (res.stdout.lower() + res.stderr.lower())
168+
169+
170+
@pytest.mark.parametrize("fmt", ["raw", "json"])
171+
@pytest.mark.parametrize("text", ["deterministic test", "deterministic test again"])
172+
def test_repeated_call_consistent(fmt, text, embedding_model):
173+
"""Same input → same hash across repeated runs."""
174+
out1 = run_embedding(text, fmt, embedding_model)
175+
out2 = run_embedding(text, fmt, embedding_model)
176+
177+
if fmt == "json":
178+
v1 = np.array(json.loads(out1)["data"][0]["embedding"], float)
179+
v2 = np.array(json.loads(out2)["data"][0]["embedding"], float)
180+
else:
181+
v1 = np.array(out1.split(), float)
182+
v2 = np.array(out2.split(), float)
183+
184+
assert embedding_hash(v1) == embedding_hash(v2)

0 commit comments

Comments
 (0)