Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax cpu #323

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions include/hidet/runtime/cpu/avx_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <immintrin.h>

static inline __m256
as_v8_f32_u32(__m256i x)
{
union {
__m256i _xi; __m256 _xf;
} val = { ._xi = x};

return val._xf;
}

static inline __m256i
as_v8_u32_f32(__m256 x)
{
union {
__m256i _xi; __m256 _xf;
} val = { ._xf = x};

return val._xi;
}

/*
* p(x) = c7*x^7 + c6*x^6 + c5*x^5 + c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
* = ((c6+c7*x)*x2 + (c4+c5*x))*x4 + ((c2+c3*x)*x2 + (c0+c1*x))
*/

#define POLY_EVAL_7(x, c0, c1, c2, c3, c4, c5, c6, c7) ({ \
__typeof(x) x2 = x * x; \
__typeof(x) x4 = x2 * x2; \
__typeof(x) q = mul_add(mul_add(mul_add(c7, x, c6), \
x2, \
mul_add(c5, x, c4)), \
x4, \
mul_add(mul_add(c3, x, c2), \
x2, \
mul_add(c1, x, c0))); \
q; \
})

#define mul_add(x, y, z) \
_Generic((x), \
float : _mm_fmadd_ss, \
double : _mm_fmadd_sd, \
__m128 : _mm_fmadd_ps, \
__m128d: _mm_fmadd_pd, \
__m256 : _mm256_fmadd_ps, \
__m256d: _mm256_fmadd_pd, \
__m512 : _mm512_fmadd_ps, \
__m512d: _mm512_fmadd_pd)((x), (y), (z))
7 changes: 5 additions & 2 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,10 +566,11 @@ def visit_DataType(self, t: DataType):
'complex128': 'complex128_t',
'float32x4': '__m128',
'float32x8': '__m256',
'uint32x8': '__m256i'
}

self.require_complex = self.require_complex or t.name in ['complex64', 'complex128']
self.require_immintrin = self.require_immintrin or t.name in ['float32x4', 'float32x8']
self.require_immintrin = self.require_immintrin or t.name in ['float32x4', 'float32x8', 'uint32x8']
self.require_bf16 = self.require_bf16 or t.name == 'bfloat16'
self.require_fp16 = self.require_fp16 or t.name == 'float16'
self.require_tf32 = self.require_tf32 or t.name == 'tfloat32'
Expand Down Expand Up @@ -625,6 +626,7 @@ def require_headers(self) -> Doc:
doc += Text('#include <stdint.h>') + NewLine()
if self.require_immintrin:
doc += Text('#include <immintrin.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/avx_helper.h>') + NewLine()
if self.require_fp16:
doc += Text('#include <cuda_fp16.h>') + NewLine()
if self.require_bf16:
Expand Down Expand Up @@ -704,7 +706,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <stdint.h>') + NewLine()
doc += Text('#include <math.h>') + NewLine()
if self.require_immintrin:
doc += Text('#include <immintrin.h>')
doc += Text('#include <immintrin.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/avx_helper.h>') + NewLine()
doc += Text('#include <hidet/runtime/symbols.h>') + NewLine()
doc += Text('#include <hidet/runtime/memory_planner.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/context.h>') + NewLine()
Expand Down
5 changes: 5 additions & 0 deletions python/hidet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled
f.write(hidet.__version__)

# implement task
#TODO: Look into this and get it working on CPU!
ir_module = task.implement(target=target_device, working_dir=task_dir)

# compile ir module
Expand All @@ -113,6 +114,7 @@ def build_task(task: Task, target_device='cuda', load=True) -> Optional[Compiled
save_ir=option.get_option('save_lower_ir'),
load=False,
use_hash_dir=False,
target_device=target_device,
)
if load:
compiled_module = load_compiled_module(task_dir)
Expand Down Expand Up @@ -175,13 +177,16 @@ def build_ir_module(
save_ir: bool = True,
load: bool = True,
use_hash_dir: bool = True,
target_device: str = "cpu",
) -> Optional[CompiledModule]:
if use_hash_dir:
hash_dir = sha256(str(ir_module).encode()).hexdigest()[:16]
output_dir = os.path.join(output_dir, hash_dir)

