Skip to content

Commit 1cea763

Browse files
authored
feat: extract rev in attn_implementation kernels via @ (#40009)
* feat: extract rev in attn_implementation kernels via @ * fix: adjust for ruff * fix: update regex and add explanatory comment * fix: move attn_implementation kernel doc * fix: remove extra line
1 parent e29919f commit 1cea763

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

src/transformers/modeling_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2721,7 +2721,7 @@ def _check_and_adjust_attn_implementation(
27212721
None to sdpa (to potentially eager).
27222722
"""
27232723
applicable_attn_implementation = "sdpa" if attn_implementation is None else attn_implementation
2724-
if re.match(r"^[^/:]+/[^/:]+:?[^/:]+$", applicable_attn_implementation):
2724+
if re.match(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", applicable_attn_implementation):
27252725
if not is_kernels_available():
27262726
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
27272727
attention_wrapper = None
@@ -2738,8 +2738,12 @@ def _check_and_adjust_attn_implementation(
27382738
repo_id = applicable_attn_implementation
27392739
kernel_name = None
27402740
repo_id = repo_id.strip()
2741+
# extract the rev after the @ if it exists
2742+
repo_id, _, rev = repo_id.partition("@")
2743+
repo_id = repo_id.strip()
2744+
rev = rev.strip() if rev else None
27412745
try:
2742-
kernel = get_kernel(repo_id)
2746+
kernel = get_kernel(repo_id, revision=rev)
27432747
if hasattr(kernel, "flash_attn_varlen_func"):
27442748
if attention_wrapper is None:
27452749
attention_wrapper = flash_attention_forward
@@ -4494,6 +4498,22 @@ def from_pretrained(
44944498
attn_implementation (`str`, *optional*):
44954499
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
44964500
4501+
Accept HF kernel references in the form:
4502+
<namespace>/<repo_name>[@<revision>][:<kernel_name>]
4503+
4504+
- <namespace> and <repo_name> are any non-"/" and non-":" sequences.
4505+
- "@<revision>" is optional (branch, tag, or commit-ish), e.g. "@main", "@v1.2.0", "@abc123".
4506+
- ":<kernel_name>" is optional and selects a function inside the kernel repo.
4507+
- Both options can appear together and in this order only: @revision first, then :kernel_name.
4508+
- We intentionally allow a leading "<wrapper>|" prefix (e.g., "flash|...") because the code
4509+
strips it before loading; '|' is not excluded in the character classes here.
4510+
4511+
Examples that match:
4512+
"org/model"
4513+
"org/model@main"
4514+
"org/model:custom_kernel"
4515+
"org/model@v1.2.3:custom_kernel"
4516+
44974517
> Parameters for big model inference
44984518
44994519
torch_dtype (`str` or `torch.dtype`, *optional*):

0 commit comments

Comments
 (0)