Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE: relaxed rope validation #32182

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,41 +354,46 @@ def _compute_llama3_parameters(

def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it
if "rope_type" not in received_keys and "type" in received_keys:
received_keys -= {"type"}
received_keys.add("rope_type")

missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")

if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys
else:
unused_keys = received_keys - received_keys
unused_keys = received_keys - required_keys
if unused_keys:
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the linked issues - it seems the problem was just with rope_type vs type: does this mean older configs also have other keys?

Copy link
Member Author

@gante gante Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts No other keys that we should support in our transformers code (hence the original exception). The only change was the renaming of (type -> rope_type), as suggested in the PR review.

However, users may be using the base config class in their own custom projects. e.g. Phi3 used custom fields for rope scaling (which are no longer custom because we merged them). For that reason, this PR:

  • no longer requires a validation function for a given rope type, it throws a warning if the function doesn't exist. Users can be working on a new rope type.
  • No longer raises an exception on unrecognized keys. Users might be working on a variant of an existing rope type

The only exceptions that persist are the ones mentioned at the top of this PR page, which would result in other logic errors down the line. I hope this clears out the issues we're seeing, while preventing as many future issues as possible from bad utilization.



def _validate_default_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)


def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
Expand All @@ -397,43 +402,43 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")

attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")

if (beta_fast or 32) < (beta_slow or 1):
raise ValueError(
logger.warning(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)


def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
Expand All @@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig):

short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")

long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")

# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
Expand All @@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig):
else:
factor = rope_scaling.get("factor")
if factor is None:
raise ValueError("Missing required keys in `rope_scaling`: 'factor'")
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")

attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)


def _validate_llama3_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")

low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float):
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
raise ValueError(
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
)

original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
raise ValueError(
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}"
)
if original_max_position_embeddings >= config.max_position_embeddings:
raise ValueError(
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
)
Expand All @@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig):
if rope_scaling is None:
return

possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys())
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
if rope_type is None:
raise ValueError(
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
)

# BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
else:
raise ValueError(
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)
3 changes: 3 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def __init__(
self.mlp_bias = mlp_bias

# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling["type"])
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
Expand Down
54 changes: 54 additions & 0 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,60 @@ def test_rope_class_retrocompatibility(self):
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)

def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
# steps.
base_config_dict = base_config.to_dict()
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
return new_config

# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)

# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)

# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)

# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
config = _reinitialize_config(
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)

# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])

# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])

# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are not testing the "type"=xxxx, "rope_scaling"


@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down
Loading