Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 committed Oct 31, 2024
1 parent 42d6b39 commit 14b5ca2
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions xinference/core/tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import os
import sys
import threading
import time
Expand Down Expand Up @@ -112,18 +111,11 @@ def run_internal(self):
assert result["msg"] == self._expected_res


@pytest.fixture
def enable_batch():
os.environ["XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"] = "1"
yield
os.environ["XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"] = "0"


@pytest.mark.skipif(
sys.platform == "win32",
reason="does not run on windows github CI due to its terrible runtime",
)
def test_continuous_batching(enable_batch, setup):
def test_continuous_batching(setup):
endpoint, _ = setup
url = f"{endpoint}/v1/models"
client = RESTfulClient(endpoint)
Expand All @@ -132,7 +124,7 @@ def test_continuous_batching(enable_batch, setup):
payload = {
"model_engine": "transformers",
"model_type": "LLM",
"model_name": "qwen1.5-chat",
"model_name": "qwen2.5-instruct",
"quantization": "none",
"model_format": "pytorch",
"model_size_in_billions": "0_5",
Expand All @@ -146,7 +138,7 @@ def test_continuous_batching(enable_batch, setup):
response = requests.post(url, json=payload)
response_data = response.json()
model_uid_res = response_data["model_uid"]
assert model_uid_res == "qwen1.5-chat"
assert model_uid_res == "qwen2.5-instruct"

model = client.get_model(model_uid_res)

Expand Down

0 comments on commit 14b5ca2

Please sign in to comment.