Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement NTK-Aware scaled and dynamically scaled RoPE for PositionRotaryEmbedding #529

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct Args {
/// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)]
#[clap(default_value = "2048", long, env)]
max_batch_prefill_tokens: u32,

/// **IMPORTANT** This is one critical control to allow maximum usage
Expand All @@ -182,7 +182,7 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "16000", long, env)]
#[clap(default_value = "8192", long, env)]
max_batch_total_tokens: u32,

/// This setting defines how many tokens can be passed before forcing the waiting
Expand Down Expand Up @@ -280,6 +280,19 @@ struct Args {
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,

/// NTK-Aware Scaled Rope is a method proposed in https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
/// The scale factor, or "α", is used in combination with a non linearity to scale the base used to calculate the parameter "θ", the angle of rotation in RoPE.
/// This increases how many input tokens can be represented within the same portion of a positional embedding, with the non linearity used to increase token seprability.
#[clap(default_value = "1", long, env)]
rope_scale_factor: usize,

/// Dynamic scaling of the "α" factor in NTK-Aware Scaled Rope was introduced in https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
/// The idea being instead of setting alpha statically, it is calculated as a function of the current sequence length and the model's base sequence length.
/// This is a means to both increase performance on shorter sequence lengths and smooth the perplexity explosion experienced by both linearly scaled and NTK-Aware scaled RoPE.
/// If this is enabled the above "rope_scale_factor" will be ignored.
#[clap(default_value = "false", long, env)]
rope_dynamic_scaling: bool,
}

#[derive(Debug)]
Expand All @@ -293,6 +306,8 @@ fn shard_manager(
model_id: String,
revision: Option<String>,
quantize: Option<Quantization>,
rope_scale_factor: usize,
rope_dynamic_scaling: bool,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
Expand Down Expand Up @@ -422,6 +437,16 @@ fn shard_manager(
envs.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}

// RoPE Scaling
envs.push((
"ROPE_SCALE_FACTOR".into(),
rope_scale_factor.to_string().into(),
));
envs.push((
"ROPE_DYNAMIC_SCALING".into(),
rope_dynamic_scaling.to_string().into(),
));

// Start process
tracing::info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server")
Expand Down Expand Up @@ -776,11 +801,16 @@ fn spawn_shards(
let disable_custom_kernels = args.disable_custom_kernels;
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let rope_scale_factor = args.rope_scale_factor;
let rope_dynamic_scaling = args.rope_dynamic_scaling;

thread::spawn(move || {
shard_manager(
model_id,
revision,
quantize,
rope_scale_factor,
rope_dynamic_scaling,
dtype,
trust_remote_code,
uds_path,
Expand Down
4 changes: 2 additions & 2 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ struct Args {
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
#[clap(default_value = "2048", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
#[clap(default_value = "8192", long, env)]
Comment on lines +38 to +40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not belong in this PR.

We can discuss changing the defaults, but it's a separate concerns.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yup fair, I don't want to change them I meant to clean this out. I'll remove!

max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.distributed

Expand All @@ -41,6 +42,12 @@
TensorParallelHead,
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing should be model specific.


class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
Expand Down Expand Up @@ -105,10 +112,18 @@ def __init__(
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling
)
else:
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size**-0.5

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.distributed

Expand All @@ -44,6 +45,9 @@
get_linear,
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"


def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
Expand Down Expand Up @@ -102,10 +106,18 @@ def __init__(self, config, prefix, weights):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings
)
else:
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size ** (-0.5)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import torch
import torch.distributed
import warnings

from torch import nn
from transformers.modeling_utils import PreTrainedModel
Expand All @@ -23,6 +25,8 @@
get_linear,
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"

def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
Expand Down Expand Up @@ -113,10 +117,13 @@ def __init__(
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
)

self.softmax_scale = self.head_size ** (-0.5)

if self.num_heads % weights.process_group.size() != 0:
Expand Down Expand Up @@ -239,9 +246,11 @@ def __init__(

self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, scale_factor=self.scale_factor, dynamic_scaling=self.dynamic_scaling
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"

def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
Expand Down
49 changes: 39 additions & 10 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def forward(self, hidden_states, residual=None):
import rotary_emb

class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
def __init__(self, inv_freq, scale_factor=1, dynamic_scaling=False, max_seq_len=2048, dim=None, base=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have at most 1 extra argument.

A lot of information should be extractable directly from inv_freq.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I can try to simplify this

super().__init__()

self.inv_freq = inv_freq
Expand All @@ -379,32 +379,61 @@ def __init__(self, inv_freq):
self._cos_k_cached = None
self._sin_k_cached = None

@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
self.scale_factor = scale_factor
self.dynamic_scaling = dynamic_scaling
self.original_max_seq_len = max_seq_len
self.max_seq_len = max_seq_len * scale_factor
self.dim = dim
self.base = base

@classmethod
def static(cls, dim, base, device, scale_factor=1, dynamic_scaling=False, max_seq_len=2048):
inv_freq = cls._get_inv_freq(dim, base, device, scale_factor)
return cls(inv_freq, scale_factor, dynamic_scaling, max_seq_len, dim, base)

@classmethod
def load(cls, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32

inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)

@staticmethod
def _get_inv_freq(dim, base, device, scale_factor=1):
base = base * scale_factor ** (dim / (dim-2))

inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)

return inv_freq

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)

length = seqlen
max_seq_len = self.max_seq_len
inv_freq = self.inv_freq

if self.dynamic_scaling:
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
max_seq_len = self.original_max_seq_len * scale_factor
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really OK I think.

You ditching entirely the original self.inv_freq which unfortunately for us is sometimes different from the calculation proposed (that's why not all models are static and some are load.

Llama most notably has different saved inv_freq (not sure why but it's indeed the case).

Copy link
Author

@iantbutler01 iantbutler01 Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of dynamic scaling is calculating the new inv_freq, looking at the dynamic scaling implementation in Transformers I don't see them preserving this value either.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you suggest alternatively?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking interpolation when I wrote this.

Now that I reflect more it would make the code even more complex, which is not the desired effect.

Can we maybe move out the scaling factor out of get_inv_freq and keep it directly here (since it just seems to be rescaling the base)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And so let's keep rewriting inv_freq. It has some indesirable effects on those models, but the other way is even worse.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable, I'll make this change after work.


if self.scale_factor > 1 and not self.dynamic_scaling:
length = max(seqlen, max_seq_len)

if (
seqlen > self._seq_len_cached
length > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
self._seq_len_cached = length
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Expand Down