Skip to content

Commit

Permalink
Update conditionals for dynamic scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
iantbutler01 committed Jul 17, 2023
1 parent f01c11b commit 0ec4d81
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

if self.scale_factor > 1:
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,8 @@
get_linear,
)


ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"


def load_row(config, prefix: str, weights, bias: bool):
Expand Down Expand Up @@ -114,7 +109,7 @@ def __init__(self, config, prefix, weights):
self.scale_factor = ROPE_SCALE_FACTOR
self.dynamic_scaling = ROPE_DYNAMIC_SCALING

if self.scale_factor > 1:
if self.scale_factor > 1 or self.dynamic_scaling:
# Base before scaling is 10000 per the original RoPE paper
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@
)

ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"

def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")


ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))

if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
ROPE_DYNAMIC_SCALING = True
else:
ROPE_DYNAMIC_SCALING = False

ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"

def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
Expand Down
5 changes: 2 additions & 3 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
if self.dynamic_scaling:
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
max_seq_len = self.original_max_seq_len * scale_factor
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
self.register_buffer("inv_freq", inv_freq)
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)

if self.scale_factor > 1:
if self.scale_factor > 1 and not self.dynamic_scaling:
length = max(seqlen, max_seq_len)

if (
Expand Down

0 comments on commit 0ec4d81

Please sign in to comment.