Skip to content

Commit

Permalink
TST: Fix windows CI (#455)
Browse files Browse the repository at this point in the history
  • Loading branch information
aresnow1 authored Sep 20, 2023
1 parent f227c93 commit acc1abb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
17 changes: 0 additions & 17 deletions xinference/model/llm/ggml/tests/test_ctransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import random
import string
from concurrent.futures import ThreadPoolExecutor

import pytest

Expand Down Expand Up @@ -138,22 +137,6 @@ async def test_ctransformers_generate(setup):
model = client.get_model(model_uid=model_uid)
assert isinstance(model, GenerateModelHandle)

# Test concurrent generate is OK.
def _check():
completion = model.generate("AI is going to", generate_config={"max_tokens": 5})
print(completion)
assert "id" in completion
assert "text" in completion["choices"][0]
assert len(completion["choices"][0]["text"]) > 0

results = []
with ThreadPoolExecutor() as executor:
for _ in range(3):
r = executor.submit(_check)
results.append(r)
for r in results:
r.result()

completion = model.generate("AI is going to", generate_config={"max_tokens": 5})
print(completion)
assert "id" in completion
Expand Down
38 changes: 38 additions & 0 deletions xinference/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from concurrent.futures import ThreadPoolExecutor

import pytest

Expand All @@ -26,6 +27,7 @@
)


@pytest.mark.skipif(os.name == "nt", reason="Skip windows")
def test_client(setup):
endpoint, _ = setup
client = Client(endpoint)
Expand Down Expand Up @@ -91,6 +93,7 @@ def test_client_for_embedding(setup):
assert len(client.list_models()) == 0


@pytest.mark.skipif(os.name == "nt", reason="Skip windows")
def test_replica_model(setup):
endpoint, _ = setup
client = Client(endpoint)
Expand Down Expand Up @@ -185,6 +188,7 @@ def test_client_custom_model(setup):
assert custom_model_reg is None


@pytest.mark.skipif(os.name == "nt", reason="Skip windows")
def test_RESTful_client(setup):
endpoint, _ = setup
client = RESTfulClient(endpoint)
Expand Down Expand Up @@ -237,6 +241,40 @@ def test_RESTful_client(setup):
client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0

model_uid = client.launch_model(
model_name="tiny-llama",
model_size_in_billions=1,
model_format="ggufv2",
quantization="Q2_K",
)
assert len(client.list_models()) == 1

# Test concurrent chat is OK.
def _check(stream=False):
model = client.get_model(model_uid=model_uid)
completion = model.generate(
"AI is going to", generate_config={"stream": stream, "max_tokens": 5}
)
if stream:
for chunk in completion:
assert "text" in chunk["choices"][0]
assert len(chunk["choices"][0]["text"]) > 0
else:
assert "text" in completion["choices"][0]
assert len(completion["choices"][0]["text"]) > 0

for stream in [True, False]:
results = []
with ThreadPoolExecutor() as executor:
for _ in range(3):
r = executor.submit(_check, stream=stream)
results.append(r)
for r in results:
r.result()

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0

with pytest.raises(RuntimeError):
client.terminate_model(model_uid=model_uid)

Expand Down

0 comments on commit acc1abb

Please sign in to comment.