Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU AVX implementation for Softmax, Norm #357

Merged
merged 76 commits into from
Jan 9, 2024

Conversation

fishingguy456
Copy link
Contributor

Working but inefficient batch matmul. Takes path of matmul_f32_x86 instead of cpu autoscheduler.

works on 8x8 at least but bad exp

save for omp changes

working and faster than pytorch

works and is fast but exp is WIP

remove useless files

minor changes for rebase

delete trash

fix trash

fix trash

initial commit

works on 8x8 at least but bad exp

save for omp changes

working and faster than pytorch

works and is fast but exp is WIP

remove useless files

minor changes for rebase

delete trash

fix trash

fix trash

change imports

fix for diff size, compiledmodule error fix
works on 8x8 at least but bad exp

save for omp changes

working and faster than pytorch

works and is fast but exp is WIP

remove useless files

minor changes for rebase

delete trash

fix trash

fix trash
works on 8x8 at least but bad exp

save for omp changes

working and faster than pytorch

works and is fast but exp is WIP

remove useless files

minor changes for rebase

delete trash

fix trash

fix trash
works on 8x8 at least but bad exp

save for omp changes

working and faster than pytorch

works and is fast but exp is WIP

remove useless files

minor changes for rebase

delete trash

fix trash

fix trash
Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fishingguy456 for your first PR to hidet!

I left some comments. In general,

  1. do not forget to run tests, lint and format before submit the PR.
  2. for the new operators, we should add some tests for them. See the examples in tests/operators/....
  3. our current design allow one task to have cpu and cuda implementation, and they share the same property of whether allow prologue, epilogue. When we want to change the allow properties, it is better to create a new task override the original one, so that it does not interfere with our original operator. In the future, we might want to add a device parameter to these functions (like allow_prologue(self, device) -> bool) so that we do not need creating a new class. But for now, let's create a new class and add resolve rule.

Comment on lines 100 to 101
def run_batch_matmul(self, a: Tensor, b: Tensor, is_cpu: bool) -> Tensor:
if is_cpu:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can directly check the device of tensor, without the need to pass is_cpu as parameter.

is_cpu = a.device.is_cpu()

Comment on lines 110 to 112
def allow_epilogue(self) -> bool:
return True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cpu and cuda task have different behavior, it is better to create a subclass of the task and override the subclass:

class CPUNormalizeTask(NormalizeTask):
    ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And implement the resolve rule to convert the normalize operator to corresponding cpu_normalize operator.

Comment on lines 443 to 444
norm_cpu_kernel.kind = "cpu_kernel"
avx_f32x8_find_sum.kind = "cpu_internal"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to avoid setting the function attributes outside its definition. Instead, use

from hidet.lang import attrs

@hidet.script
def norm_cpu_kernel(...):
    attrs.func_kind = "cpu_kernel"
    ...

@@ -16,6 +16,9 @@
from hidet.ir.builders import StmtBuilder
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync
from .utils import Task, TensorNode, compute, reduce
from typing import List, Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the top.

Remember to run format & lint, see https://docs.hidet.org/stable/developer-guides/contributing.html#contributing

@@ -153,3 +156,143 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]):
ir_module = module.ir_module()

return ir_module

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
# if not all(is_constant(dim) for dim in self.inputs[0].shape)\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# if not all(is_constant(dim) for dim in self.inputs[0].shape)\

Comment on lines +166 to +170
def allow_epilogue(self) -> bool:
return False

def allow_prologue(self) -> bool:
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a CPU version of the operator because cuda version allows prologue & epilogue.

Comment on lines 294 to 295
softmax_cpu_kernel.kind = "cpu_kernel"
apply_exponent.kind = "cpu_internal"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Comment on lines 442 to 445
if not (isinstance(func_var, Var) and isinstance(args, tuple)):
print(func_var, args)
print(type(args[0]))
print(type(func_var), type(args))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not (isinstance(func_var, Var) and isinstance(args, tuple)):
print(func_var, args)
print(type(args[0]))
print(type(func_var), type(args))

from hidet.ir.func import Function

@script
def avx_x86_f32x8_find_sum(x: f32x8) -> f32:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any convention to use "find" in the function name?

