Skip to content

Commit

Permalink
cinn(py-dsl): parse schedule of python dsl (PaddlePaddle#57981)
Browse files Browse the repository at this point in the history
拆分新特性:CINN Python DSL, 主PR和单测见:PaddlePaddle#56393

此PR只负责 解析python dsl中的schedule定义
  • Loading branch information
6clc authored and jiahy0825 committed Oct 16, 2023
1 parent c454fa2 commit 6630b62
Show file tree
Hide file tree
Showing 12 changed files with 951 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/cinn/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..runtime import CinnLowerLevelIrJit
from .compute_code_generator import ComputeCodeGenerator
from .schedule_code_generator import ScheduleCodeGenerator


def ast_to_llir(fn, inputs_signature):
Expand All @@ -24,7 +25,10 @@ def ast_to_llir(fn, inputs_signature):
fn, function_name, inputs_signature
)
cinn_llir_func = llir_compute_generator.parse()
return cinn_llir_func

# 2. Parse CINN Schedule
llir_schedule_generator = ScheduleCodeGenerator(fn, cinn_llir_func)
return llir_schedule_generator.parse()


def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):
Expand All @@ -35,4 +39,3 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs):

if just_convert:
return llir_func
return llir_func
189 changes: 189 additions & 0 deletions python/cinn/compiler/schedule_code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast

from cinn.schedule import IRSchedule

from .expr_executor import ExprExecutor, exec_assign
from .utils import (
VariableTable,
is_node_parsed_in_schedule,
node_is_schedule_block_context,
)


class ScheduleCodeGenerator(ast.NodeVisitor):
"""
Convert python ast to CINN Lower Level IR,
containing only the semantics of the schedule part
"""

def __init__(self, fn, cinn_llir_func):
self.fn = fn
self.cinn_llir_func = cinn_llir_func
self.scheduler = IRSchedule.make(self.cinn_llir_func)
self.variable_table = VariableTable()
self.global_variable_table = VariableTable()
# Set the schedule-related variable to global
self.extra_scope = {
"ScheduleBlockVariable": ScheduleBlockVariable,
"scheduler": self.scheduler,
}
self.loop_var_stack = []
self.block_stack = []
self.sch_block_tmp_var_name = "__CINN_SCHEDULE_BLOCK_VAR_NAME__"
self.tmp_var_count = 1

def parse(self):
with self.variable_table, self.global_variable_table:
ast_node = self.fn.parse()
for k, v in self.fn.scope.items():
self.variable_table.add(k, v)
for k, v in self.extra_scope.items():
self.variable_table.add(k, v)
self.visit(ast_node)
return self.cinn_llir_func

def visit_For(self, node):
assert isinstance(
node.target, ast.Name
), "Current only support range() to make ForLoop"
with self.variable_table:
self.loop_var_stack.append(node.target)
self.generic_visit(node)
self.loop_var_stack.pop()

def visit_compound_statement(self, stmts):
for stmt in stmts:
self.visit(stmt)

def visit_With(self, node):
with self.variable_table:
for item in node.items:
if isinstance(
item.context_expr, ast.Call
) and not node_is_schedule_block_context(item.context_expr):
continue
# 1. replace ScheduleBlockContext to ScheduleBlockVariable
sch_ctx_node = item.context_expr
sch_block_node = ast.copy_location(
ast.Call(
func=ast.Name(
id="ScheduleBlockVariable", ctx=ast.Load()
),
args=sch_ctx_node.args,
keywords=[],
starargs=None,
kwargs=None,
),
item.context_expr,
)
item.context_expr = sch_block_node

# 2. store ScheduleBlockVariable node
sch_block = ExprExecutor(self.variable_table.get()).exec(
item.context_expr
)
if item.optional_vars is None:
tmp_var_name = self.sch_block_tmp_var_name + str(
self.tmp_var_count
)
sch_block_var_node = ast.Name(
id=tmp_var_name, ctx=ast.Store()
)
item.optional_vars = sch_block_var_node
local_var_table = exec_assign(
target=item.optional_vars, source=sch_block
)
# 3. Set the block's loop to its attritbute
sch_block.set_scheduler(self.scheduler)
self.block_stack.append(sch_block)
for k, v in local_var_table.items():
self.variable_table.add(k, v)
self.global_variable_table.add(k, v)
for loop_var in self.loop_var_stack:
loop_var_value = ast.Attribute(
value=ast.Name(id=k, ctx=ast.Load()),
attr=loop_var.id,
ctx=ast.Load(),
)
loop_var_value = ExprExecutor(
self.variable_table.get()
).exec(loop_var_value)
for_loop_var_table = exec_assign(
loop_var, loop_var_value
)
for (
loop_var_k,
loop_var_v,
) in for_loop_var_table.items():
self.variable_table.add(loop_var_k, loop_var_v)

body = self.visit_compound_statement(node.body)

def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and is_node_parsed_in_schedule(
node.value
):
sch_ret = self.exec_schedule_primitive(node.value)
local_var_table = exec_assign(
target=node.targets[0], source=sch_ret
)
for k, v in local_var_table.items():
self.variable_table.add(k, v)
return
self.generic_visit(node)

def visit_Call(self, node):
if isinstance(node, ast.Call) and is_node_parsed_in_schedule(node):
self.exec_schedule_primitive(node)
return

