Skip to content

Commit 0ec4d81

Browse files
committed
Update conditionals for dynamic scaling
1 parent f01c11b commit 0ec4d81

File tree

5 files changed

+7
-23
lines changed

5 files changed

+7
-23
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(
115115
self.scale_factor = ROPE_SCALE_FACTOR
116116
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
117117

118-
if self.scale_factor > 1:
118+
if self.scale_factor > 1 or self.dynamic_scaling:
119119
# Base before scaling is 10000 per the original RoPE paper
120120
self.rotary_emb = PositionRotaryEmbedding.static(
121121
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling

server/text_generation_server/models/custom_modeling/flash_neox_modeling.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,8 @@
4545
get_linear,
4646
)
4747

48-
4948
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
50-
51-
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
52-
ROPE_DYNAMIC_SCALING = True
53-
else:
54-
ROPE_DYNAMIC_SCALING = False
49+
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
5550

5651

5752
def load_row(config, prefix: str, weights, bias: bool):
@@ -114,7 +109,7 @@ def __init__(self, config, prefix, weights):
114109
self.scale_factor = ROPE_SCALE_FACTOR
115110
self.dynamic_scaling = ROPE_DYNAMIC_SCALING
116111

117-
if self.scale_factor > 1:
112+
if self.scale_factor > 1 or self.dynamic_scaling:
118113
# Base before scaling is 10000 per the original RoPE paper
119114
self.rotary_emb = PositionRotaryEmbedding.static(
120115
self.head_size, 10000, weights.device, self.scale_factor, self.dynamic_scaling, config.max_position_embeddings

server/text_generation_server/models/custom_modeling/flash_rw_modeling.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626
)
2727

2828
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
29-
30-
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
31-
ROPE_DYNAMIC_SCALING = True
32-
else:
33-
ROPE_DYNAMIC_SCALING = False
29+
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
3430

3531
def load_row(config, prefix: str, weights, bias: bool):
3632
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

server/text_generation_server/models/custom_modeling/neox_modeling.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,8 @@
6060
if not CUSTOM_KERNELS_ENABLED:
6161
logger.warning("We're not using custom kernels.")
6262

63-
6463
ROPE_SCALE_FACTOR = int(os.getenv("ROPE_SCALE_FACTOR", 1))
65-
66-
if os.getenv("ROPE_DYNAMIC_SCALING", False).lower() == "true":
67-
ROPE_DYNAMIC_SCALING = True
68-
else:
69-
ROPE_DYNAMIC_SCALING = False
70-
64+
ROPE_DYNAMIC_SCALING = os.getenv("ROPE_DYNAMIC_SCALING", "false").lower() == "true"
7165

7266
def make_causal_mask(
7367
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int

server/text_generation_server/utils/layers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
423423
if self.dynamic_scaling:
424424
scale_factor = (self.scale_factor * length / self.original_max_seq_len) - (self.scale_factor - 1)
425425
max_seq_len = self.original_max_seq_len * scale_factor
426-
inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
427-
self.register_buffer("inv_freq", inv_freq)
426+
self.inv_freq = self._get_inv_freq(self.dim, self.base, inv_freq.device, scale_factor)
428427

429-
if self.scale_factor > 1:
428+
if self.scale_factor > 1 and not self.dynamic_scaling:
430429
length = max(seqlen, max_seq_len)
431430

432431
if (

0 commit comments

Comments
 (0)