Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cinn(py-dsl): add runtime module to python dsl #58009

Merged
merged 1 commit into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cinn

from ..runtime import CinnLowerLevelIrJit
from .compute_code_generator import ComputeCodeGenerator
Expand All @@ -31,6 +32,13 @@ def ast_to_llir(fn, inputs_signature):
return llir_schedule_generator.parse()


def llir_to_runtime_module(llir_func, target, function_name, arg_names):
cinn_builder = cinn.lang.Module.Builder(function_name, target)
cinn_builder.add_function(llir_func)
llir_module = cinn_builder.build()
return cinn.runtime.Module(llir_module, target, function_name, arg_names)


def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
if isinstance(fn, CinnLowerLevelIrJit):
llir_func = ast_to_llir(fn, jit_inputs_signature)
Expand All @@ -39,3 +47,9 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):

if just_convert:
return llir_func

rt_module = llir_to_runtime_module(
llir_func, kwargs["target"], fn.__name__, kwargs["arg_names"]
)

return rt_module
3 changes: 2 additions & 1 deletion python/cinn/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@
)

from .cinn_jit import CinnLowerLevelIrJit
from .module import Module

__all__ = ["CinnLowerLevelIrJit"]
__all__ = ["CinnLowerLevelIrJit", "Module"]
55 changes: 31 additions & 24 deletions python/cinn/runtime/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,39 @@ def to_numpy(self):
"""
Convert DataArray to numpy array
"""
cinn_dtype_to_np_dtype = {
np_dtype = "unk"
if self.dtype.is_bfloat16():
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
BFloat16(): "uint16",
BFloat16(): "bfloat16",
Float16(): "float16",
Float(32): "float32",
Float(64): "float64",
Int(8): "int8",
Int(16): "int16",
Int(32): "int32",
Int(64): "int64",
UInt(8): "uint8",
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
# "UInt(16): uint16"
UInt(32): "uint32",
UInt(64): "uint64",
Bool(): "bool",
}
for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items():
if isinstance(self.dtype, cinn_dtype):
np_arr = np.empty(self.shape, np_dtype)
assert np_arr.flags["C_CONTIGUOUS"]
self.data.copy_to(np_arr)
return np_arr
np_dtype = "uint16"
elif self.dtype.is_float16():
np_dtype = "float16"
elif self.dtype.is_float(32, common.Type.specific_type_t.UNK):
np_dtype = "float32"
elif self.dtype.is_float(64, common.Type.specific_type_t.UNK):
np_dtype = "float64"
elif self.dtype.is_int(8):
np_dtype = "int8"
elif self.dtype.is_int(16):
np_dtype = "int16"
elif self.dtype.is_int(32):
np_dtype = "int32"
elif self.dtype.is_int(64):
np_dtype = "int64"
elif self.dtype.is_uint(8):
np_dtype = "uint8"
elif self.dtype.is_uint(32):
np_dtype = "uint32"
elif self.dtype.is_uint(64):
np_dtype = "uint64"
elif self.dtype.is_bool():
np_dtype = "bool"
else:
raise TypeError(f"no support {self.dtype} in CINN")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"No support" -> "Not support" or "Unsupported".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be fixed in next PR


raise TypeError(f"no support {self._dtype} in CINN")
np_arr = np.empty(self.shape, np_dtype)
assert np_arr.flags["C_CONTIGUOUS"]
self.data.copy_to(np_arr)
return np_arr

@staticmethod
def from_numpy(np_array, target=common.DefaultHostTarget()):
Expand Down
37 changes: 37 additions & 0 deletions python/cinn/runtime/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cinn
from cinn import framework
from cinn.backends import Compiler


class Module:
def __init__(self, llir_module, target, fn_name, arg_names):
self.arg_names = arg_names
self.fn_name = fn_name
self.compiler = Compiler.create(target)
self.compiler.build(llir_module)
self._instruction = framework.Instruction(
target, None, [], arg_names, fn_name
)

def __call__(self, *args):
name2pod = {}
for i, name in enumerate(self.arg_names):
if isinstance(args[i], cinn.runtime.data_array.DataArray):
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i].data)
else:
name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i])

self._instruction.run(self.compiler, self.fn_name, name2pod)
68 changes: 68 additions & 0 deletions test/cinn/ir/test_llir_schedule_bind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test.cinn.utils.testing import assert_llir_equal

from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


def test_bind_reduce():
@to_cinn_llir
def reduce_sum(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))):
for i1 in range(1):
for j1 in range(4):
for k1 in range(256):
with ir.ScheduleBlockContext("init") as init:
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1])
B[vi, vj, vk] = 0.0
for l1 in range(512):
with ir.ScheduleBlockContext("B"):
sch.bind(i1, "blockIdx.x")
sch.bind(j1, "threadIdx.y")
sch.bind(k1, "threadIdx.x")
vi1, vj1, vk1, vl1 = ir.AxisMap(
"SSSR", [i1, j1, k1, l1]
)
B[vi1, vj1, vk1] = (
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1]
)

@to_cinn_llir
def reduce_sum_expected(
A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))
):
for i1 in range(1):
for j1 in range(4):
for k1 in range(256):
with ir.ScheduleBlockContext("init") as init:
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1])
B[vi, vj, vk] = 0.0
for l1 in range(512):
with ir.ScheduleBlockContext("B"):
vi1, vj1, vk1, vl1 = ir.AxisMap(
"SSSR", [i1, j1, k1, l1]
)
B[vi1, vj1, vk1] = (
B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1]
)
sch.bind(init.i1, "blockIdx.x")
sch.bind(init.j1, "threadIdx.y")
sch.bind(init.k1, "threadIdx.x")

assert_llir_equal(reduce_sum, reduce_sum_expected)


if __name__ == "__main__":
test_bind_reduce()
99 changes: 99 additions & 0 deletions test/cinn/ir/test_llir_schedule_for_kind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from test.cinn.utils.testing import assert_llir_equal

from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


# Current Python DSL cannot express the parallel `for`,
# only checks that it can be converted correctly
def test_elementwise_parallel():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("Y"):
i1, j1 = ir.AxisMap("SS", [i, j])
Y[i1, j1] = A[i1, j1] + 2.0
sch.parallel(A_block.i)

assert_llir_equal(elementwise_add, elementwise_add)


# Current Python DSL cannot express the vectorize `for`,
# only checks that it can be converted correctly
def test_elementwise_vectorize():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j0 in range(32):
for j1 in range(4):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1])
Y[i1, j1] = A[i1, j1] + 2.0
sch.vectorize(Y_block.j1, 1)

assert_llir_equal(elementwise_add, elementwise_add)


# Current Python DSL cannot express the unroll `for`,
# only checks that it can be converted correctly
def test_elementwise_unroll():
@to_cinn_llir
def elementwise_add(
X: DataArray((128, 128)),
Y: DataArray((128, 128)),
A: DataArray((128, 128)),
):
for i in range(128):
for j in range(128):
with ir.ScheduleBlockContext("A") as A_block:
i1, j1 = ir.AxisMap("SS", [i, j])
A[i1, j1] = X[i1, j1] * 2.0
for i in range(128):
for j0 in range(32):
for j1 in range(4):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1])
Y[i1, j1] = A[i1, j1] + 2.0
sch.unroll(Y_block.j1)

assert_llir_equal(elementwise_add, elementwise_add)


if __name__ == "__main__":
test_elementwise_parallel()
test_elementwise_vectorize()
test_elementwise_unroll()
57 changes: 57 additions & 0 deletions test/cinn/ir/test_llir_schedule_rfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from cinn import ir, to_cinn_llir
from cinn.runtime.data_array import DataArray
from cinn.schedule import IRSchedule as sch


def test_matmul():
@to_cinn_llir
def matmul(
A: DataArray((128, 128)),
B: DataArray((128, 128)),
C: DataArray((128, 128)),
):
for i0 in range(128):
for i1 in range(128):
with ir.ScheduleBlockContext("init"):
vi, vj = ir.AxisMap("SS", [i0, i1])
C[vi, vj] = 0.0
for i2_outer in range(4):
for i2_inner_outer in range(8):
for i2_inner_inner in range(4):
with ir.ScheduleBlockContext(
"compute"
) as Compute_block:
vi, vj, vk = ir.AxisMap(
"SSR",
[
i0,
i1,
i2_outer * 32
+ i2_inner_outer * 4
+ i2_inner_inner,
],
)
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
sch.rfactor(Compute_block.i2_inner_inner, 0)

# TODO(6clc): rfactor schedule rasie Error Message: iter_value not support complex reduce bindings
# assert_llir_equal(matmul, matmul)


if __name__ == "__main__":
test_matmul()
Loading