From ac404be2dc54bc7329c25d2742859a455b17c0e8 Mon Sep 17 00:00:00 2001
From: Xuan Son Nguyen <son@huggingface.co>
Date: Thu, 28 Nov 2024 14:40:22 +0100
Subject: [PATCH 1/3] server : add split model test

---
 examples/server/tests/unit/test_basic.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py
index 84db5ca1ca192..d82d54a5a6f47 100644
--- a/examples/server/tests/unit/test_basic.py
+++ b/examples/server/tests/unit/test_basic.py
@@ -32,3 +32,17 @@ def test_server_models():
     assert res.status_code == 200
     assert len(res.body["data"]) == 1
     assert res.body["data"][0]["id"] == server.model_alias
+
+def test_load_split_model():
+    global server
+    server.model_hf_repo = "ggml-org/models"
+    server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
+    server.model_alias = "tinyllama-split"
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "n_predict": 16,
+        "prompt": "Hello",
+        "temperature": 0.0,
+    })
+    assert res.status_code == 200
+    assert match_regex("(little|girl)+", res.body["content"])

From 8aaf69a3eead143856e41b7addf823b873a2185f Mon Sep 17 00:00:00 2001
From: Xuan Son Nguyen <son@huggingface.co>
Date: Thu, 28 Nov 2024 15:15:02 +0100
Subject: [PATCH 2/3] add test speculative

---
 .../server/tests/unit/test_speculative.py     | 103 ++++++++++++++++++
 examples/server/tests/utils.py                |  12 +-
 2 files changed, 114 insertions(+), 1 deletion(-)
 create mode 100644 examples/server/tests/unit/test_speculative.py

diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py
new file mode 100644
index 0000000000000..982d6abb45f5f
--- /dev/null
+++ b/examples/server/tests/unit/test_speculative.py
@@ -0,0 +1,103 @@
+import pytest
+from utils import *
+
+# We use a F16 MOE gguf as main model, and q4_0 as draft model
+
+server = ServerPreset.stories15m_moe()
+
+MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
+
+def create_server():
+    global server
+    server = ServerPreset.stories15m_moe()
+    # download draft model file if needed
+    file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
+    model_draft_file = f'../../../{file_name}'
+    if not os.path.exists(model_draft_file):
+        print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
+        with open(model_draft_file, 'wb') as f:
+            f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
+        print(f"Done downloading draft model file")
+    # set default values
+    server.model_draft = model_draft_file
+    server.draft_min = 4
+    server.draft_max = 8
+
+
+@pytest.fixture(scope="module", autouse=True)
+def fixture_create_server():
+    return create_server()
+
+
+def test_with_and_without_draft():
+    global server
+    server.model_draft = None  # disable draft model
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is",
+        "temperature": 0.0,
+        "top_k": 1,
+    })
+    assert res.status_code == 200
+    content_no_draft = res.body["content"]
+    server.stop()
+
+    # create new server with draft model
+    create_server()
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is",
+        "temperature": 0.0,
+        "top_k": 1,
+    })
+    assert res.status_code == 200
+    content_draft = res.body["content"]
+
+    assert content_no_draft == content_draft
+
+
+def test_different_draft_min_draft_max():
+    global server
+    test_values = [
+        (1, 2),
+        (1, 4),
+        (4, 8),
+        (4, 12),
+        (8, 16),
+    ]
+    last_content = None
+    for draft_min, draft_max in test_values:
+        server.stop()
+        server.draft_min = draft_min
+        server.draft_max = draft_max
+        server.start()
+        res = server.make_request("POST", "/completion", data={
+            "prompt": "I believe the meaning of life is",
+            "temperature": 0.0,
+            "top_k": 1,
+        })
+        assert res.status_code == 200
+        if last_content is not None:
+            assert last_content == res.body["content"]
+        last_content = res.body["content"]
+
+
+@pytest.mark.parametrize("n_slots,n_requests", [
+    (1, 2),
+    (2, 2),
+])
+def test_multi_requests_parallel(n_slots: int, n_requests: int):
+    global server
+    server.n_slots = n_slots
+    server.start()
+    tasks = []
+    for _ in range(n_requests):
+        tasks.append((server.make_request, ("POST", "/completion", {
+            "prompt": "I believe the meaning of life is",
+            "temperature": 0.0,
+            "top_k": 1,
+        })))
+    results = parallel_function_calls(tasks)
+    for res in results:
+        assert res.status_code == 200
+        assert match_regex("(wise|kind|owl|answer)+", res.body["content"])
diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py
index e31743c505d8e..b48ec3e78e1d8 100644
--- a/examples/server/tests/utils.py
+++ b/examples/server/tests/utils.py
@@ -47,6 +47,7 @@ class ServerProcess:
     model_alias: str | None = None
     model_url: str | None = None
     model_file: str | None = None
+    model_draft: str | None = None
     n_threads: int | None = None
     n_gpu_layer: int | None = None
     n_batch: int | None = None
