Skip to content

Commit 015351e

Browse files
committed
Add e2e tests for embedding raw flag
1 parent 8284efc commit 015351e

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed

.github/workflows/embeddings.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Embedding CLI build and tests
2+
name: Embedding CLI
3+
4+
on:
5+
workflow_dispatch:
6+
push:
7+
branches:
8+
- feature/*
9+
- master
10+
paths:
11+
- '.github/workflows/embeddings.yml'
12+
- 'examples/embedding/**'
13+
- 'examples/tests/**'
14+
pull_request:
15+
types: [opened, synchronize, reopened]
16+
paths:
17+
- '.github/workflows/embeddings.yml'
18+
- 'examples/embedding/**'
19+
- 'examples/tests/**'
20+
21+
jobs:
22+
embedding-cli-tests:
23+
runs-on: ubuntu-latest
24+
25+
steps:
26+
- name: Install system deps
27+
run: |
28+
sudo apt-get update
29+
sudo apt-get -y install \
30+
build-essential \
31+
cmake \
32+
curl \
33+
libcurl4-openssl-dev \
34+
python3-pip
35+
36+
- name: Checkout repository
37+
uses: actions/checkout@v4
38+
with:
39+
fetch-depth: 0
40+
41+
- name: Set up Python
42+
uses: actions/setup-python@v5
43+
with:
44+
python-version: '3.11'
45+
46+
- name: Install Python deps
47+
run: |
48+
pip install -r requirements.txt || echo "No extra requirements found"
49+
pip install pytest
50+
51+
- name: Build llama-embedding
52+
run: |
53+
cmake -B build \
54+
-DCMAKE_BUILD_TYPE=Release
55+
cmake --build build --target llama-embedding -j $(nproc)
56+
57+
- name: Run embedding tests
58+
run: |
59+
pytest -v examples/tests

examples/tests/__init__.py

Whitespace-only changes.

examples/tests/test_embedding.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import json, os, subprocess
2+
from pathlib import Path
3+
import numpy as np
4+
5+
6+
# ---------------------------------------------------------------------------
7+
# Model helpers
8+
# ---------------------------------------------------------------------------
9+
10+
def get_model_hf_params():
11+
"""Default lightweight embedding model."""
12+
return {
13+
"hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF",
14+
"hf_file": "embeddinggemma-300M-qat-Q4_0.gguf",
15+
}
16+
17+
18+
def ensure_model_downloaded(params=None):
19+
"""Ensures the embedding model is cached locally."""
20+
repo_root = Path(__file__).resolve().parents[2]
21+
emb_path = repo_root / "build/bin/llama-embedding"
22+
if os.name == "nt" and not emb_path.exists():
23+
emb_path = repo_root / "build/bin/Release/llama-embedding.exe"
24+
if not emb_path.exists():
25+
raise FileNotFoundError(f"llama-embedding not found at {emb_path}")
26+
27+
params = params or get_model_hf_params()
28+
env = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")}
29+
30+
cmd = [
31+
str(emb_path),
32+
"-hfr", params["hf_repo"],
33+
"-hff", params["hf_file"],
34+
"--ctx-size", "16",
35+
"--embd-output-format", "json",
36+
"--no-warmup",
37+
"--threads", "1",
38+
]
39+
result = subprocess.run(cmd, input="ok", capture_output=True, text=True, env=env)
40+
if result.returncode:
41+
raise RuntimeError(f"Model download failed:\n{result.stderr}")
42+
return params
43+
44+
45+
def run_embedding(text, fmt="raw", params=None):
46+
"""Runs llama-embedding and returns stdout."""
47+
repo_root = Path(__file__).resolve().parents[2]
48+
exe = repo_root / "build/bin/llama-embedding"
49+
assert exe.exists(), f"Missing binary: {exe}"
50+
51+
params = ensure_model_downloaded(params)
52+
env = {**os.environ, "LLAMA_CACHE": os.environ.get("LLAMA_CACHE", "tmp")}
53+
cmd = [
54+
str(exe),
55+
"-hfr", params["hf_repo"],
56+
"-hff", params["hf_file"],
57+
"--ctx-size", "2048",
58+
"--embd-output-format", fmt,
59+
]
60+
out = subprocess.run(cmd, input=text, capture_output=True, text=True, env=env)
61+
if out.returncode:
62+
raise AssertionError(f"embedding failed ({out.returncode}):\n{out.stderr}")
63+
return out.stdout.strip()
64+
65+
66+
# ---------------------------------------------------------------------------
67+
# Tests
68+
# ---------------------------------------------------------------------------
69+
70+
def test_embedding_raw_and_json_consistency():
71+
"""Compare raw vs JSON embedding output."""
72+
out_raw = run_embedding("hello world", "raw")
73+
floats_raw = np.array(out_raw.split(), float)
74+
floats_json = np.array(json.loads(run_embedding("hello world", "json"))["data"][0]["embedding"])
75+
76+
assert len(floats_raw) == len(floats_json)
77+
cos = np.dot(floats_raw, floats_json) / (np.linalg.norm(floats_raw) * np.linalg.norm(floats_json))
78+
assert cos > 0.999, f"Unexpected divergence between raw and JSON output ({cos:.4f})"
79+
80+
81+
def test_embedding_empty_input():
82+
"""Ensure empty input is deterministic and finite."""
83+
out1 = np.array(run_embedding("", "raw").split(), float)
84+
out2 = np.array(run_embedding("", "raw").split(), float)
85+
86+
norm = np.linalg.norm(out1)
87+
assert len(out1) and np.all(np.isfinite(out1))
88+
assert 0.1 < norm < 10
89+
cos = np.dot(out1, out2) / (np.linalg.norm(out1) * np.linalg.norm(out2))
90+
assert cos > 0.9999, f"Empty input not deterministic (cos={cos:.4f})"
91+
92+
93+
def test_embedding_very_long_input():
94+
"""Stress test: very long input within context window."""
95+
text = "lorem " * 2000
96+
floats = np.array(run_embedding(text, "raw").split(), float)
97+
assert len(floats) > 100 and np.isfinite(np.linalg.norm(floats))
98+
99+
100+
def test_embedding_output_shape():
101+
"""Basic embedding sanity check."""
102+
floats = np.array(run_embedding("hello world", "raw").split(), float)
103+
assert len(floats) > 100 and 0.5 < np.linalg.norm(floats) < 2.0
104+
105+
106+
def test_embedding_invalid_flag():
107+
"""Invalid flag should produce non-zero exit and error output."""
108+
repo_root = Path(__file__).resolve().parents[2]
109+
exe = repo_root / "build/bin/llama-embedding"
110+
result = subprocess.run([str(exe), "--no-such-flag"], capture_output=True, text=True)
111+
assert result.returncode != 0
112+
assert any(k in result.stderr.lower() for k in ("error", "invalid", "unknown"))

0 commit comments

Comments
 (0)