Skip to content

Commit

Permalink
Adding Rope scaling. (#741)
Browse files Browse the repository at this point in the history
# What does this PR do?


- Adds Rope NTK scaling.

Done because
#529 was
closed
Took some code from
huggingface/transformers#24653

- `--rope-scaling` and `--rope-factor` are added separately. I
considered having a single one and parsing something line ("linear:4.0"
, or "dynamic") but decided against
it because it would push more parsing+validation a bit everywhere (both
in the launcher and the server).


Fixes #512




<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
  • Loading branch information
Narsil authored Jul 31, 2023
1 parent b9633c4 commit 932bdd9
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 14 deletions.
61 changes: 61 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ impl std::fmt::Display for Dtype {
}
}

#[derive(Clone, Copy, Debug, ValueEnum)]
enum RopeScaling {
Linear,
Dynamic,
}

impl std::fmt::Display for RopeScaling {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To keep in track with `server`.
match self {
RopeScaling::Linear => {
write!(f, "linear")
}
RopeScaling::Dynamic => {
write!(f, "dynamic")
}
}
}
}

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
Expand Down Expand Up @@ -250,6 +270,26 @@ struct Args {
#[clap(default_value = "1.0", long, env)]
cuda_memory_fraction: f32,

/// Rope scaling will only be used for RoPE models
/// and allow rescaling the position rotary to accomodate for
/// larger prompts.
///
/// Goes together with `rope_factor`.
///
/// `--rope-factor 2.0` gives linear scaling with a factor of 2.0
/// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0
/// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed
/// basically)
///
/// `--rope-scaling linear --rope-factor` fully describes the scaling you want
#[clap(long, env)]
rope_scaling: Option<RopeScaling>,

/// Rope scaling will only be used for RoPE models
/// See `rope_scaling`
#[clap(long, env)]
rope_factor: Option<f32>,

/// Outputs the logs in JSON format (useful for telemetry)
#[clap(long, env)]
json_output: bool,
Expand Down Expand Up @@ -305,6 +345,8 @@ fn shard_manager(
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
Expand Down Expand Up @@ -358,6 +400,12 @@ fn shard_manager(
shard_args.push(revision)
}

let rope = match (rope_scaling, rope_factor) {
(None, None) => None,
(Some(scaling), None) => Some((scaling, 1.0)),
(Some(scaling), Some(factor)) => Some((scaling, factor)),
(None, Some(factor)) => Some((RopeScaling::Linear, factor)),
};
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_args.push("--otlp-endpoint".to_string());
Expand Down Expand Up @@ -395,6 +443,15 @@ fn shard_manager(
envs.push(("HUGGING_FACE_HUB_TOKEN".into(), api_token.into()))
};

// Detect rope scaling
// Sending as env instead of CLI args to not bloat everything
// those only can be used by RoPE models, so passing information around
// for all models will complexify code unnecessarily
if let Some((scaling, factor)) = rope {
envs.push(("ROPE_SCALING".into(), scaling.to_string().into()));
envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
}

// If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache {
Expand Down Expand Up @@ -784,6 +841,8 @@ fn spawn_shards(
let watermark_gamma = args.watermark_gamma;
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let rope_scaling = args.rope_scaling;
let rope_factor = args.rope_factor;
thread::spawn(move || {
shard_manager(
model_id,
Expand All @@ -802,6 +861,8 @@ fn spawn_shards(
watermark_gamma,
watermark_delta,
cuda_memory_fraction,
rope_scaling,
rope_factor,
otlp_endpoint,
status_sender,
shutdown,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, config, prefix, weights):
self.num_heads = self.num_heads // weights.process_group.size()

self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
config=config, prefix=f"{prefix}.rotary_emb", weights=weights
)

self.softmax_scale = self.head_size ** (-0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down Expand Up @@ -247,7 +247,7 @@ def __init__(
self.head_size = hidden_size // num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
config=config, dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down
86 changes: 76 additions & 10 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,33 +381,65 @@ def forward(self, hidden_states, residual=None):
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb

def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq

def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
return rope_scaling
return getattr(config, "rope_scaling", None)

class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
def __init__(self, inv_freq, scaling_factor):
super().__init__()

self.inv_freq = inv_freq
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None

@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)

@classmethod
def load(cls, prefix, weights):
def load(cls, config, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)

scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
return cls(inv_freq, scaling_factor)

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
Expand All @@ -419,8 +451,11 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)

freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
Expand All @@ -446,5 +481,36 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x

class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base

def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
if self.scaling_factor is not None:
t /= self.scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)

freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)


except ImportError:
pass

0 comments on commit 932bdd9

Please sign in to comment.