Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add test for positional rotary embeddings #2623

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions server/tests/utils/test_rotary_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import pytest
import torch
from unittest.mock import Mock, patch
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
DynamicPositionRotaryEmbedding,
YarnPositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM


def test_position_rotary_embedding_static_basic():
config = Mock(rope_theta=10000, max_position_embeddings=2048, rope_scaling=None)
weights = Mock(device=torch.device("cpu"))

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, PositionRotaryEmbedding)
assert result.inv_freq.shape == (32,) # dim // 2
assert result.scaling_factor is None


def test_position_rotary_embedding_static_linear_scaling():
config = Mock(rope_theta=10000, max_position_embeddings=2048)
# scaling is not applied if type is linear (TODO: maybe revisit this)
config.rope_scaling = {"type": "linear", "factor": 2.0}
weights = Mock(device=torch.device("cpu"))

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, PositionRotaryEmbedding)
assert result.scaling_factor is None


def test_position_rotary_embedding_static_dynamic_scaling():
config = Mock(
rope_theta=10000,
max_position_embeddings=2048,
rope_scaling={"type": "dynamic", "factor": 2.0},
)
weights = Mock(device=torch.device("cpu"))

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, DynamicPositionRotaryEmbedding)
assert result.scaling_factor == 2.0
assert result.max_position_embeddings == 2048


def test_position_rotary_embedding_static_yarn_scaling():
config = Mock(
rope_theta=10000,
max_position_embeddings=2048,
rope_scaling={
"type": "yarn",
"factor": 1.5,
"original_max_position_embeddings": 2048,
},
)
weights = Mock(device=torch.device("cpu"))

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, YarnPositionRotaryEmbedding)
assert result.scaling_factor == 1.5
assert result.max_position_embeddings == 2048


def test_position_rotary_embedding_static_invalid_scaling():
config = Mock(
rope_theta=10000,
max_position_embeddings=2048,
rope_scaling={"type": "invalid", "factor": 2.0},
)
weights = Mock(device=torch.device("cpu"))

with pytest.raises(NotImplementedError):
PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)


def test_position_rotary_embedding_static_llama3_scaling():
config = Mock(
rope_theta=10000,
max_position_embeddings=2048,
rope_scaling={
"rope_type": "llama3",
"factor": 2.0,
"low_freq_factor": 4,
"high_freq_factor": 32,
"original_max_position_embeddings": 2048,
},
)
weights = Mock(device=torch.device("cpu"))

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, PositionRotaryEmbedding)
assert result.scaling_factor is None


def test_position_rotary_embedding_max_tokens_exceed_max_position_embeddings():
config = Mock(
rope_theta=10000,
max_position_embeddings=4096,
rope_scaling=None,
)
weights = Mock(device=torch.device("cpu"))

with patch(
"text_generation_server.layers.rotary._get_rope_config"
) as mock_get_rope_config:
mock_get_rope_config.return_value = {"type": "dynamic", "factor": 2.0}

result = PositionRotaryEmbedding.static(
config=config,
dim=64,
base=config.rope_theta,
device=weights.device,
)

assert isinstance(result, DynamicPositionRotaryEmbedding)
assert result.scaling_factor == 2.0
assert result.max_position_embeddings == 4096


# Test the application of the rotary embedding


def position_rotary_embedding_no_rope_config():
head_dim = 64
base = 10000
max_position_embeddings = 2048
num_heads = 16
batch_size = 2
seq_len = 128

device = "cuda"
dtype = torch.float16

config = Mock(
rope_theta=base,
max_position_embeddings=max_position_embeddings,
rope_scaling=None,
)

# create PositionRotaryEmbedding instance
rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=head_dim, base=base, device=device
)

# generate position IDs
position_ids = torch.arange(seq_len).unsqueeze(0)
position_ids = position_ids.to(device).to(torch.int32).view(-1)

# get cos and sin values for the position IDs
cos, sin = rotary_emb.get_cos_sin(
position_ids=position_ids,
max_s=seq_len,
dtype=dtype,
)

# create query and key tensors
query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)

# clone to compare later
original_query = query.clone()
original_key = key.clone()

# apply rotary embedding
rotary_emb(query, key, cos, sin)

# copy rotated query and key and original query and key
q_rotated = query
k_rotated = key
query = original_query
key = original_key

assert (
q_rotated.shape == query.shape
), "query shape should not change after rotation"
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"


def position_rotary_embedding_with_dynamic_scaling():
head_dim = 64
base = 10000
max_position_embeddings = 2048
num_heads = 16
batch_size = 2
seq_len = 128

device = "cuda"
dtype = torch.float16

config = Mock(
rope_theta=base,
max_position_embeddings=max_position_embeddings,
rope_scaling={"type": "dynamic", "factor": 1.0},
)

# create PositionRotaryEmbedding instance
rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=head_dim, base=base, device=device
)

# generate position IDs
position_ids = torch.arange(seq_len).unsqueeze(0)
position_ids = position_ids.to(device).to(torch.int32).view(-1)

# get cos and sin values for the position IDs
cos, sin = rotary_emb.get_cos_sin(
position_ids=position_ids,
max_s=seq_len,
dtype=dtype,
)

# create query and key tensors
query = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)
key = torch.randn(batch_size, seq_len, num_heads, head_dim).to(device).to(dtype)

# clone to compare later
original_query = query.clone()
original_key = key.clone()

# apply rotary embedding
rotary_emb(query, key, cos, sin)

# copy rotated query and key and original query and key
q_rotated = query
k_rotated = key
query = original_query
key = original_key

assert (
q_rotated.shape == query.shape
), "query shape should not change after rotation"
assert k_rotated.shape == key.shape, "key shape should not change after rotation"
assert not torch.allclose(q_rotated, query), "query should be modified by rotation"
assert not torch.allclose(k_rotated, key), "key should be modified by rotation"


if SYSTEM == "cuda":

def test_position_rotary_embedding_with_dynamic_scaling():
position_rotary_embedding_no_rope_config()

def test_position_rotary_embedding_no_rope_config():
position_rotary_embedding_no_rope_config()
Loading