Skip to content

Commit 03e378b

Browse files
committed
fix ut
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent c0b5fa1 commit 03e378b

File tree

2 files changed

+8
-22
lines changed

2 files changed

+8
-22
lines changed

requirements-dev.txt

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,4 @@ types-psutil
1313
pytest-cov
1414
regex
1515
sentence_transformers
16-
# Use daliy release ray for pp, which includes https://github.com/ray-project/ray/commit/eae786c8674f6b1c29c3bda6fdce80d0ab4afcd8
17-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp39-cp39-manylinux2014_aarch64.whl ; python_version == "3.9" and platform_system == "Linux" and platform_machine == "aarch64"
18-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_aarch64.whl ; python_version == "3.10" and platform_system == "Linux" and platform_machine == "aarch64"
19-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp311-cp311-manylinux2014_aarch64.whl ; python_version == "3.11" and platform_system == "Linux" and platform_machine == "aarch64"
20-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp312-cp312-manylinux2014_aarch64.whl ; python_version == "3.12" and platform_system == "Linux" and platform_machine == "aarch64"
21-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl ; python_version == "3.9" and platform_system == "Linux" and platform_machine == "x86_64"
22-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl ; python_version == "3.10" and platform_system == "Linux" and platform_machine == "x86_64"
23-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp311-cp311-manylinux2014_x86_64.whl ; python_version == "3.11" and platform_system == "Linux" and platform_machine == "x86_64"
24-
ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp312-cp312-manylinux2014_x86_64.whl ; python_version == "3.12" and platform_system == "Linux" and platform_machine == "x86_64"
16+
ray>=2.47.1

tests/ut/attention/test_attention_v1.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -400,19 +400,13 @@ def test_forward_head_size_192(self, mock_vanilla_prefill,
400400
layer = self.layer_no_quant
401401
mock_vanilla_prefill.return_value = MagicMock()
402402

403-
def mock_tensor(data, device=None, **kwargs):
404-
if device == "npu":
405-
return metadata.attn_mask
406-
return torch.tensor(data, **kwargs)
407-
408-
with patch("torch.tensor", side_effect=mock_tensor):
409-
output = self.impl_192.forward(layer,
410-
query,
411-
key,
412-
value,
413-
kv_cache,
414-
metadata,
415-
trace_flag=False)
403+
output = self.impl_192.forward(layer,
404+
query,
405+
key,
406+
value,
407+
kv_cache,
408+
metadata,
409+
trace_flag=False)
416410

417411
mock_vanilla_prefill.assert_called_once()
418412
assert output.shape == (10, 8 * 192)

0 commit comments

Comments
 (0)