Skip to content

Commit 963d4d8

Browse files
authored
Merge pull request #7 from huggingface/megablocks_moe
feat: add megablocks moe mlp kernel
2 parents 0203846 + a3b22ce commit 963d4d8

File tree

2 files changed

+8
-37
lines changed

2 files changed

+8
-37
lines changed

src/transformers/integrations/hub_kernels.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
register_kernel_mapping,
2424
replace_kernel_forward_from_hub,
2525
)
26-
from kernels import (
27-
use_kernel_forward_from_hub as original_use_kernel_forward_from_hub,
28-
)
26+
from kernels import use_kernel_forward_from_hub
2927

3028
_hub_kernels_available = True
3129

@@ -56,44 +54,16 @@
5654
layer_name="TritonLlamaMLP",
5755
)
5856
},
57+
"MegaBlocksMoeMLP": {
58+
"cuda": LayerRepository(
59+
repo_id="kernels-community/megablocks",
60+
layer_name="MegaBlocksMoeMLP",
61+
)
62+
},
5963
}
6064

6165
register_kernel_mapping(_KERNEL_MAPPING)
6266

63-
def use_kernel_forward_from_hub(*args, **kwargs):
64-
"""
65-
Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed
66-
when `kernels` supports `torch.compile`.
67-
68-
If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the
69-
kernel.
70-
"""
71-
72-
def decorator_with_compile_path(cls):
73-
# Keeps a reference to the original forward method
74-
original_forward = cls.forward
75-
76-
# Applies the original decorator
77-
decorator = original_use_kernel_forward_from_hub(*args, **kwargs)
78-
cls = decorator(cls)
79-
80-
# Replaces the kernel forward with a compile-friendly version
81-
kernel_forward = cls.forward
82-
83-
def forward_with_compile_path(*forward_args, **forward_kwargs):
84-
disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None)
85-
if is_torchdynamo_compiling() or disable_custom_kernels:
86-
return original_forward(*forward_args, **forward_kwargs)
87-
else:
88-
return kernel_forward(*forward_args, **forward_kwargs)
89-
90-
cls.forward = forward_with_compile_path
91-
92-
return cls
93-
94-
return decorator_with_compile_path
95-
96-
9767
except ImportError:
9868
# Stub to make decorators int transformers work when `kernels`
9969
# is not installed.

src/transformers/models/openai_moe/modeling_openai_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
123123
return next_states
124124

125125

126+
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
126127
class OpenAIMoeMLP(nn.Module):
127128
def __init__(self, config):
128129
super().__init__()

0 commit comments

Comments
 (0)