@@ -69,6 +70,8 @@ class ServerProcess:
     response_format: str | None = None
     lora_files: List[str] | None = None
     disable_ctx_shift: int | None = False
+    draft_min: int | None = None
+    draft_max: int | None = None
 
     # session variables
     process: subprocess.Popen | None = None
@@ -103,6 +106,8 @@ def start(self, timeout_seconds: int = 10) -> None:
             server_args.extend(["--model", self.model_file])
         if self.model_url:
             server_args.extend(["--model-url", self.model_url])
+        if self.model_draft:
+            server_args.extend(["--model-draft", self.model_draft])
         if self.model_hf_repo:
             server_args.extend(["--hf-repo", self.model_hf_repo])
         if self.model_hf_file:
@@ -148,6 +153,10 @@ def start(self, timeout_seconds: int = 10) -> None:
             server_args.extend(["--no-context-shift"])
         if self.api_key:
             server_args.extend(["--api-key", self.api_key])
+        if self.draft_max:
+            server_args.extend(["--draft-max", self.draft_max])
+        if self.draft_min:
+            server_args.extend(["--draft-min", self.draft_min])
 
         args = [str(arg) for arg in [server_path, *server_args]]
         print(f"bench: starting server with: {' '.join(args)}")
@@ -200,7 +209,8 @@ def server_log(in_stream, out_stream):
         raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
 
     def stop(self) -> None:
-        server_instances.remove(self)
+        if self in server_instances:
+            server_instances.remove(self)
         if self.process:
             print(f"Stopping server with pid={self.process.pid}")
             self.process.kill()

From 879c5ebd25d486c1810796364904436bb194c627 Mon Sep 17 00:00:00 2001
From: Xuan Son Nguyen <son@huggingface.co>
Date: Thu, 28 Nov 2024 17:07:51 +0100
Subject: [PATCH 3/3] add invalid cases

---
 .../server/tests/unit/test_chat_completion.py | 19 ++++++++++++++++
 examples/server/tests/unit/test_infill.py     | 22 +++++++++++++++++++
 examples/server/tests/unit/test_rerank.py     | 17 ++++++++++++++
 3 files changed, 58 insertions(+)

diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py
index d7aeb288d45cc..1048d6fcaf500 100644
--- a/examples/server/tests/unit/test_chat_completion.py
+++ b/examples/server/tests/unit/test_chat_completion.py
@@ -127,3 +127,22 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
         assert res.status_code != 200
         assert "error" in res.body
 
+
+@pytest.mark.parametrize("messages", [
+    None,
+    "string",
+    [123],
+    [{}],
+    [{"role": 123}],
+    [{"role": "system", "content": 123}],
+    # [{"content": "hello"}], # TODO: should not be a valid case
+    [{"role": "system", "content": "test"}, {}],
+])
+def test_invalid_chat_completion_req(messages):
+    global server
+    server.start()
+    res = server.make_request("POST", "/chat/completions", data={
+        "messages": messages,
+    })
+    assert res.status_code == 400 or res.status_code == 500
+    assert "error" in res.body
diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py
index 38ce6c42954ed..6a6d40a1cbc8b 100644
--- a/examples/server/tests/unit/test_infill.py
+++ b/examples/server/tests/unit/test_infill.py
@@ -8,6 +8,7 @@ def create_server():
     global server
     server = ServerPreset.tinyllama_infill()
 
+
 def test_infill_without_input_extra():
     global server
     server.start()
@@ -19,6 +20,7 @@ def test_infill_without_input_extra():
     assert res.status_code == 200
     assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
 
+
 def test_infill_with_input_extra():
     global server
     server.start()
@@ -33,3 +35,23 @@ def test_infill_with_input_extra():
     })
     assert res.status_code == 200
     assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])
+
+
+@pytest.mark.parametrize("input_extra", [
+    {},
+    {"filename": "ok"},
+    {"filename": 123},
+    {"filename": 123, "text": "abc"},
+    {"filename": 123, "text": 456},
+])
+def test_invalid_input_extra_req(input_extra):
+    global server
+    server.start()
+    res = server.make_request("POST", "/infill", data={
+        "prompt": "Complete this",
+        "input_extra": [input_extra],
+        "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_",
+        "input_suffix": "}\n",
+    })
+    assert res.status_code == 400
+    assert "error" in res.body
diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py
index 3a49fd3ac6bdf..189bc4c962329 100644
--- a/examples/server/tests/unit/test_rerank.py
+++ b/examples/server/tests/unit/test_rerank.py
@@ -36,3 +36,20 @@ def test_rerank():
     assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
     assert most_relevant["index"] == 2
     assert least_relevant["index"] == 3
+
+
+@pytest.mark.parametrize("documents", [
+    [],
+    None,
+    123,
+    [1, 2, 3],
+])
+def test_invalid_rerank_req(documents):
+    global server
+    server.start()
+    res = server.make_request("POST", "/rerank", data={
+        "query": "Machine learning is",
+        "documents": documents,
+    })
+    assert res.status_code == 400
+    assert "error" in res.body