|
23 | 23 | register_kernel_mapping, |
24 | 24 | replace_kernel_forward_from_hub, |
25 | 25 | ) |
26 | | - from kernels import ( |
27 | | - use_kernel_forward_from_hub as original_use_kernel_forward_from_hub, |
28 | | - ) |
| 26 | + from kernels import use_kernel_forward_from_hub |
29 | 27 |
|
30 | 28 | _hub_kernels_available = True |
31 | 29 |
|
|
56 | 54 | layer_name="TritonLlamaMLP", |
57 | 55 | ) |
58 | 56 | }, |
| 57 | + "MegaBlocksMoeMLP": { |
| 58 | + "cuda": LayerRepository( |
| 59 | + repo_id="kernels-community/megablocks", |
| 60 | + layer_name="MegaBlocksMoeMLP", |
| 61 | + ) |
| 62 | + }, |
59 | 63 | } |
60 | 64 |
|
61 | 65 | register_kernel_mapping(_KERNEL_MAPPING) |
62 | 66 |
|
63 | | - def use_kernel_forward_from_hub(*args, **kwargs): |
64 | | - """ |
65 | | - Expands `kernels`' `use_kernel_forward_from_hub` to NOT use a kernel at compile time. This should be removed |
66 | | - when `kernels` supports `torch.compile`. |
67 | | -
|
68 | | - If the layer has a `config` attribute, we can also set `config.disable_custom_kernels = True` to disable the |
69 | | - kernel. |
70 | | - """ |
71 | | - |
72 | | - def decorator_with_compile_path(cls): |
73 | | - # Keeps a reference to the original forward method |
74 | | - original_forward = cls.forward |
75 | | - |
76 | | - # Applies the original decorator |
77 | | - decorator = original_use_kernel_forward_from_hub(*args, **kwargs) |
78 | | - cls = decorator(cls) |
79 | | - |
80 | | - # Replaces the kernel forward with a compile-friendly version |
81 | | - kernel_forward = cls.forward |
82 | | - |
83 | | - def forward_with_compile_path(*forward_args, **forward_kwargs): |
84 | | - disable_custom_kernels = hasattr(cls, "config") and getattr(cls.config, "disable_custom_kernels", None) |
85 | | - if is_torchdynamo_compiling() or disable_custom_kernels: |
86 | | - return original_forward(*forward_args, **forward_kwargs) |
87 | | - else: |
88 | | - return kernel_forward(*forward_args, **forward_kwargs) |
89 | | - |
90 | | - cls.forward = forward_with_compile_path |
91 | | - |
92 | | - return cls |
93 | | - |
94 | | - return decorator_with_compile_path |
95 | | - |
96 | | - |
97 | 67 | except ImportError: |
98 | 68 | # Stub to make decorators int transformers work when `kernels` |
99 | 69 | # is not installed. |
|
0 commit comments