-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU AVX implementation for Softmax, Norm #357
Conversation
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash initial commit works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash change imports fix for diff size, compiledmodule error fix
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
works on 8x8 at least but bad exp save for omp changes working and faster than pytorch works and is fast but exp is WIP remove useless files minor changes for rebase delete trash fix trash fix trash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @fishingguy456 for your first PR to hidet!
I left some comments. In general,
- do not forget to run tests, lint and format before submit the PR.
- for the new operators, we should add some tests for them. See the examples in
tests/operators/...
. - our current design allow one task to have cpu and cuda implementation, and they share the same property of whether allow prologue, epilogue. When we want to change the allow properties, it is better to create a new task override the original one, so that it does not interfere with our original operator. In the future, we might want to add a device parameter to these functions (like
allow_prologue(self, device) -> bool
) so that we do not need creating a new class. But for now, let's create a new class and add resolve rule.
def run_batch_matmul(self, a: Tensor, b: Tensor, is_cpu: bool) -> Tensor: | ||
if is_cpu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can directly check the device of tensor, without the need to pass is_cpu
as parameter.
is_cpu = a.device.is_cpu()
def allow_epilogue(self) -> bool: | ||
return True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the cpu and cuda task have different behavior, it is better to create a subclass of the task and override the subclass:
class CPUNormalizeTask(NormalizeTask):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And implement the resolve rule to convert the normalize operator to corresponding cpu_normalize operator.
norm_cpu_kernel.kind = "cpu_kernel" | ||
avx_f32x8_find_sum.kind = "cpu_internal" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to avoid setting the function attributes outside its definition. Instead, use
from hidet.lang import attrs
@hidet.script
def norm_cpu_kernel(...):
attrs.func_kind = "cpu_kernel"
...
python/hidet/graph/ops/softmax.py
Outdated
@@ -16,6 +16,9 @@ | |||
from hidet.ir.builders import StmtBuilder | |||
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync | |||
from .utils import Task, TensorNode, compute, reduce | |||
from typing import List, Union |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to the top.
Remember to run format & lint, see https://docs.hidet.org/stable/developer-guides/contributing.html#contributing
python/hidet/graph/ops/softmax.py
Outdated
@@ -153,3 +156,143 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]): | |||
ir_module = module.ir_module() | |||
|
|||
return ir_module | |||
|
|||
def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: | |||
# if not all(is_constant(dim) for dim in self.inputs[0].shape)\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# if not all(is_constant(dim) for dim in self.inputs[0].shape)\ |
def allow_epilogue(self) -> bool: | ||
return False | ||
|
||
def allow_prologue(self) -> bool: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a CPU version of the operator because cuda version allows prologue & epilogue.
python/hidet/graph/ops/softmax.py
Outdated
softmax_cpu_kernel.kind = "cpu_kernel" | ||
apply_exponent.kind = "cpu_internal" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
python/hidet/ir/expr.py
Outdated
if not (isinstance(func_var, Var) and isinstance(args, tuple)): | ||
print(func_var, args) | ||
print(type(args[0])) | ||
print(type(func_var), type(args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not (isinstance(func_var, Var) and isinstance(args, tuple)): | |
print(func_var, args) | |
print(type(args[0])) | |
print(type(func_var), type(args)) |
from hidet.ir.func import Function | ||
|
||
@script | ||
def avx_x86_f32x8_find_sum(x: f32x8) -> f32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any convention to use "find" in the function name?
If not, I would prefer to name directly as "avx_x86_f32x8_sum" and "avx_x86_f32x8_max".
Thanks @fishingguy456 ! Could you also add a test for softmax? Hi @BolinSNLHM, could you have a look of this PR? I did not check the kernel implementation details. |
def avx_x86_f32x8_sum(x: f32x8) -> f32: | ||
attrs.func_kind = "cpu_internal" | ||
attrs.func_name = "avx_x86_float32x8_sum" | ||
sum_vec = call_primitive_func( | ||
'avx_x86_float32x4_add', | ||
[ | ||
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]), | ||
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]), | ||
], | ||
) | ||
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec]) | ||
|
||
assert isinstance(avx_x86_f32x8_sum, Function) | ||
register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum) | ||
|
||
@script | ||
def avx_x86_f32x8_scalar_max(x: f32x8) -> f32: | ||
attrs.func_kind = "cpu_internal" | ||
attrs.func_name = "avx_x86_float32x8_scalar_max" | ||
y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1]) | ||
m1 = call_primitive_func('avx_x86_float32x8_max', [x, y]) | ||
m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110]) | ||
m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2]) | ||
m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001]) | ||
m = call_primitive_func('avx_x86_float32x8_max', [m3, m4]) | ||
return call_primitive_func('avx_x86_float32x8_extract_last', [m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to only declare the primitives (e.g., avx_x86_f32x8_extract_half
) in this file, and then define functions like avx_x86_f32x8_sum
as helper functions in Hidet Script in a separate file where it would be needed? The code should work as it is, but it looks a bit odd to have hidet.script decorator and multiple calls to call_primitive_func here...
@@ -73,7 +73,7 @@ def __init__(self, a: TensorNode, b: TensorNode): | |||
) | |||
|
|||
def allow_epilogue(self) -> bool: | |||
return True | |||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should we change this to False? 🤔
from hidet.lang import script, attrs | ||
from hidet.ir.dtypes import f32x8, f32 | ||
from hidet.ir.func import Function | ||
|
||
@script | ||
def avx_x86_f32x8_sum(x: f32x8) -> f32: | ||
attrs.func_kind = "cpu_internal" | ||
attrs.func_name = "avx_x86_float32x8_sum" | ||
sum_vec = call_primitive_func( | ||
'avx_x86_float32x4_add', | ||
[ | ||
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]), | ||
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]), | ||
], | ||
) | ||
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec]) | ||
return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec]) | ||
|
||
assert isinstance(avx_x86_f32x8_sum, Function) | ||
register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum) | ||
|
||
@script | ||
def avx_x86_f32x8_scalar_max(x: f32x8) -> f32: | ||
attrs.func_kind = "cpu_internal" | ||
attrs.func_name = "avx_x86_float32x8_scalar_max" | ||
y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1]) | ||
m1 = call_primitive_func('avx_x86_float32x8_max', [x, y]) | ||
m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110]) | ||
m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2]) | ||
m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001]) | ||
m = call_primitive_func('avx_x86_float32x8_max', [m3, m4]) | ||
return call_primitive_func('avx_x86_float32x8_extract_last', [m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommand to move these user-defined functions over avx (not the ones provided by underlying vector library) like avx_x86_f32x8_sum
to another file called avx_helpers.py
.
For functions like avx_x86_float32x4_extract_last
, we also need to define a wrapper function like
def avx_x86_float32x4_extract_last(x: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_extract_last', [x])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new file, we directly use avx_x86_float32x4_extract_last(...)
in the hidet script, instead of calling call_primitive_func
.
Thanks @fishingguy456 ! |
Working but inefficient batch matmul. Takes path of matmul_f32_x86 instead of cpu autoscheduler.