src_path = (
os.path.join(output_dir, 'source.cu') if hidet.cuda.available() else os.path.join(output_dir, 'source.cc')
# TODO: change this back to cuda is available
# os.path.join(output_dir, 'source.cu') if target_device == "cuda" else os.path.join(output_dir, "source.cc")
)
lib_path = os.path.join(output_dir, 'lib.so')

Expand Down
65 changes: 55 additions & 10 deletions python/hidet/graph/ops/definitions/normalize/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Union
from hidet.ir import IRModule
from hidet.ir.primitives import active_mask, shfl_down_sync
from hidet.ir.compute import reduce
Expand Down Expand Up @@ -105,6 +105,35 @@ def allow_prologue(self) -> bool:
def allow_epilogue(self) -> bool:
return False

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
import hidet
# return NotImplemented
x, y = self.inputs[0], self.outputs[0]
input_shape: List[int] = list(x.const_shape)
dims = self.dims

spatial_shape = [v for i, v in enumerate(input_shape) if i not in dims]
reduce_shape = [input_shape[i] for i in dims]
dim_zeros = [0] * len(dims)

reduce_extent = prod(reduce_shape) # product of all elems

accumulate_dtype = data_type(self.attrs['accumulate_dtype']) # float32

with hidet.script_module() as module:
@hidet.script
def welford_update_mean_var(a, mean, m2):
delta = as_tensor_pointer()

@hidet.script
def norm_kernel(x: f32[x.const_shape], y: f32[y.const_shape]):
pass
# TODO: parallelize, try various vector sizes for loading the values?

ir_module = module.ir_module()
return ir_module


def implement_cuda(self, working_dir: str) -> IRModule:
import hidet
import math
Expand All @@ -117,8 +146,7 @@ def implement_cuda(self, working_dir: str) -> IRModule:
reduce_shape = [input_shape[i] for i in dims]
dim_zeros = [0] * len(dims)

reduce_extent = prod(reduce_shape)

reduce_extent = prod(reduce_shape) # product of all elems
warp_size = 32
block_size = min(max(warp_size, reduce_extent), 1024)
block_size = math.ceil(block_size / warp_size) * warp_size
Expand All @@ -133,18 +161,34 @@ def implement_cuda(self, working_dir: str) -> IRModule:
used_smem_bytes_per_block = shm_count

stages = math.ceil(math.log(block_size) / math.log(warp_size))
print(x, y)
print(input_shape)
print(dims)
print(spatial_shape)
print(reduce_shape)
print(dim_zeros)
print(reduce_extent)
print(block_size)
print(repeat_reduction)
print(task_layout)
print(grid_size)
print(accumulate_dtype)
print(shm_count)
print(used_smem_bytes_per_block)
print(stages)
# assert False
assert stages <= 2

with hidet.script_module() as module:

