-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU AVX implementation for Softmax, Norm #357
Merged
Merged
Changes from 73 commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
14cbb3b
initial commit
fishingguy456 7896c45
works on multidimensional, axis=-1
fishingguy456 ff90ed5
initial commit
fishingguy456 fc61204
change imports
fishingguy456 f84201f
fix for diff size, compiledmodule error fix
fishingguy456 6f2e43c
works on multidimensional, axis=-1
fishingguy456 25f22cf
initial commit
fishingguy456 aafbb0f
initial commit
fishingguy456 44993e2
change imports
fishingguy456 a86d866
fix for diff size, compiledmodule error fix
fishingguy456 b59ffa2
works on multidimensional, axis=-1
fishingguy456 7edf0eb
wrap up softmax, starting layernorm
fishingguy456 44c04b3
layernorm kinda works but not rly
fishingguy456 2ccc4b6
better code for softmax
fishingguy456 13ea5dc
layernorm works for last layer
fishingguy456 d89036d
move find sum and find max to registered function
fishingguy456 b0659f6
find max in registered func
fishingguy456 904760b
not working softmax on not last dim, minor changes
fishingguy456 29b7ba7
layernorm works for any dims
fishingguy456 0c8dc3a
comments
fishingguy456 77fe8d9
tuning, fix for flowgraph operator resolve
fishingguy456 ac40695
softmax works
fishingguy456 4938a1f
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 1d447cf
actually works rn frfr so fast :100:
fishingguy456 30224ce
cleanup
fishingguy456 67d4d56
more cleanup
fishingguy456 09ca2f8
random testing stuff
fishingguy456 8352dd8
allow epilogue
fishingguy456 27f6cbb
better epiloguing
fishingguy456 cce1d42
janky matmul resolve
fishingguy456 f92de53
still epilogue problem?
fishingguy456 63dfed4
initial commit
fishingguy456 73a063a
works on multidimensional, axis=-1
fishingguy456 1c129c0
initial commit
fishingguy456 bf8a5b5
change imports
fishingguy456 3aa5cb6
fix for diff size, compiledmodule error fix
fishingguy456 b849ebf
works on multidimensional, axis=-1
fishingguy456 12fdbd1
initial commit
fishingguy456 9c7ecd0
initial commit
fishingguy456 b155bbd
change imports
fishingguy456 de72bc6
fix for diff size, compiledmodule error fix
fishingguy456 17b8d76
works on multidimensional, axis=-1
fishingguy456 1b52167
wrap up softmax, starting layernorm
fishingguy456 e479db7
layernorm kinda works but not rly
fishingguy456 c623630
better code for softmax
fishingguy456 b44b69e
layernorm works for last layer
fishingguy456 29ea558
move find sum and find max to registered function
fishingguy456 339e549
find max in registered func
fishingguy456 88c423c
not working softmax on not last dim, minor changes
fishingguy456 9c91875
layernorm works for any dims
fishingguy456 6e0d8e5
comments
fishingguy456 552aebb
tuning, fix for flowgraph operator resolve
fishingguy456 dc258e3
softmax works
fishingguy456 95f6be7
commented tensors dont work, i.e. axis is not last 2 AND not multiple…
fishingguy456 d0b99a4
actually works rn frfr so fast :100:
fishingguy456 67a43a5
cleanup
fishingguy456 4443780
more cleanup
fishingguy456 4088fc6
random testing stuff
fishingguy456 7430696
allow epilogue
fishingguy456 8a1167e
better epiloguing
fishingguy456 0f4876f
janky matmul resolve
fishingguy456 49c072f
still epilogue problem?
fishingguy456 0bd13d8
Merge remote-tracking branch 'origin/main'
fishingguy456 de74231
clean up for pr
fishingguy456 9ab0bac
fix test
fishingguy456 f779a1d
lint
fishingguy456 124fb09
minor pr edits
fishingguy456 6c4efd9
pytests, cpu child class
fishingguy456 40fd71f
potential fix for failing tests? but prob not will have to investigat…
fishingguy456 90c4ffb
weird diff
fishingguy456 587ba64
merge conflict resolve build.py
fishingguy456 89d5646
remove shady batch mat mul
fishingguy456 a3a4b03
lint thing
fishingguy456 aec95d2
move helpers to new file
fishingguy456 7a41b5c
lint
fishingguy456 dcc6a45
change tolerance for flaky test for test_dynamic_shape
fishingguy456 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,15 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Union | ||
from hidet.ir.module import IRModule | ||
from hidet.ir import primitives as prim | ||
from hidet.ir.expr import is_constant | ||
from hidet.ir.stmt import Stmt, AssignStmt | ||
from hidet.ir.builders import StmtBuilder | ||
from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync | ||
from hidet.ir.dtypes import float32 | ||
from hidet.ir.library import tune | ||
from .utils import Task, TensorNode, compute, reduce | ||
|
||
|
||
|
@@ -153,3 +156,159 @@ def softmax_kernel(xs: xdtype[shape], ys: xdtype[shape]): | |
ir_module = module.ir_module() | ||
|
||
return ir_module | ||
|
||
def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]: | ||
if self.inputs[0].type.dtype != float32: | ||
return NotImplemented # use auto-scheduler | ||
return tune.extract_ir_modules(self.schedule_softmax_cpu) | ||
|
||
|
||
class CPUSoftmaxTask(SoftmaxTask): | ||
def allow_epilogue(self) -> bool: | ||
return False | ||
|
||
def allow_prologue(self) -> bool: | ||
return False | ||
Comment on lines
+167
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create a CPU version of the operator because cuda version allows prologue & epilogue. |
||
|
||
@tune.space(2, nthreads=['', 4, 8, 16, 32, 64, 96]) | ||
@tune.space(1, nthreads=['', 8, 16]) | ||
def schedule_softmax_cpu(self, nthreads='') -> IRModule: | ||
import hidet | ||
from hidet.ir.primitives.cpu.avx import ( | ||
avx_f32x8_subtract, | ||
avx_f32x8_load, | ||
avx_f32x8_setzero, | ||
avx_f32x8_store, | ||
avx_f32x8_add, | ||
avx_f32x8_max, | ||
avx_f32x8_set1, | ||
avx_f32x8_divide, | ||
avx_f32x8_sum, | ||
avx_f32x8_scalar_max, | ||
) | ||
from hidet.lang import tensor, attrs, grid | ||
from hidet.ir.stmt import DeclareScope | ||
from hidet.lang.mapping import spatial | ||
from hidet.utils import prod | ||
from hidet.ir.dtypes import float32x8 | ||
|
||
shape = self.inputs[0].shape | ||
head = shape[: self.axis] | ||
tail = shape[self.axis :] if self.axis == len(shape) - 1 else shape[self.axis + 1 :] | ||
head_size = prod(head) | ||
tail_size = prod(tail) | ||
axis_size = shape[self.axis] | ||
|
||
with hidet.script_module() as module: | ||
|
||
@hidet.script | ||
def apply_exponent(vec: float32x8) -> float32x8: | ||
attrs.func_kind = "cpu_internal" | ||
arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[8]) | ||
avx_f32x8_store(arr, vec) | ||
for n in range(8): | ||
arr[n] = prim.exp(arr[n]) | ||
return avx_f32x8_load(arr) | ||
|
||
@hidet.script | ||
def softmax_cpu_kernel(x: float32[shape], out: float32[shape]): | ||
attrs.func_kind = "cpu_kernel" | ||
para = 'p' + str(nthreads) | ||
for k in grid(head_size, attrs=para): | ||
head_idx = spatial(*head).map(k) | ||
if self.axis == len(shape) - 1: # last dim | ||
temp_exp = tensor(dtype=float32, shape=tail) | ||
max_val = x[head_idx][0] | ||
if tail_size >= 8: | ||
# vectorized find max value | ||
max_vec = avx_f32x8_load(~x[head_idx][0]) | ||
for i in range(tail_size // 8): | ||
data_vec = avx_f32x8_load(~x[head_idx][i * 8]) | ||
max_vec = avx_f32x8_max(max_vec, data_vec) | ||
max_val = avx_f32x8_scalar_max(max_vec) | ||
for i in range(tail_size % 8): | ||
# max value of remaining unvectorized parts | ||
max_val = ( | ||
max_val | ||
if max_val > x[head_idx][tail_size - tail_size % 8 + i] | ||
else x[head_idx][tail_size - tail_size % 8 + i] | ||
) | ||
|
||
# subtract max, take exp and find exp sum | ||
sum_value = 0.0 | ||
if tail_size >= 8: | ||
sum_exp_vec = avx_f32x8_setzero() | ||
max_vec = avx_f32x8_set1(max_val) | ||
for i in range(tail_size // 8): | ||
val_vec = avx_f32x8_load(~x[head_idx][i * 8]) | ||
val_vec = avx_f32x8_subtract(val_vec, max_vec) | ||
val_vec = apply_exponent(val_vec) | ||
avx_f32x8_store(~temp_exp[i * 8], val_vec) | ||
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec) | ||
sum_value = avx_f32x8_sum(sum_exp_vec) | ||
for i in range(tail_size % 8): | ||
temp_exp[tail_size - tail_size % 8 + i] = prim.exp( | ||
x[head_idx][tail_size - tail_size % 8 + i] - max_val | ||
) | ||
sum_value += temp_exp[tail_size - tail_size % 8 + i] | ||
|
||
# divide by exp sum | ||
if tail_size >= 8: | ||
# divide | ||
sum_vec8 = avx_f32x8_set1(sum_value) | ||
for i in range(tail_size // 8): | ||
avx_f32x8_store( | ||
~temp_exp[i * 8], avx_f32x8_divide(avx_f32x8_load(~temp_exp[i * 8]), sum_vec8) | ||
) | ||
for i in range(tail_size % 8): | ||
temp_exp[tail_size - tail_size % 8 + i] /= sum_value | ||
for i in range(tail_size): | ||
out[head_idx][i] = temp_exp[i] | ||
else: # not last dim | ||
temp_exp = tensor(dtype=float32, shape=[shape[self.axis]] + tail) | ||
# vectorized operations across all contiguous memory for relevant axis | ||
for g in range(tail_size // 8): | ||
tail_idx = spatial(*tail).map(g * 8) | ||
max_vec = avx_f32x8_load(~x[head_idx][0][tail_idx]) | ||
for i in range(axis_size): | ||
data_vec = avx_f32x8_load(~x[head_idx][i][tail_idx]) | ||
max_vec = avx_f32x8_max(max_vec, data_vec) | ||
sum_exp_vec = avx_f32x8_setzero() | ||
for i in range(axis_size): | ||
val_vec = avx_f32x8_load(~x[head_idx][i][tail_idx]) | ||
val_vec = avx_f32x8_subtract(val_vec, max_vec) | ||
val_vec = apply_exponent(val_vec) | ||
avx_f32x8_store(~temp_exp[i][tail_idx], val_vec) | ||
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec) | ||
for i in range(axis_size): | ||
avx_f32x8_store( | ||
~temp_exp[i][tail_idx], | ||
avx_f32x8_divide(avx_f32x8_load(~temp_exp[i][tail_idx]), sum_exp_vec), | ||
) | ||
for j in range(8): | ||
tail_end_idx = spatial(*tail).map(g * 8 + j) | ||
out[head_idx][i][tail_end_idx] = temp_exp[i][tail_end_idx] | ||
# unvectorized operations for the remaining elements | ||
max_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8]) | ||
for j in range(tail_size % 8): | ||
max_arr[j] = 0.0 | ||
for p in range(axis_size): | ||
for j in range(tail_size % 8): | ||
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j) | ||
max_arr[j] = prim.max(max_arr[j], x[head_idx][p][last_idx]) # TODO: index | ||
sum_exp_arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[tail_size % 8]) | ||
for j in range(tail_size % 8): | ||
sum_exp_arr[j] = 0.0 | ||
for p in range(axis_size): | ||
for j in range(tail_size % 8): | ||
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j) | ||
out[head_idx][p][last_idx] = prim.exp(x[head_idx][p][last_idx] - max_arr[j]) | ||
sum_exp_arr[j] += out[head_idx][p][last_idx] | ||
for p in range(axis_size): | ||
for j in range(tail_size % 8): | ||
last_idx = spatial(*tail).map(tail_size - tail_size % 8 + j) | ||
out[head_idx][p][last_idx] = out[head_idx][p][last_idx] / sum_exp_arr[j] | ||
|
||
assert isinstance(softmax_cpu_kernel, hidet.ir.Function) | ||
ir_module = module.ir_module() | ||
return ir_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should we change this to False? 🤔