If not, I would prefer to name directly as "avx_x86_f32x8_sum" and "avx_x86_f32x8_max".

@yaoyaoding
Copy link
Member

Thanks @fishingguy456 ! Could you also add a test for softmax?

Hi @BolinSNLHM, could you have a look of this PR? I did not check the kernel implementation details.

Comment on lines 95 to 122
def avx_x86_f32x8_sum(x: f32x8) -> f32:
attrs.func_kind = "cpu_internal"
attrs.func_name = "avx_x86_float32x8_sum"
sum_vec = call_primitive_func(
'avx_x86_float32x4_add',
[
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]),
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]),
],
)
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec])
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec])
return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec])

assert isinstance(avx_x86_f32x8_sum, Function)
register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum)

@script
def avx_x86_f32x8_scalar_max(x: f32x8) -> f32:
attrs.func_kind = "cpu_internal"
attrs.func_name = "avx_x86_float32x8_scalar_max"
y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1])
m1 = call_primitive_func('avx_x86_float32x8_max', [x, y])
m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110])
m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2])
m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001])
m = call_primitive_func('avx_x86_float32x8_max', [m3, m4])
return call_primitive_func('avx_x86_float32x8_extract_last', [m])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to only declare the primitives (e.g., avx_x86_f32x8_extract_half) in this file, and then define functions like avx_x86_f32x8_sum as helper functions in Hidet Script in a separate file where it would be needed? The code should work as it is, but it looks a bit odd to have hidet.script decorator and multiple calls to call_primitive_func here...

@@ -73,7 +73,7 @@ def __init__(self, a: TensorNode, b: TensorNode):
)

def allow_epilogue(self) -> bool:
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we change this to False? 🤔

Comment on lines 90 to 122
from hidet.lang import script, attrs
from hidet.ir.dtypes import f32x8, f32
from hidet.ir.func import Function

@script
def avx_x86_f32x8_sum(x: f32x8) -> f32:
attrs.func_kind = "cpu_internal"
attrs.func_name = "avx_x86_float32x8_sum"
sum_vec = call_primitive_func(
'avx_x86_float32x4_add',
[
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b0]),
call_primitive_func('avx_x86_float32x8_extract_half', [x, 0b1]),
],
)
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec])
sum_vec = call_primitive_func('avx_x86_float32x4_hadd', [sum_vec, sum_vec])
return call_primitive_func('avx_x86_float32x4_extract_last', [sum_vec])

assert isinstance(avx_x86_f32x8_sum, Function)
register_primitive_function(avx_x86_f32x8_sum.name, avx_x86_f32x8_sum)

@script
def avx_x86_f32x8_scalar_max(x: f32x8) -> f32:
attrs.func_kind = "cpu_internal"
attrs.func_name = "avx_x86_float32x8_scalar_max"
y = call_primitive_func('avx_x86_float32x8_permute_2f128', [x, x, 1])
m1 = call_primitive_func('avx_x86_float32x8_max', [x, y])
m2 = call_primitive_func('avx_x86_float32x8_permute', [m1, 0b01001110])
m3 = call_primitive_func('avx_x86_float32x8_max', [m1, m2])
m4 = call_primitive_func('avx_x86_float32x8_permute', [m3, 0b10110001])
m = call_primitive_func('avx_x86_float32x8_max', [m3, m4])
return call_primitive_func('avx_x86_float32x8_extract_last', [m])
Copy link
Member

@yaoyaoding yaoyaoding Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommand to move these user-defined functions over avx (not the ones provided by underlying vector library) like avx_x86_f32x8_sum to another file called avx_helpers.py.

For functions like avx_x86_float32x4_extract_last, we also need to define a wrapper function like

def avx_x86_float32x4_extract_last(x: Expr) -> Call:
    return call_primitive_func('avx_x86_float32x4_extract_last', [x])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new file, we directly use avx_x86_float32x4_extract_last(...) in the hidet script, instead of calling call_primitive_func.

@yaoyaoding
Copy link
Member

Thanks @fishingguy456 !

@yaoyaoding yaoyaoding merged commit 52fe368 into hidet-org:main Jan 9, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants