Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Adding CPU support for matrix multiplication #250

Closed
wants to merge 105 commits into from
Closed
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
8d0aae4
now remember to backup...
BolinSNLHM Apr 25, 2023
8f8bcca
...
BolinSNLHM Apr 25, 2023
3b8f9c1
change 4x4 kernel to avx intrinsics
BolinSNLHM Apr 26, 2023
7536d30
added some type info
BolinSNLHM Apr 26, 2023
a4ef3e9
commit before changing the compilation command
BolinSNLHM Apr 26, 2023
088970c
now can compile with avx intrinsics
BolinSNLHM Apr 26, 2023
0d406a2
added 32x8 primitives for CPU
BolinSNLHM Apr 29, 2023
9916989
added O3 compiler option
BolinSNLHM Apr 29, 2023
edc5e67
...
BolinSNLHM Apr 29, 2023
9a0aa6a
added more primitives
BolinSNLHM Apr 29, 2023
db2c683
...
BolinSNLHM Apr 29, 2023
33b4451
slight modification of opt88 file
BolinSNLHM Apr 29, 2023
1ad9e7d
added 32x8 imports where necessary
BolinSNLHM Apr 29, 2023
46f1d63
modified two scratch files
BolinSNLHM Apr 29, 2023
3d69a5f
five2: quite some speedup compared to how little has been down in add…
BolinSNLHM Apr 29, 2023
1302698
..... fixed dumb error
BolinSNLHM Apr 29, 2023
f799531
..
BolinSNLHM Apr 29, 2023
b93b408
8x8 kernel: efficiency improved again
BolinSNLHM Apr 29, 2023
3bbc4bd
reordering: some improvements
BolinSNLHM Apr 30, 2023
b053dc5
reordering loop gets a slight boost
BolinSNLHM Apr 30, 2023
9ffa73f
working on packing: back up midway
BolinSNLHM Apr 30, 2023
b29d61a
commented out redundant codes
BolinSNLHM Apr 30, 2023
613e3e2
a version of packing that does not yield much benefit...
BolinSNLHM Apr 30, 2023
a1a6c5e
...
BolinSNLHM Apr 30, 2023
035dca8
fix conflicts
BolinSNLHM Apr 30, 2023
3be1845
resolved conflict
BolinSNLHM Apr 30, 2023
0372647
......
BolinSNLHM Apr 30, 2023
6c67af0
working on packing B: some bugs for now:
BolinSNLHM Apr 30, 2023
fb3ca73
still hasn't figured out packing of B... move to using pointer?
BolinSNLHM May 1, 2023
4982ddf
first version of packing works?
BolinSNLHM May 1, 2023
894ee8a
really strange behavior regarding those definitions...
BolinSNLHM May 1, 2023
47980ce
seems like there's benefit in setting MC large
BolinSNLHM May 2, 2023
578b925
seems like aligning didn't do much
BolinSNLHM May 2, 2023
d0ba954
performance still not satisfactory yet; try to handle general case fo…
BolinSNLHM May 2, 2023
01a33e3
working on general: now at least in the work-in-progress the nice siz…
BolinSNLHM May 2, 2023
9d441a6
finally support for arbitrary size...
BolinSNLHM May 3, 2023
79c1c09
...
BolinSNLHM May 3, 2023
4da2612
working on refactoring; backup
BolinSNLHM May 3, 2023
e9132ff
first version of refactoring macrokernel
BolinSNLHM May 3, 2023
4d487be
what... segfault for only one case after refactoring
BolinSNLHM May 3, 2023
d7ba0f1
finished refactoring macro-kernel
BolinSNLHM May 3, 2023
c8faf37
refactored macro-kernel
BolinSNLHM May 3, 2023
273c0fd
why is it slower after refactoring??
BolinSNLHM May 3, 2023
5fe11a3
finished refactoring out the micro-kernel
BolinSNLHM May 3, 2023
a28d300
little details
BolinSNLHM May 3, 2023
12faa70
change MC to 2048
BolinSNLHM May 3, 2023
aec8b5f
...
BolinSNLHM May 3, 2023
c950603
10x8 does not work so well
BolinSNLHM May 3, 2023
a461637
6x16 really makes a difference
BolinSNLHM May 3, 2023
4596028
Merge branch 'hidet-org:main' into main
BolinSNLHM May 4, 2023
5a5f8f9
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM May 4, 2023
d25195a
start working on parallel
BolinSNLHM May 4, 2023
e416720
start workng on parallel
BolinSNLHM May 4, 2023
b9fc9d4
so far the best got...
BolinSNLHM May 4, 2023
e2b34c7
first try... need to experiment more
BolinSNLHM May 4, 2023
b74cbc6
... play with nthreads, go to paper
BolinSNLHM May 4, 2023
66cb61b
nthreads=24 currently promising
BolinSNLHM May 4, 2023
35821ff
stop playing with block sizes for now...
BolinSNLHM May 4, 2023
db3fb2a
exploring parallelizing the third loop
BolinSNLHM May 4, 2023
61a3a44
Merge branch 'hidet-org:main' into main
BolinSNLHM May 6, 2023
8243937
...
BolinSNLHM May 6, 2023
bd872ff
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM May 6, 2023
58f6edc
Merge branch 'main' into bolin
BolinSNLHM May 6, 2023
99828d9
eliminate for loops
BolinSNLHM May 6, 2023
fc926f5
removed that parallelizing 3rd loop: seems like a bad idea for some r…
BolinSNLHM May 6, 2023
58b23ce
strange error; push for backup
BolinSNLHM May 8, 2023
afda10a
Merge branch 'hidet-org:main' into main
BolinSNLHM May 8, 2023
a4f2ca8
Merge branch 'hidet-org:main' into bolin
BolinSNLHM May 8, 2023
cc47d1b
finished debugging; seems like they ran slower than before?
BolinSNLHM May 9, 2023
64ac8a3
worked out the first version of the schedule template; the issue w/ o…
BolinSNLHM May 10, 2023
618b0c1
first benchmark...
BolinSNLHM May 10, 2023
4b33400
trying tvm
BolinSNLHM May 15, 2023
5b3e1e3
moving to the server
BolinSNLHM May 15, 2023
e141cf2
...
BolinSNLHM May 15, 2023
03c5ea2
some more trying files...
BolinSNLHM May 17, 2023
9340268
Merge branch 'hidet-org:main' into main
BolinSNLHM May 17, 2023
8e95190
commit before checking out to main...
BolinSNLHM May 21, 2023
bb03b68
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM May 21, 2023
5b0f01f
Merge branch 'main' into bolin
BolinSNLHM May 21, 2023
61cd1c7
...
BolinSNLHM May 21, 2023
3e4a16c
working on replicating the oneDNN ref impl in hidet script
BolinSNLHM May 21, 2023
7912abd
Merge branch 'hidet-org:main' into main
BolinSNLHM May 22, 2023
18278cd
Merge branch 'hidet-org:main' into main
BolinSNLHM May 23, 2023
aa2cc45
commit b4 pulling for pointer arithmetic
BolinSNLHM May 23, 2023
ef115e5
solving merge conflict
BolinSNLHM May 23, 2023
7d73e8e
.
BolinSNLHM May 23, 2023
529c07a
..
BolinSNLHM May 23, 2023
1cb9bd6
Merge branch 'hidet-org:main' into main
BolinSNLHM May 23, 2023
6fd08a0
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM May 23, 2023
435a401
.
BolinSNLHM May 23, 2023
771c15b
Merge branch 'hidet-org:main' into main
BolinSNLHM May 23, 2023
2d8f8bd
..
BolinSNLHM May 23, 2023
1be5c12
Merge branch 'main' of github.com:BolinSNLHM/hidet into main
BolinSNLHM May 23, 2023
888a285
Merge branch 'main' into bolin
BolinSNLHM May 23, 2023
0b3e45a
.
BolinSNLHM May 24, 2023
54cd1b6
changed codegen to use dynamic
BolinSNLHM May 25, 2023
4424c7d
I should try smaller blocks?
BolinSNLHM May 25, 2023
087eae1
still something wrong with packing with pointer arithmetics...
BolinSNLHM May 26, 2023
ef36d60
.
BolinSNLHM May 26, 2023
322a082
.
BolinSNLHM May 26, 2023
9b46a2d
.
BolinSNLHM May 26, 2023
7a94c2d
deleting
BolinSNLHM May 26, 2023
e3210ab
deleting
BolinSNLHM May 26, 2023
2af5bbf
cleanup
BolinSNLHM May 27, 2023
df0158f
lint
BolinSNLHM May 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -204,3 +204,4 @@ build-release

# intermediate files
/gallery/**/*.json
/python/opt9.py
8 changes: 6 additions & 2 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
@@ -121,13 +121,13 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
# optimize host side code via -O3
'-O3',
# enable openmp support for cpu kernels
'-Xcompiler -fopenmp',
'-Xcompiler -fopenmp,-fPIC,-m64,-mavx2,-march=native,-O3,-funroll-loops,-ffast-math',
# the target PTX and SASS version.
'-gencode arch=compute_{cc},code=sm_{cc}'.format(cc=cc_code),
# allow ptxas (PTX assembler) to output information like register/smem usage.
'--ptxas-options=-v',
# compile into position independent code.
'--compiler-options -fPIC',
# '--compiler-options -fPIC,-m64,-mavx2,-march=native, -O3',
# embed the line information into the binary, allow Nsight Compute to get the source code for profiling.
'-lineinfo',
# ftz=true and prec-div=false for fast math
@@ -184,6 +184,10 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
*['-L{}'.format(library_dir) for library_dir in self.library_dirs],
# apply -O3 optimization.
'-O3',
# support avx intrinsics
'-mavx2',
'-m64',
'-march=native',
# compile into position independent code.
'-fPIC',
# enable OpenMP.
68 changes: 67 additions & 1 deletion python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
@@ -441,7 +441,9 @@ def visit_ForStmt(self, stmt: ForStmt):
doc += NewLine() + '#pragma unroll'
elif stmt.attr.parallel:
if stmt.attr.parallel_threads:
doc += NewLine() + '#pragma omp parallel for num_threads({})'.format(stmt.attr.parallel_threads)
doc += NewLine() + '#pragma omp parallel for schedule(dynamic) num_threads({})'.format(
stmt.attr.parallel_threads
)
else:
doc += NewLine() + '#pragma omp parallel for'
doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') '
@@ -555,6 +557,8 @@ def visit_DataType(self, t: DataType):
'tfloat32': 'tfloat32_t',
'complex64': 'complex64_t',
'complex128': 'complex128_t',
'float32x4': '__m128',
'float32x8': '__m256',
}
return Text(scalar_type_map[t.name])

@@ -613,6 +617,8 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cuda/complex.h>') + NewLine()
doc += Text('#include <hidet/runtime/cuda/context.h>') + NewLine()

doc += Text('#include <immintrin.h>') + NewLine()

# nvcc use float to 'store' tfloat32 data
doc += Text('typedef float tfloat32_t;') + NewLine()
doc += Text('typedef __nv_bfloat16 bfloat16_t;') + NewLine()
@@ -684,9 +690,69 @@ def require_headers(self) -> Doc:
doc += Text('#include <hidet/runtime/cpu/float16.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/bfloat16.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu/complex.h>') + NewLine()
doc += Text('#include <immintrin.h>')
doc += NewLine()
return doc

def visit_ScalarType(self, t: DataType):
# float16, bfloat16 and tfloat32 are not supported on CPU yet
# https://moocaholic.medium.com/fp64-fp32-fp16-bfloat16-tf32-and-other-members-of-the-zoo-a1ca7897d407
scalar_type_map = {
'bool': 'bool',
'uint8': 'uint8_t',
'uint16': 'uint16_t',
'uint32': 'uint32_t',
'uint64': 'uint64_t',
'int8': 'int8_t',
'int16': 'int16_t',
'int32': 'int32_t',
'int64': 'int64_t',
'float16': 'half',
'float32': 'float',
'float64': 'double',
'bfloat16': 'bfloat16_t',
'tfloat32': 'float',
'float32x4': '__m128',
'float32x8': '__m256',
}
return Text(scalar_type_map[t.name])

def visit_IRModule(self, module: IRModule) -> Doc:
self.ir_module = module
doc = Doc()
# todo: only add necessary headers
doc += Text('#include <stdint.h>') + NewLine()
doc += Text('#include <hidet/runtime/cpu_context.h>') + NewLine()
doc += Text('#include <math.h>') + NewLine()
# float16 and bfloat16 emulation
doc += Text('#include <hidet/cpu/float16.h>') + NewLine()
doc += Text('#include <hidet/cpu/bfloat16.h>') + NewLine()

# Headers for avx intrinsics
doc += Text('#include <immintrin.h>') + NewLine()

if module.task is not None:
doc += '/*' + NewLine()
doc += str(module.task) + NewLine()
doc += '*/' + NewLine()

doc += Text('extern "C" {') + NewLine()

# add namespace to activate data type and function
doc += Text('using float16::Half;') + NewLine()
doc += Text('using bfloat16::BFloat16;') + NewLine()

# use typedef to map half and bfloat16 type
doc += Text('typedef Half half;') + NewLine()
doc += Text('typedef BFloat16 bfloat16_t;') + NewLine()

call_graph = CallGraph(module)
for node in call_graph.reversed_order:
doc += self(node.func) + NewLine()

doc += NewLine() + '}'
return doc

def visit_Function(self, func: Function) -> Doc:
self.namer.clear()

2 changes: 1 addition & 1 deletion python/hidet/ffi/runtime_api.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ class RuntimeAPI:
_register_callback = get_func('register_callback', [c_char_p, c_void_p], None)
_allocate_cuda_storage = get_func('allocate_cuda_storage', [c_uint64], c_uint64)
_free_cuda_storage = get_func('free_cuda_storage', [c_uint64], None)
_reset_symbol_table = get_func('reset_symbol_table', [], None)
# _reset_symbol_table = get_func('reset_symbol_table', [], None)
_get_symbol_value = get_func('get_symbol_value', [c_char_p], c_int32)
_set_symbol_value = get_func('set_symbol_value', [c_char_p, c_int32], None)

3 changes: 3 additions & 0 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,9 @@
from .definitions.fusion import fused_operator
from .definitions.special import barrier

from .definitions.matmul import matmul_x86
from .definitions.matmul import matmul_x86_onednn

from .definitions import utils

from . import schedules
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
from .conv3d_transpose import conv3d_transpose
from .matmul import batch_matmul, matmul

from .matmul import BatchMatmulOp, MatmulOp
from .matmul import BatchMatmulOp, MatmulOp, Matmulx86Op
from .conv2d import Conv2dOp
from .arithmetic import ErfOp, PowOp, AddOp, SubtractOp, MultiplyOp, DivideOp, WhereOp
from .compare import EqualOp
6 changes: 6 additions & 0 deletions python/hidet/graph/ops/definitions/matmul/__init__.py
Original file line number Diff line number Diff line change
@@ -12,3 +12,9 @@
from .matmul import matmul, MatmulOp, MatmulTask
from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask
from . import resolve

from .matmul_f32_x86 import matmul_x86
from .matmul_f32_x86_v2 import matmul_x86_onednn

from .matmul_f32_x86 import MatmulF32Taskx86, Matmulx86Op
from .matmul_f32_x86_v2 import MatmulF32Taskx86OneDNN, MatmulX86OneDNNOp
411 changes: 411 additions & 0 deletions python/hidet/graph/ops/definitions/matmul/matmul_f32_x86.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,9 @@
from .floats import float16, float32, float64, bfloat16, tfloat32
from .floats import f16, f32, f64, bf16, tf32
from .boolean import boolean
from .vector import float16x2, float32x4, float32x8
from .complex import complex64, complex128
from .vector import float16x2, float32x4
from .vector import f16x2, f32x4
from .vector import f16x2, f32x4, f32x8
from .promotion import promote_type
from .utils import dtype_to_numpy, finfo, iinfo

@@ -39,6 +39,7 @@
'complex64': complex64,
'complex128': complex128,
'float32x4': float32x4,
'float32x8': float32x8,
'float16x2': float16x2,
}

@@ -60,6 +61,7 @@
'c64': complex64,
'c128': complex128,
'f32x4': f32x4,
'f32x8': f32x8,
'f16x2': f16x2,
}

2 changes: 2 additions & 0 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
@@ -72,7 +72,9 @@ def max_value(self):


float32x4 = VectorType(float32, 4)
float32x8 = VectorType(float32, 8)
float16x2 = VectorType(float16, 2)

f32x4 = float32x4
f32x8 = float32x8
f16x2 = float16x2
2 changes: 2 additions & 0 deletions python/hidet/ir/primitives/__init__.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,8 @@

# cpu primitive functions
from . import cpu
from .cpu import avx_f32x4_store, avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_setzero
from .cpu import avx_free, avx_malloc

# cuda primitive functions and variables
from . import cuda
4 changes: 4 additions & 0 deletions python/hidet/ir/primitives/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -10,3 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import math

from .avx import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero
from .avx import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero
from .avx import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc
104 changes: 104 additions & 0 deletions python/hidet/ir/primitives/cpu/avx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from hidet.ir.expr import Expr, Call
from hidet.ir.type import FuncType, VoidType, PointerType
from hidet.ir.primitives.func import register_primitive_function
from hidet.utils import initialize
from hidet.ir.primitives.func import call_primitive_func


