Skip to content

Commit

Permalink
[Ir][Primitives] add exp2 (#410)
Browse files Browse the repository at this point in the history
- add a primitive of exp2 for float types.  

This primitive could be useful when optimizing the flash attention.
Specifically, flash attention rewrites the exponential function as
```
# log2_e = 1.44269504
exp(a) = exp2(a * log2_e)
```
The transformation can hint the nvcc compiler to generate better code
(more ffma instructions instead of fmuls and fadds) .

---------

Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
  • Loading branch information
xiaocenxiaocen and xiaocenxiaocen authored Aug 14, 2024
1 parent 06180c9 commit 3e323e3
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/hidet/ir/primitives/cuda/math/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def register(self):
'sin': ['hsin', 1],
'cos': ['hcos', 1],
'exp': ['hexp', 1],
'exp2': ['hexp2', 1],
'sqrt': ['hsqrt', 1],
'rsqrt': ['hrsqrt', 1],
'log': ['hlog', 1],
Expand Down Expand Up @@ -158,6 +159,9 @@ def tanh(self, a: Expr) -> Expr:
def exp(self, a: Expr) -> Expr:
return self.call('cuda_f16_exp', a)

def exp2(self, a: Expr) -> Expr:
return self.call('cuda_f16_exp2', a)

def erf(self, a: Expr) -> Expr:
# use float32 erf to delegate the float16 erf
from hidet.ir.expr import cast
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/ir/primitives/cuda/math/float32.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def register():
'acosh': 'acoshf',
'atanh': 'atanhf',
'exp': '__expf', # fast math
'exp2': 'exp2f',
'erf': 'erff',
'sqrt': 'sqrtf',
'rsqrt': 'rsqrtf',
Expand Down Expand Up @@ -81,6 +82,9 @@ def tanh(self, a: Expr) -> Expr:
def exp(self, a: Expr) -> Expr:
return self.call('exp', a)

def exp2(self, a: Expr) -> Expr:
return self.call('exp2', a)

def erf(self, a: Expr) -> Expr:
return self.call('erf', a)

Expand Down
11 changes: 11 additions & 0 deletions python/hidet/ir/primitives/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def atanh(self, a: Expr) -> Expr:
def exp(self, a: Expr) -> Expr:
raise NotImplementedError()

def exp2(self, a: Expr) -> Expr:
raise NotImplementedError()

def expm1(self, a: Expr) -> Expr:
raise NotImplementedError()

Expand Down Expand Up @@ -217,6 +220,7 @@ def register():
'acosh',
'atanh',
'exp',
'exp2',
'expm1',
'erf',
'sqrt',
Expand Down Expand Up @@ -292,6 +296,9 @@ def atanh(self, a: Expr) -> Expr:
def exp(self, a: Expr) -> Expr:
return self.call('exp', a)

def exp2(self, a: Expr) -> Expr:
return self.call('exp2', a)

def expm1(self, a: Expr) -> Expr:
return self.call('expm1', a)

Expand Down Expand Up @@ -426,6 +433,10 @@ def exp(a: Expr) -> Expr:
return generic_math_function_set.exp(a)


def exp2(a: Expr) -> Expr:
return generic_math_function_set.exp2(a)


def expm1(a: Expr) -> Expr:
return generic_math_function_set.expm1(a)

Expand Down
79 changes: 79 additions & 0 deletions tests/ir/primitives/cuda/test_exp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import torch
import hidet


def test_exp2():
from hidet.lang import attrs
from hidet.ir.primitives.math import exp2
from hidet.ir.dtypes import f32
from hidet.lang.cuda import threadIdx

with hidet.script_module() as script_module:

@hidet.script
def func(out: f32[32]):
attrs.func_kind = "cuda_kernel"
attrs.cuda.block_dim = 32
attrs.cuda.grid_dim = 1

t = threadIdx.x
out[threadIdx.x] = exp2(f32(t))

func = script_module.build()

out = torch.empty((32,), dtype=torch.float32, device="cuda")
out = hidet.from_torch(out)
func(out)
import numpy as np

groundtruth = np.array([2**i for i in range(32)], dtype=np.float32)
np.testing.assert_equal(out.cpu().numpy(), groundtruth)


def test_exp2_f16():
from hidet.lang import attrs
from hidet.ir.primitives.math import exp2
from hidet.ir.dtypes import f16
from hidet.lang.cuda import threadIdx

with hidet.script_module() as script_module:

@hidet.script
def func(out: f16[4]):
attrs.func_kind = "cuda_kernel"
attrs.cuda.block_dim = 32
attrs.cuda.grid_dim = 1

t = threadIdx.x
if t < 4:
out[threadIdx.x] = exp2(f16(t))

func = script_module.build()

out = torch.empty((4,), dtype=torch.float16, device="cuda")
out = hidet.from_torch(out)
func(out)
import numpy as np

groundtruth = np.array([2**i for i in range(4)], dtype=np.float16)
np.testing.assert_equal(out.cpu().numpy(), groundtruth)


if __name__ == "__main__":
hidet.option.cache_dir("./exp2")
hidet.option.save_lower_ir(True)

pytest.main([__file__])

0 comments on commit 3e323e3

Please sign in to comment.