Skip to content

Commit

Permalink
avoid changing function attributes from outside
Browse files Browse the repository at this point in the history
  • Loading branch information
BolinSNLHM committed Nov 17, 2023
1 parent 656bbd0 commit ebcc78f
Showing 1 changed file with 17 additions and 32 deletions.
49 changes: 17 additions & 32 deletions python/hidet/graph/ops/matmul/matmul_f32_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from hidet.ir.library import tune
from hidet.graph.operator import Operator, Tensor
from hidet.graph.ops.utils import broadcast_indices
from hidet.lang import attrs


class MatmulF32Taskx86(Task):
Expand Down Expand Up @@ -123,12 +124,11 @@ def schedule_matmulf32_x86(self, MC=2016, NC=384, KC=560, ways=(1, 4, 2, 1)) ->

@hidet.script
def init_thr(sense: ~int32, arrived: ~int32, size: int32):
attrs.func_kind = 'cpu_internal'
for i in range(size):
sense[i] = 0
arrived[i] = 0

init_thr.kind = "cpu_internal"

# Helpers
packed_a_type = tensor_type('float32', layout=row_major(MC // MR, 1) * column_major(MR, KC))
packed_b_type = tensor_type('float32', layout=row_major(1, NC // NR) * row_major(KC, NR))
Expand All @@ -146,6 +146,7 @@ def init_thr(sense: ~int32, arrived: ~int32, size: int32):

@hidet.script
def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~int32, end: ~int32):
attrs.func_kind = "cpu_internal"
if n_way == 1:
start[0] = 0
end[0] = n
Expand Down Expand Up @@ -186,20 +187,18 @@ def thread_range_sub(n_way: int32, work_id: int32, n: int32, bf: int32, start: ~
end[0] += n_bf_left
end[0] = min(end[0], all_end)

thread_range_sub.kind = "cpu_internal"

@hidet.script
def thread_range_jrir(
work_id: int32, n_way: int32, n: int32, bf: int32, start: ~int32, end: ~int32, inc: ~int32
):
attrs.func_kind = "cpu_internal"
start[0] = work_id
end[0] = n
inc[0] = n_way

thread_range_jrir.kind = "cpu_internal"

@hidet.script
def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32:
attrs.func_kind = 'cpu_internal'
dim_left_now = dim - i
b_now = -1
if dim_left_now <= b_alg:
Expand All @@ -209,23 +208,20 @@ def determine_blocksize_f_sub(i: int32, dim: int32, b_alg: int32) -> int32:
assert b_now >= 0
return b_now

determine_blocksize_f_sub.kind = "cpu_internal"

@hidet.script
def not_edge(i: int32, n_iter: int32, n_left: int32) -> bool:
attrs.func_kind = 'cpu_internal'
return i != n_iter - 1 or n_left == 0

not_edge.kind = 'cpu_internal'

@hidet.script
def packa_index(work_id_loop5: int32, work_id_loop3: int32) -> int32:
attrs.func_kind = 'cpu_internal'
return work_id_loop5 * loop3_nways + work_id_loop3

packa_index.kind = 'cpu_internal'

# Thread barrier
@hidet.script
def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_threads: int32):
attrs.func_kind = 'cpu_internal'
if num_threads == 1:
return
orig_sense = cpu_atomic_load_n(barrier_sense, 0) # _ATOMIC_RELAXED
Expand All @@ -240,8 +236,6 @@ def thrcomm_barrier(barrier_sense: ~int32, barrier_threads_arrived: ~int32, num_
while cpu_atomic_load_n(barrier_sense, 2) == orig_sense: # _ATOMIC_ACQUIRE
pass

thrcomm_barrier.kind = 'cpu_internal'

@hidet.script
def micro_kernel(
a: packed_a_type,
Expand All @@ -252,6 +246,7 @@ def micro_kernel(
nsize: int32,
is_first: bool,
):
attrs.func_kind = 'cpu_internal'
c = as_tensor_pointer(c_ptr, dtype=float32, shape=[msize, nsize])
c0 = avx_f32x8_load(~c[0, 0])
c08 = avx_f32x8_load(~c[0, 8])
Expand Down Expand Up @@ -343,11 +338,6 @@ def micro_kernel(
packed_a_individual_height = min(MC, (m_size + MR - 1) // MR * MR)
packed_a_total_height = packed_a_individual_height * packed_a_buffers_needed

# packed_a_width = KC
# if packed_a_width > k_size:
# packed_a_width = k_size
# # pad this to be able to use the aligned version of the avx store
# packed_a_width = (packed_a_width + 8 - 1) // 8 * 8
packed_a_width = min(KC, (k_size + 8 - 1) // 8 * 8)

packed_a_total_size = packed_a_total_height * packed_a_width
Expand All @@ -369,6 +359,7 @@ def gemm_pack_a(
packed_a_buf: ~float32,
work_id_packa: int32,
):
attrs.func_kind = 'cpu_internal'
packed_a_tensor = as_tensor_pointer(
packed_a_buf,
float32,
Expand Down Expand Up @@ -463,6 +454,7 @@ def gemm_pack_b(
packed_b_buf: ~float32,
work_id_packb: int32,
):
attrs.func_kind = 'cpu_internal'
npanels_full_b = loop4_partition_b_width // NR
npanels_b_remainder = loop4_partition_b_width % NR

Expand Down Expand Up @@ -569,10 +561,6 @@ def gemm_pack_b(
packed_b_remaining_buf_curr += 1
zero_fill_col += 1

gemm_pack_b.kind = "cpu_internal"
gemm_pack_a.kind = "cpu_internal"
micro_kernel.kind = "cpu_internal"

@hidet.script
def gemm_macro(
packed_a: ~float32,
Expand All @@ -589,6 +577,7 @@ def gemm_macro(
work_id_macro: int32,
is_first: bool,
):
attrs.func_kind = 'cpu_internal'
comm_id_1st_loop = comm_id_macro % loop1_nthreads
work_id_1st_loop = comm_id_1st_loop // (loop1_nthreads // loop1_nways)

Expand Down Expand Up @@ -650,8 +639,6 @@ def gemm_macro(
i += ir_inc
j += jr_inc

gemm_macro.kind = "cpu_internal"

@hidet.script
def gemm_3rd_loop(
a: float32[m_size, k_size],
Expand All @@ -666,6 +653,7 @@ def gemm_3rd_loop(
is_first: bool,
work_id_5th_loop: int32,
):
attrs.func_kind = 'cpu_internal'
comm_id_macro = comm_id_3rd_loop % macro_nthreads
work_id_macro = comm_id_macro // (macro_nthreads // macro_nways)
work_id_packa = comm_id_macro
Expand Down Expand Up @@ -727,8 +715,6 @@ def gemm_3rd_loop(
)
ii += b_alg_loop3

gemm_3rd_loop.kind = "cpu_internal"

@hidet.script
def gemm_4th_loop(
a: float32[m_size, k_size],
Expand All @@ -739,6 +725,7 @@ def gemm_4th_loop(
comm_id_4th_loop: int32,
work_id_5th_loop: int32,
):
attrs.func_kind = 'cpu_internal'
i_loop4 = 0

comm_id_3rd_loop = comm_id_4th_loop % loop3_nthreads
Expand Down Expand Up @@ -805,8 +792,6 @@ def gemm_4th_loop(

i_loop4 += b_alg_loop4

gemm_4th_loop.kind = "cpu_internal"

@hidet.script
def gemm_5th_loop(
a: float32[m_size, k_size],
Expand All @@ -815,6 +800,7 @@ def gemm_5th_loop(
work_id_5th_loop: int32,
comm_id_5th_loop: int32,
):
attrs.func_kind = 'cpu_internal'
comm_id_4th_loop = comm_id_5th_loop % loop4_nthreads

loop5_my_start = -1
Expand All @@ -839,13 +825,12 @@ def gemm_5th_loop(
)
loop5_iter += b_alg_loop5

gemm_5th_loop.kind = 'cpu_internal'

################### Start of the main kernel ###################
@hidet.script
def matmul_kernel_x86_v3(
a: float32[m_size, k_size], b: float32[k_size, n_size], c: float32[m_size, n_size]
):
attrs.func_kind = 'cpu_kernel'

init_thr(packa_thrcomm_barrier_sense, packa_thrcomm_threads_arrived, loop3_nways)
init_thr(packb_thrcomm_barrier_sense, packb_thrcomm_barrier_threads_arrived, loop5_nways)
Expand All @@ -860,7 +845,7 @@ def matmul_kernel_x86_v3(
gemm_5th_loop(a, b, c, work_id_5th_loop, comm_id_5th_loop)

assert isinstance(matmul_kernel_x86_v3, hidet.ir.Function)
matmul_kernel_x86_v3.kind = "cpu_kernel"
# matmul_kernel_x86_v3.kind = "cpu_kernel"
ir_module = module.ir_module()
return ir_module

Expand Down

0 comments on commit ebcc78f

Please sign in to comment.