def exec_schedule_primitive(self, node):
# Reflect ScheduleBlockContext to ScheduleBlockVariable
sch_primitive = node
args = [ast.Name(id="scheduler", ctx=ast.Load()), *sch_primitive.args]
sch_primitive.args = args
all_variable_table = self.variable_table.get()
for k, v in self.global_variable_table.get().items():
all_variable_table[k] = v
sch_ret = ExprExecutor(all_variable_table).exec(node)

return sch_ret


class ScheduleBlockVariable:
"""
The parse Schedule process replaces ScheduleBlockContext with this class on the ast layer to improve schedule usability on the python layer
For example, split a loop in c++ requires two steps:
1. Gets the loop for the corresponding block: `x, y = sch.get_loops(block)`
2. Apply schedule to loop: tx, xi = sch.split(x, [2])
This class allows you to directly manipulate the loop name of a block
`sch.split(block.x, [2])`
"""

def __init__(self, name):
self.name = name
self.scheduler = None

def set_scheduler(self, scheduler):
self.scheduler = scheduler

def __getattr__(self, k):
if k == "block":
return self.scheduler.get_block(self.name)
else:
name2loops = self.scheduler.get_name2loops_dict(self.name)
return name2loops[k]
97 changes: 97 additions & 0 deletions python/cinn/runtime/data_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from cinn import common, runtime
from cinn.common import BFloat16, Bool, Float, Float16, Int, UInt


class DataArray:
"""
Provides Python encapsulation of the cinn_buffer_t
data interface in the CINN RunTime module.
"""

def __init__(
self,
shape: list,
dtype: common.Type = common.Float(32),
data: runtime.cinn_buffer_t = None,
) -> None:
self.shape = shape
self.dtype = dtype
self.data = data

def to_numpy(self):
"""
Convert DataArray to numpy array
"""
cinn_dtype_to_np_dtype = {
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
BFloat16(): "uint16",
BFloat16(): "bfloat16",
Float16(): "float16",
Float(32): "float32",
Float(64): "float64",
Int(8): "int8",
Int(16): "int16",
Int(32): "int32",
Int(64): "int64",
UInt(8): "uint8",
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
# "UInt(16): uint16"
UInt(32): "uint32",
UInt(64): "uint64",
Bool(): "bool",
}
for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items():
if isinstance(self.dtype, cinn_dtype):
np_arr = np.empty(self.shape, np_dtype)
assert np_arr.flags["C_CONTIGUOUS"]
self.data.copy_to(np_arr)
return np_arr

raise TypeError(f"no support {self._dtype} in CINN")

@staticmethod
def from_numpy(np_array, target=common.DefaultHostTarget()):
"""
Create DataArray form numpy array
"""
assert isinstance(np_array, np.ndarray)
data = runtime.cinn_buffer_t(np_array, target)
dtype_np_to_common = {
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
"uint16": BFloat16(),
"bfloat16": BFloat16(),
"float16": Float16(),
"float32": Float(32),
"float64": Float(64),
"int8": Int(8),
"int16": Int(16),
"int32": Int(32),
"int64": Int(64),
"uint8": UInt(8),
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle
# "uint16": UInt(16),
"uint32": UInt(32),
"uint64": UInt(64),
"bool": Bool(),
}
dtype_np = str(np_array.dtype).split(".")[-1]
assert str(dtype_np) in dtype_np_to_common, (
str(dtype_np) + " not support in CINN"
)
assert dtype_np in dtype_np_to_common.keys()

return DataArray(np_array.shape, dtype_np_to_common[dtype_np], data)
36 changes: 36 additions & 0 deletions test/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,40 @@ if(WITH_GPU)
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endforeach()

file(
GLOB CINN_RUNTIME_TEST
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"runtime/test_*.py")

foreach(runtime_test_name ${EXCLUDE_RUNTIME})
list(REMOVE_ITEM CINN_RUNTIME_TEST runtime/${runtime_test_name}.py)
endforeach()

foreach(runtime_test_name ${CINN_RUNTIME_TEST})
string(REGEX REPLACE ".py" "" runtime_test_name ${runtime_test_name})
add_test(
NAME ${runtime_test_name}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/${runtime_test_name}.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endforeach()

file(
GLOB CINN_IR_TEST
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"ir/test_*.py")

foreach(ir_test_name ${CINN_IR_TEST})
string(REGEX REPLACE ".py" "" ir_test_name ${ir_test_name})
add_test(
NAME ${ir_test_name}
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/${ir_test_name}.py
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endforeach()

endif()
36 changes: 36 additions & 0 deletions test/cinn/ir/test_llir_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from cinn import ir, lang, to_cinn_llir
from cinn.runtime.data_array import DataArray


def test_call_extern():
@to_cinn_llir
def call_sinh(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))):
for i1 in range(1):
for j1 in range(4):
for k1 in range(256):
with ir.ScheduleBlockContext("init") as init:
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1])
B[vi, vj, vk] = lang.call_extern(
"sinh", [A[vi, vi, vj, vk]], {}
)

str(call_sinh)


if __name__ == "__main__":
test_call_extern()
Loading

0 comments on commit 6630b62

Please sign in to comment.