@hidet.script
def welford_combine(
mean_a: TensorType(dtype=accumulate_dtype, shape=[1]),
m2_a: TensorType(dtype=accumulate_dtype, shape=[1]),
count_a: TensorType(dtype=i32, shape=[1]),
mean_b: TensorType(dtype=accumulate_dtype, shape=[1]),
m2_b: TensorType(dtype=accumulate_dtype, shape=[1]),
count_b: TensorType(dtype=i32, shape=[1]),
mean_a: TensorType(dtype=accumulate_dtype, shape=[1]),
m2_a: TensorType(dtype=accumulate_dtype, shape=[1]),
count_a: TensorType(dtype=i32, shape=[1]),
mean_b: TensorType(dtype=accumulate_dtype, shape=[1]),
m2_b: TensorType(dtype=accumulate_dtype, shape=[1]),
count_b: TensorType(dtype=i32, shape=[1]),
):
count = count_a[0] + count_b[0]
if count == 0:
Expand All @@ -153,7 +197,8 @@ def welford_combine(

mean_a[0] = mean_a[0] + delta * cast(count_b[0], f32) / cast(count, f32)
m2_a[0] = (
m2_a[0] + m2_b[0] + delta * delta * cast(count_a[0], f32) * cast(count_b[0], f32) / cast(count, f32)
m2_a[0] + m2_b[0] + delta * delta * cast(count_a[0], f32) * cast(count_b[0], f32) / cast(count,
f32)
)
count_a[0] = count

Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/definitions/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
y_shape.append(1)
else:
y_shape.append(x.shape[i])

print(y_shape)
def fcompute(*indices):
def reduce_fcompute(*reduce_indices):
x_indices = []
Expand Down
127 changes: 127 additions & 0 deletions python/hidet/graph/ops/definitions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from hidet.ir import primitives as prim
from hidet.ir.expr import is_constant
from .utils import Task, TensorNode, compute, reduce
from typing import List, Union
from hidet.ir.dtypes import float32
from hidet.graph.ops.definitions.utils import tune


class SoftmaxTask(Task):
Expand Down Expand Up @@ -66,3 +69,127 @@ def implement_cuda(self, working_dir: str) -> IRModule:
return NotImplemented # use auto-scheduler

return softmax_cuda_schedule(self)

def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
if not all(is_constant(dim) for dim in self.inputs[0].shape) or (self.axis != len(self.x_shape) - 1 and
self.axis != -2): # not row-major, avx no good
return NotImplemented # use auto-scheduler
return self.schedule_softmax_cpu()
# return tune.extract_ir_modules(self.schedule_softmax_cpu)

# @tune.space(2, 'nthreads', [4, 8, 16, 32, 64, 96])
# @tune.space(1, 'nthreads', [8, 16])
def schedule_softmax_cpu(self, nthreads=4) -> IRModule:
import hidet
from hidet.ir.primitives.cpu.avx import avx_f32x8_subtract, avx_f32x8_load, avx_f32x8_setzero, avx_f32x8_store,\
avx_f32x8_add, avx_f32x8_max, avx_f32x8_permute, avx_f32x8_permute_2f128, avx_f32x8_extract_last,\
avx_f32x8_extract_half, avx_f32x4_add, avx_f32x4_hadd, avx_f32x4_extract_last, avx_f32x8_broadcast,\
avx_f32x8_divide, avx_f32x8_to_u32x8, avx_u32x8_to_f32x8
from hidet.ir.dtypes import float32x8
from hidet.lang.constructs.type import tensor
from hidet.ir.stmt import DeclareScope
from hidet.lang import grid
row_size, col_size = self.x_shape[-2], self.x_shape[-1]

with hidet.script_module() as module:
@hidet.script
def find_max(max_vec: float32x8) -> float32:
y = avx_f32x8_permute_2f128(max_vec, max_vec, 1) # swap first and last 4
m1 = avx_f32x8_max(max_vec, y)
m2 = avx_f32x8_permute(m1, 0b01001110) # reshuffle to 2 elems per vec and compare
m3 = avx_f32x8_max(m1, m2)
m4 = avx_f32x8_permute(m3, 0b10110001) # reshuffle to 1 elem per vec and compare
m = avx_f32x8_max(m3, m4) # max val
return avx_f32x8_extract_last(m)

@hidet.script
def find_sum(x: float32x8) -> float32:
sum_vec = avx_f32x4_add(avx_f32x8_extract_half(x, 0b0), avx_f32x8_extract_half(x, 0b1))
sum_vec = avx_f32x4_hadd(sum_vec, sum_vec)
sum_vec = avx_f32x4_hadd(sum_vec, sum_vec)
return avx_f32x4_extract_last(sum_vec)

# @hidet.script
# def avx_exp(x: float32x8) -> float32x8:
# vx = avx_f32x8_to_u32x8(x)
# vx = vx & MASK
# cond = vx > ARG_MAX # I think all these operations should be avx?
# z = x * TBL_LN2
# dn = z + EXP_HUGE
# r1 = x - (dn * LN2_TBL_H)
# r2 = dn * LN2_TBL_T
# r = r1 - r2
# m = (n + EXPF_BIAS) << 23
# poly = POLY_EVAL_7() # how can i call the macro? idk...
# result = poly * avx_u32x8_to_f32x8(m)
#
# # if cond is not satisfied, resort to regular scalar expf
# return result

@hidet.script
def softmax_cpu(x: float32[row_size, col_size], out: float32[row_size, col_size]):
para = 'p' + str(nthreads)
for i in grid(row_size, attrs=para):
# find max
max_val = x[i, 0]
if col_size >= 8:
max_vec = avx_f32x8_load(x + i * col_size) # only if greater than equal 8
for j in range(col_size//8):
data_vec = avx_f32x8_load(x + i * col_size + j * 8)
max_vec = avx_f32x8_max(max_vec, data_vec)
max_val = find_max(max_vec)
for j in range(col_size % 8):
max_val = max_val if max_val > x[i, col_size + j - 8] else x[i, col_size + j - 8]

# subtract max, take exp and find exp sum
sum_value = 0.0
if col_size >= 8:
sum_exp_vec = avx_f32x8_setzero()
max_vec = avx_f32x8_broadcast(~max_val)
for j in range(col_size//8):
val_vec = avx_f32x8_load(x + i * col_size + j * 8)
val_vec = avx_f32x8_subtract(val_vec, max_vec)
# apply exponent val_vec = avxexponent
arr = tensor(scope=DeclareScope.Default, dtype=float32, shape=[8])
avx_f32x8_store(arr, val_vec)
for k in range(8):
arr[k] = prim.exp(arr[k])
val_vec = avx_f32x8_load(arr)
# val_vec = avx_exp(val_vec)
avx_f32x8_store(out + i * col_size + j * 8, val_vec)
sum_exp_vec = avx_f32x8_add(sum_exp_vec, val_vec)
sum_value = find_sum(sum_exp_vec)
for j in range(col_size % 8):
out[i, col_size + j - 8] = prim.exp(x[i, col_size + j - 8] - max_val)
sum_value += out[i, col_size + j - 8]

# divide by exp sum
if col_size >= 8:
# divide
sum_vec8 = avx_f32x8_broadcast(~sum_value)
for j in range(col_size//8):
avx_f32x8_store(out + i * col_size + j * 8,
avx_f32x8_divide(avx_f32x8_load(out + i * col_size + j * 8), sum_vec8))
for j in range(col_size % 8):
out[i, col_size + j - 8] = out[i, col_size + j - 8] / sum_value

softmax_cpu.kind = "cpu_kernel"
find_max.kind = "cpu_internal"
find_sum.kind = "cpu_internal"
# avx_exp.kind = "cpu_internal"
# avx_exp_dumb.kind = "cpu_internal"
ir_module = module.ir_module()
return ir_module

# sum = _mm_add_ps(_mm256_extractf128_ps(vector, 0), _mm256_extractf128_ps(vector, 1));
# sum = _mm_hadd_ps(sum, sum);
# sum = _mm_hadd_ps(sum, sum);
# return _mm_cvtss_f32(sum);


# __m256 y = _mm256_permute2f128_ps(x, x, 1); // 8 5 3 6 8 5 3 6
# __m256 m1 = _mm256_max_ps(x, y); // 8 7 3 6 8 5 3 6
# __m256 m2 = _mm256_permute_ps(m1, 0b01001110); // swap 2, 3 and 0, 1, 3 6 8 7 8 5 3 6
# __m256 m3 = _mm256_max_ps(m1, m2); // 8 7 8 7 8 5 3 6
# __m256 m4 = _mm256_permute_ps(m3, 0b10110001); // 7 8 8 7 8 5 3 6
# __m256 m = _mm256_max_ps(m3, m4); // max elem will be available in all elements of m
11 changes: 0 additions & 11 deletions python/hidet/graph/ops/schedules/cpu/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion python/hidet/graph/ops/schedules/cuda/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def cuda_schedule_reduce_by_default(task: ReduceTask) -> IRModule:

x_dtype = task.inputs[0].ttype.dtype
accumulate_dtype = task.attrs['accumulate_dtype']

with FunctionBuilder(
name=task.name + '_grid', kind='cuda_kernel', grid_dim=grid_size, block_dim=block_size, label='reduce schedule'
) as fb:
Expand Down
Loading