From 040c5b5970f58d0bf5d524504ca716480ad3d75a Mon Sep 17 00:00:00 2001 From: System administrator Date: Tue, 8 Oct 2024 19:10:38 +0000 Subject: [PATCH 1/3] feat: add test for positional rotary embeddings --- server/tests/utils/test_rotary_emb.py | 275 ++++++++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 server/tests/utils/test_rotary_emb.py diff --git a/server/tests/utils/test_rotary_emb.py b/server/tests/utils/test_rotary_emb.py new file mode 100644 index 00000000000..7a3877cc468 --- /dev/null +++ b/server/tests/utils/test_rotary_emb.py @@ -0,0 +1,275 @@ +import pytest +import torch +from unittest.mock import Mock, patch +from text_generation_server.layers.rotary import ( + PositionRotaryEmbedding, + DynamicPositionRotaryEmbedding, + YarnPositionRotaryEmbedding, +) + + +def test_position_rotary_embedding_static_basic(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048, + rope_scaling=None + ) + weights = Mock(device=torch.device("cpu")) + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, PositionRotaryEmbedding) + assert result.inv_freq.shape == (32,) # dim // 2 + assert result.scaling_factor is None + + +def test_position_rotary_embedding_static_linear_scaling(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048 + ) + # scaling is not applied if type is linear (TODO: maybe revisit this) + config.rope_scaling = {"type": "linear", "factor": 2.0} + weights = Mock(device=torch.device("cpu")) + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, PositionRotaryEmbedding) + assert result.scaling_factor is None + + +def test_position_rotary_embedding_static_dynamic_scaling(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048, + rope_scaling = {"type": "dynamic", "factor": 2.0} + ) + weights = Mock(device=torch.device("cpu")) + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, DynamicPositionRotaryEmbedding) + assert result.scaling_factor == 2.0 + assert result.max_position_embeddings == 2048 + + +def test_position_rotary_embedding_static_yarn_scaling(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048, + rope_scaling = { + "type": "yarn", + "factor": 1.5, + "original_max_position_embeddings": 2048, + } + ) + weights = Mock(device=torch.device("cpu")) + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, YarnPositionRotaryEmbedding) + assert result.scaling_factor == 1.5 + assert result.max_position_embeddings == 2048 + + +def test_position_rotary_embedding_static_invalid_scaling(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048, + rope_scaling = {"type": "invalid", "factor": 2.0} + ) + weights = Mock(device=torch.device("cpu")) + + with pytest.raises(NotImplementedError): + PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + +def test_position_rotary_embedding_static_llama3_scaling(): + config = Mock( + rope_theta=10000, + max_position_embeddings=2048, + rope_scaling = { + "rope_type": "llama3", + "factor": 2.0, + "low_freq_factor": 4, + "high_freq_factor": 32, + "original_max_position_embeddings": 2048, + }) + weights = Mock(device=torch.device("cpu")) + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, PositionRotaryEmbedding) + assert result.scaling_factor is None + + +def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings(): + config = Mock( + rope_theta=10000, + max_position_embeddings=4096, + rope_scaling=None, + ) + weights = Mock(device=torch.device("cpu")) + + with patch( + "text_generation_server.layers.rotary._get_rope_config" + ) as mock_get_rope_config: + mock_get_rope_config.return_value = {"type": "dynamic", "factor": 2.0} + + result = PositionRotaryEmbedding.static( + config=config, + dim=64, + base=config.rope_theta, + device=weights.device, + ) + + assert isinstance(result, DynamicPositionRotaryEmbedding) + assert result.scaling_factor == 2.0 + assert result.max_position_embeddings == 4096 + +# Test the application of the rotary embedding + +def test_position_rotary_embedding_no_rope_config(): + head_dim = 64 + base = 10000 + max_position_embeddings = 2048 + num_heads = 16 + batch_size = 2 + seq_len = 128 + + device = "cuda" + dtype = torch.float16 + + config = Mock( + rope_theta=base, + max_position_embeddings=max_position_embeddings, + rope_scaling=None + ) + + # create PositionRotaryEmbedding instance + rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=head_dim, base=base, device=device + ) + + # generate position IDs + position_ids = torch.arange(seq_len).unsqueeze(0) + position_ids = position_ids.to(device).to(torch.int32).view(-1) + + # get cos and sin values for the position IDs + cos, sin = rotary_emb.get_cos_sin( + position_ids=position_ids, + max_s=seq_len, + dtype=dtype, + ) + + # create query and key tensors + query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype) + key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype) + + # clone to compare later + original_query = query.clone() + original_key = key.clone() + + # apply rotary embedding + rotary_emb(query, key, cos, sin) + + # copy rotated query and key and original query and key + q_rotated = query + k_rotated = key + query = original_query + key = original_key + + assert ( + q_rotated.shape == query.shape + ), "query shape should not change after rotation" + assert k_rotated.shape == key.shape, "key shape should not change after rotation" + assert not torch.allclose(q_rotated, query), "query should be modified by rotation" + assert not torch.allclose(k_rotated, key), "key should be modified by rotation" + + +def test_position_rotary_embedding_with_dynamic_scaling(): + head_dim = 64 + base = 10000 + max_position_embeddings = 2048 + num_heads = 16 + batch_size = 2 + seq_len = 128 + + device = "cuda" + dtype = torch.float16 + + config = Mock( + rope_theta=base, + max_position_embeddings=max_position_embeddings, + rope_scaling={"type": "dynamic", "factor": 1.0} + ) + + # create PositionRotaryEmbedding instance + rotary_emb = PositionRotaryEmbedding.static( + config=config, dim=head_dim, base=base, device=device + ) + + # generate position IDs + position_ids = torch.arange(seq_len).unsqueeze(0) + position_ids = position_ids.to(device).to(torch.int32).view(-1) + + # get cos and sin values for the position IDs + cos, sin = rotary_emb.get_cos_sin( + position_ids=position_ids, + max_s=seq_len, + dtype=dtype, + ) + + # create query and key tensors + query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype) + key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype) + + # clone to compare later + original_query = query.clone() + original_key = key.clone() + + # apply rotary embedding + rotary_emb(query, key, cos, sin) + + # copy rotated query and key and original query and key + q_rotated = query + k_rotated = key + query = original_query + key = original_key + + assert ( + q_rotated.shape == query.shape + ), "query shape should not change after rotation" + assert k_rotated.shape == key.shape, "key shape should not change after rotation" + assert not torch.allclose(q_rotated, query), "query should be modified by rotation" + assert not torch.allclose(k_rotated, key), "key should be modified by rotation" From 301a18c2e5364bf9da1f7a38b4d5f1d75207a154 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 9 Oct 2024 18:35:29 +0000 Subject: [PATCH 2/3] fix: limit some tests to only run when cuda available --- server/tests/utils/test_rotary_emb.py | 36 ++++++++++++++++----------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/server/tests/utils/test_rotary_emb.py b/server/tests/utils/test_rotary_emb.py index 7a3877cc468..1ade3c8bc51 100644 --- a/server/tests/utils/test_rotary_emb.py +++ b/server/tests/utils/test_rotary_emb.py @@ -6,12 +6,13 @@ DynamicPositionRotaryEmbedding, YarnPositionRotaryEmbedding, ) +from text_generation_server.utils.import_utils import SYSTEM def test_position_rotary_embedding_static_basic(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling=None ) weights = Mock(device=torch.device("cpu")) @@ -30,7 +31,7 @@ def test_position_rotary_embedding_static_basic(): def test_position_rotary_embedding_static_linear_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048 ) # scaling is not applied if type is linear (TODO: maybe revisit this) @@ -50,8 +51,8 @@ def test_position_rotary_embedding_static_linear_scaling(): def test_position_rotary_embedding_static_dynamic_scaling(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling = {"type": "dynamic", "factor": 2.0} ) weights = Mock(device=torch.device("cpu")) @@ -70,7 +71,7 @@ def test_position_rotary_embedding_static_dynamic_scaling(): def test_position_rotary_embedding_static_yarn_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048, rope_scaling = { "type": "yarn", @@ -94,8 +95,8 @@ def test_position_rotary_embedding_static_yarn_scaling(): def test_position_rotary_embedding_static_invalid_scaling(): config = Mock( - rope_theta=10000, - max_position_embeddings=2048, + rope_theta=10000, + max_position_embeddings=2048, rope_scaling = {"type": "invalid", "factor": 2.0} ) weights = Mock(device=torch.device("cpu")) @@ -111,7 +112,7 @@ def test_position_rotary_embedding_static_invalid_scaling(): def test_position_rotary_embedding_static_llama3_scaling(): config = Mock( - rope_theta=10000, + rope_theta=10000, max_position_embeddings=2048, rope_scaling = { "rope_type": "llama3", @@ -159,7 +160,7 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings(): # Test the application of the rotary embedding -def test_position_rotary_embedding_no_rope_config(): +def position_rotary_embedding_no_rope_config(): head_dim = 64 base = 10000 max_position_embeddings = 2048 @@ -171,7 +172,7 @@ def test_position_rotary_embedding_no_rope_config(): dtype = torch.float16 config = Mock( - rope_theta=base, + rope_theta=base, max_position_embeddings=max_position_embeddings, rope_scaling=None ) @@ -217,7 +218,7 @@ def test_position_rotary_embedding_no_rope_config(): assert not torch.allclose(k_rotated, key), "key should be modified by rotation" -def test_position_rotary_embedding_with_dynamic_scaling(): +def position_rotary_embedding_with_dynamic_scaling(): head_dim = 64 base = 10000 max_position_embeddings = 2048 @@ -229,8 +230,8 @@ def test_position_rotary_embedding_with_dynamic_scaling(): dtype = torch.float16 config = Mock( - rope_theta=base, - max_position_embeddings=max_position_embeddings, + rope_theta=base, + max_position_embeddings=max_position_embeddings, rope_scaling={"type": "dynamic", "factor": 1.0} ) @@ -273,3 +274,10 @@ def test_position_rotary_embedding_with_dynamic_scaling(): assert k_rotated.shape == key.shape, "key shape should not change after rotation" assert not torch.allclose(q_rotated, query), "query should be modified by rotation" assert not torch.allclose(k_rotated, key), "key should be modified by rotation" + +if SYSTEM == "cuda": + def test_position_rotary_embedding_with_dynamic_scaling(): + position_rotary_embedding_no_rope_config() + + def test_position_rotary_embedding_no_rope_config(): + position_rotary_embedding_no_rope_config() From 130f9d16b59f7c2020558e5b035184cf15bf6da7 Mon Sep 17 00:00:00 2001 From: David Holtz Date: Wed, 9 Oct 2024 18:44:41 +0000 Subject: [PATCH 3/3] fix: rerun black lint --- server/tests/utils/test_rotary_emb.py | 32 +++++++++++++-------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/server/tests/utils/test_rotary_emb.py b/server/tests/utils/test_rotary_emb.py index 1ade3c8bc51..399e869a236 100644 --- a/server/tests/utils/test_rotary_emb.py +++ b/server/tests/utils/test_rotary_emb.py @@ -10,11 +10,7 @@ def test_position_rotary_embedding_static_basic(): - config = Mock( - rope_theta=10000, - max_position_embeddings=2048, - rope_scaling=None - ) + config = Mock(rope_theta=10000, max_position_embeddings=2048, rope_scaling=None) weights = Mock(device=torch.device("cpu")) result = PositionRotaryEmbedding.static( @@ -30,10 +26,7 @@ def test_position_rotary_embedding_static_basic(): def test_position_rotary_embedding_static_linear_scaling(): - config = Mock( - rope_theta=10000, - max_position_embeddings=2048 - ) + config = Mock(rope_theta=10000, max_position_embeddings=2048) # scaling is not applied if type is linear (TODO: maybe revisit this) config.rope_scaling = {"type": "linear", "factor": 2.0} weights = Mock(device=torch.device("cpu")) @@ -53,7 +46,7 @@ def test_position_rotary_embedding_static_dynamic_scaling(): config = Mock( rope_theta=10000, max_position_embeddings=2048, - rope_scaling = {"type": "dynamic", "factor": 2.0} + rope_scaling={"type": "dynamic", "factor": 2.0}, ) weights = Mock(device=torch.device("cpu")) @@ -73,11 +66,11 @@ def test_position_rotary_embedding_static_yarn_scaling(): config = Mock( rope_theta=10000, max_position_embeddings=2048, - rope_scaling = { + rope_scaling={ "type": "yarn", "factor": 1.5, "original_max_position_embeddings": 2048, - } + }, ) weights = Mock(device=torch.device("cpu")) @@ -97,7 +90,7 @@ def test_position_rotary_embedding_static_invalid_scaling(): config = Mock( rope_theta=10000, max_position_embeddings=2048, - rope_scaling = {"type": "invalid", "factor": 2.0} + rope_scaling={"type": "invalid", "factor": 2.0}, ) weights = Mock(device=torch.device("cpu")) @@ -114,13 +107,14 @@ def test_position_rotary_embedding_static_llama3_scaling(): config = Mock( rope_theta=10000, max_position_embeddings=2048, - rope_scaling = { + rope_scaling={ "rope_type": "llama3", "factor": 2.0, "low_freq_factor": 4, "high_freq_factor": 32, "original_max_position_embeddings": 2048, - }) + }, + ) weights = Mock(device=torch.device("cpu")) result = PositionRotaryEmbedding.static( @@ -158,8 +152,10 @@ def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings(): assert result.scaling_factor == 2.0 assert result.max_position_embeddings == 4096 + # Test the application of the rotary embedding + def position_rotary_embedding_no_rope_config(): head_dim = 64 base = 10000 @@ -174,7 +170,7 @@ def position_rotary_embedding_no_rope_config(): config = Mock( rope_theta=base, max_position_embeddings=max_position_embeddings, - rope_scaling=None + rope_scaling=None, ) # create PositionRotaryEmbedding instance @@ -232,7 +228,7 @@ def position_rotary_embedding_with_dynamic_scaling(): config = Mock( rope_theta=base, max_position_embeddings=max_position_embeddings, - rope_scaling={"type": "dynamic", "factor": 1.0} + rope_scaling={"type": "dynamic", "factor": 1.0}, ) # create PositionRotaryEmbedding instance @@ -275,7 +271,9 @@ def position_rotary_embedding_with_dynamic_scaling(): assert not torch.allclose(q_rotated, query), "query should be modified by rotation" assert not torch.allclose(k_rotated, key), "key should be modified by rotation" + if SYSTEM == "cuda": + def test_position_rotary_embedding_with_dynamic_scaling(): position_rotary_embedding_no_rope_config()