@initialize()
def register_primitive_functions():
functions = [
('avx_x86_float32x4_broadcast', '_mm_broadcast_ss', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_fmadd', '_mm_fmadd_ps', FuncType(['float32x4', 'float32x4', 'float32x4'], 'float32x4')),
('avx_x86_float32x4_load', '_mm_loadu_ps', FuncType([PointerType('float32')], 'float32x4')),
('avx_x86_float32x4_store', '_mm_storeu_ps', FuncType([PointerType('float32'), 'float32x4'], VoidType())),
('avx_x86_float32x4_setzero', '_mm_setzero_ps', FuncType([], 'float32x4')),
('avx_x86_float32x8_broadcast', '_mm256_broadcast_ss', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_fmadd', '_mm256_fmadd_ps', FuncType(['float32x8', 'float32x8', 'float32x8'], 'float32x8')),
('avx_x86_float32x8_load', '_mm256_loadu_ps', FuncType([PointerType('float32')], 'float32x8')),
('avx_x86_float32x8_store', '_mm256_storeu_ps', FuncType([PointerType('float32'), 'float32x8'], VoidType())),
('avx_x86_float32x8_setzero', '_mm256_setzero_ps', FuncType([], 'float32x8')),
('avx_x86_malloc', '_mm_malloc', FuncType(['uint64', 'uint64'], PointerType(VoidType()))),
('avx_x86_free', '_mm_free', FuncType([PointerType(VoidType())], VoidType())),
('x86_memset', 'memset', FuncType([PointerType(VoidType()), 'int32', 'uint64'], PointerType(VoidType()))),
(
'x86_memcpy',
'memcpy',
FuncType([PointerType(VoidType()), PointerType(VoidType()), 'uint64'], PointerType(VoidType())),
),
]
for name, codegen_name, func_type in functions:
register_primitive_function(name=name, func_or_type=func_type, codegen_name=codegen_name)


def aligned_alloc(alignment: Union[int, Expr], size: Union[int, Expr]):
return call_primitive_func('aligned_alloc', [alignment, size])


def x86_memcpy(dst: Expr, src: Expr, num: Union[Expr, int]) -> Call:
return call_primitive_func('x86_memcpy', [dst, src, num])


def x86_memset(dst: Expr, val: Union[int, Expr], num: Union[Expr, int]) -> Call:
return call_primitive_func('x86_memset', [dst, val, num])


def avx_malloc(size: Union[Expr, int], align: Union[Expr, int]) -> Call:
return call_primitive_func('avx_x86_malloc', [size, align])


def avx_free(p: Expr) -> Call:
return call_primitive_func('avx_x86_free', [p])


def avx_f32x4_setzero() -> Call:
return call_primitive_func('avx_x86_float32x4_setzero', [])


def avx_f32x8_setzero() -> Call:
return call_primitive_func('avx_x86_float32x8_setzero', [])


def avx_f32x4_broadcast(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_broadcast', [addr])


def avx_f32x8_broadcast(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_broadcast', [addr])


def avx_f32x4_fmadd(a: Expr, b: Expr, c: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_fmadd', [a, b, c])


def avx_f32x8_fmadd(a: Expr, b: Expr, c: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_fmadd', [a, b, c])


def avx_f32x4_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_load', [addr])


def avx_f32x8_load(addr: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_load', [addr])


def avx_f32x4_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x4_store', [addr, src])


def avx_f32x8_store(addr: Expr, src: Expr) -> Call:
return call_primitive_func('avx_x86_float32x8_store', [addr, src])
3 changes: 3 additions & 0 deletions python/hidet/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,9 @@
spatial = row_spatial
repeat = row_repeat

ConstExpr = Union[Expr, int]


# def var_of_function(func: Function) -> Var:
# # pylint: disable=import-outside-toplevel
# from hidet.lang.script import ScriptModuleContext
9 changes: 9 additions & 0 deletions python/hidet/lang/avx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Union, Optional, Sequence
from hidet.ir.type import DataType, tensor_type
from hidet.ir.expr import Expr
from hidet.ir.stmt import DeclareScope
from hidet.ir.layout import DataLayout

from hidet.ir.primitives.cpu import avx_f32x4_broadcast, avx_f32x4_fmadd, avx_f32x4_load, avx_f32x4_store, avx_f32x4_setzero
from hidet.ir.primitives.cpu import avx_f32x8_broadcast, avx_f32x8_fmadd, avx_f32x8_load, avx_f32x8_store, avx_f32x8_setzero
from hidet.ir.primitives.cpu import avx_free, avx_malloc, x86_memcpy, x86_memset, aligned_alloc