forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cinn(py-dsl): parse schedule of python dsl (PaddlePaddle#57981)
拆分新特性:CINN Python DSL, 主PR和单测见:PaddlePaddle#56393 此PR只负责 解析python dsl中的schedule定义
- Loading branch information
Showing
12 changed files
with
951 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import ast | ||
|
||
from cinn.schedule import IRSchedule | ||
|
||
from .expr_executor import ExprExecutor, exec_assign | ||
from .utils import ( | ||
VariableTable, | ||
is_node_parsed_in_schedule, | ||
node_is_schedule_block_context, | ||
) | ||
|
||
|
||
class ScheduleCodeGenerator(ast.NodeVisitor): | ||
""" | ||
Convert python ast to CINN Lower Level IR, | ||
containing only the semantics of the schedule part | ||
""" | ||
|
||
def __init__(self, fn, cinn_llir_func): | ||
self.fn = fn | ||
self.cinn_llir_func = cinn_llir_func | ||
self.scheduler = IRSchedule.make(self.cinn_llir_func) | ||
self.variable_table = VariableTable() | ||
self.global_variable_table = VariableTable() | ||
# Set the schedule-related variable to global | ||
self.extra_scope = { | ||
"ScheduleBlockVariable": ScheduleBlockVariable, | ||
"scheduler": self.scheduler, | ||
} | ||
self.loop_var_stack = [] | ||
self.block_stack = [] | ||
self.sch_block_tmp_var_name = "__CINN_SCHEDULE_BLOCK_VAR_NAME__" | ||
self.tmp_var_count = 1 | ||
|
||
def parse(self): | ||
with self.variable_table, self.global_variable_table: | ||
ast_node = self.fn.parse() | ||
for k, v in self.fn.scope.items(): | ||
self.variable_table.add(k, v) | ||
for k, v in self.extra_scope.items(): | ||
self.variable_table.add(k, v) | ||
self.visit(ast_node) | ||
return self.cinn_llir_func | ||
|
||
def visit_For(self, node): | ||
assert isinstance( | ||
node.target, ast.Name | ||
), "Current only support range() to make ForLoop" | ||
with self.variable_table: | ||
self.loop_var_stack.append(node.target) | ||
self.generic_visit(node) | ||
self.loop_var_stack.pop() | ||
|
||
def visit_compound_statement(self, stmts): | ||
for stmt in stmts: | ||
self.visit(stmt) | ||
|
||
def visit_With(self, node): | ||
with self.variable_table: | ||
for item in node.items: | ||
if isinstance( | ||
item.context_expr, ast.Call | ||
) and not node_is_schedule_block_context(item.context_expr): | ||
continue | ||
# 1. replace ScheduleBlockContext to ScheduleBlockVariable | ||
sch_ctx_node = item.context_expr | ||
sch_block_node = ast.copy_location( | ||
ast.Call( | ||
func=ast.Name( | ||
id="ScheduleBlockVariable", ctx=ast.Load() | ||
), | ||
args=sch_ctx_node.args, | ||
keywords=[], | ||
starargs=None, | ||
kwargs=None, | ||
), | ||
item.context_expr, | ||
) | ||
item.context_expr = sch_block_node | ||
|
||
# 2. store ScheduleBlockVariable node | ||
sch_block = ExprExecutor(self.variable_table.get()).exec( | ||
item.context_expr | ||
) | ||
if item.optional_vars is None: | ||
tmp_var_name = self.sch_block_tmp_var_name + str( | ||
self.tmp_var_count | ||
) | ||
sch_block_var_node = ast.Name( | ||
id=tmp_var_name, ctx=ast.Store() | ||
) | ||
item.optional_vars = sch_block_var_node | ||
local_var_table = exec_assign( | ||
target=item.optional_vars, source=sch_block | ||
) | ||
# 3. Set the block's loop to its attritbute | ||
sch_block.set_scheduler(self.scheduler) | ||
self.block_stack.append(sch_block) | ||
for k, v in local_var_table.items(): | ||
self.variable_table.add(k, v) | ||
self.global_variable_table.add(k, v) | ||
for loop_var in self.loop_var_stack: | ||
loop_var_value = ast.Attribute( | ||
value=ast.Name(id=k, ctx=ast.Load()), | ||
attr=loop_var.id, | ||
ctx=ast.Load(), | ||
) | ||
loop_var_value = ExprExecutor( | ||
self.variable_table.get() | ||
).exec(loop_var_value) | ||
for_loop_var_table = exec_assign( | ||
loop_var, loop_var_value | ||
) | ||
for ( | ||
loop_var_k, | ||
loop_var_v, | ||
) in for_loop_var_table.items(): | ||
self.variable_table.add(loop_var_k, loop_var_v) | ||
|
||
body = self.visit_compound_statement(node.body) | ||
|
||
def visit_Assign(self, node): | ||
if isinstance(node.value, ast.Call) and is_node_parsed_in_schedule( | ||
node.value | ||
): | ||
sch_ret = self.exec_schedule_primitive(node.value) | ||
local_var_table = exec_assign( | ||
target=node.targets[0], source=sch_ret | ||
) | ||
for k, v in local_var_table.items(): | ||
self.variable_table.add(k, v) | ||
return | ||
self.generic_visit(node) | ||
|
||
def visit_Call(self, node): | ||
if isinstance(node, ast.Call) and is_node_parsed_in_schedule(node): | ||
self.exec_schedule_primitive(node) | ||
return | ||
|
||
def exec_schedule_primitive(self, node): | ||
# Reflect ScheduleBlockContext to ScheduleBlockVariable | ||
sch_primitive = node | ||
args = [ast.Name(id="scheduler", ctx=ast.Load()), *sch_primitive.args] | ||
sch_primitive.args = args | ||
all_variable_table = self.variable_table.get() | ||
for k, v in self.global_variable_table.get().items(): | ||
all_variable_table[k] = v | ||
sch_ret = ExprExecutor(all_variable_table).exec(node) | ||
|
||
return sch_ret | ||
|
||
|
||
class ScheduleBlockVariable: | ||
""" | ||
The parse Schedule process replaces ScheduleBlockContext with this class on the ast layer to improve schedule usability on the python layer | ||
For example, split a loop in c++ requires two steps: | ||
1. Gets the loop for the corresponding block: `x, y = sch.get_loops(block)` | ||
2. Apply schedule to loop: tx, xi = sch.split(x, [2]) | ||
This class allows you to directly manipulate the loop name of a block | ||
`sch.split(block.x, [2])` | ||
""" | ||
|
||
def __init__(self, name): | ||
self.name = name | ||
self.scheduler = None | ||
|
||
def set_scheduler(self, scheduler): | ||
self.scheduler = scheduler | ||
|
||
def __getattr__(self, k): | ||
if k == "block": | ||
return self.scheduler.get_block(self.name) | ||
else: | ||
name2loops = self.scheduler.get_name2loops_dict(self.name) | ||
return name2loops[k] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
from cinn import common, runtime | ||
from cinn.common import BFloat16, Bool, Float, Float16, Int, UInt | ||
|
||
|
||
class DataArray: | ||
""" | ||
Provides Python encapsulation of the cinn_buffer_t | ||
data interface in the CINN RunTime module. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
shape: list, | ||
dtype: common.Type = common.Float(32), | ||
data: runtime.cinn_buffer_t = None, | ||
) -> None: | ||
self.shape = shape | ||
self.dtype = dtype | ||
self.data = data | ||
|
||
def to_numpy(self): | ||
""" | ||
Convert DataArray to numpy array | ||
""" | ||
cinn_dtype_to_np_dtype = { | ||
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle | ||
BFloat16(): "uint16", | ||
BFloat16(): "bfloat16", | ||
Float16(): "float16", | ||
Float(32): "float32", | ||
Float(64): "float64", | ||
Int(8): "int8", | ||
Int(16): "int16", | ||
Int(32): "int32", | ||
Int(64): "int64", | ||
UInt(8): "uint8", | ||
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle | ||
# "UInt(16): uint16" | ||
UInt(32): "uint32", | ||
UInt(64): "uint64", | ||
Bool(): "bool", | ||
} | ||
for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items(): | ||
if isinstance(self.dtype, cinn_dtype): | ||
np_arr = np.empty(self.shape, np_dtype) | ||
assert np_arr.flags["C_CONTIGUOUS"] | ||
self.data.copy_to(np_arr) | ||
return np_arr | ||
|
||
raise TypeError(f"no support {self._dtype} in CINN") | ||
|
||
@staticmethod | ||
def from_numpy(np_array, target=common.DefaultHostTarget()): | ||
""" | ||
Create DataArray form numpy array | ||
""" | ||
assert isinstance(np_array, np.ndarray) | ||
data = runtime.cinn_buffer_t(np_array, target) | ||
dtype_np_to_common = { | ||
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle | ||
"uint16": BFloat16(), | ||
"bfloat16": BFloat16(), | ||
"float16": Float16(), | ||
"float32": Float(32), | ||
"float64": Float(64), | ||
"int8": Int(8), | ||
"int16": Int(16), | ||
"int32": Int(32), | ||
"int64": Int(64), | ||
"uint8": UInt(8), | ||
# numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle | ||
# "uint16": UInt(16), | ||
"uint32": UInt(32), | ||
"uint64": UInt(64), | ||
"bool": Bool(), | ||
} | ||
dtype_np = str(np_array.dtype).split(".")[-1] | ||
assert str(dtype_np) in dtype_np_to_common, ( | ||
str(dtype_np) + " not support in CINN" | ||
) | ||
assert dtype_np in dtype_np_to_common.keys() | ||
|
||
return DataArray(np_array.shape, dtype_np_to_common[dtype_np], data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from cinn import ir, lang, to_cinn_llir | ||
from cinn.runtime.data_array import DataArray | ||
|
||
|
||
def test_call_extern(): | ||
@to_cinn_llir | ||
def call_sinh(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))): | ||
for i1 in range(1): | ||
for j1 in range(4): | ||
for k1 in range(256): | ||
with ir.ScheduleBlockContext("init") as init: | ||
vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) | ||
B[vi, vj, vk] = lang.call_extern( | ||
"sinh", [A[vi, vi, vj, vk]], {} | ||
) | ||
|
||
str(call_sinh) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_call_extern() |
Oops, something went wrong.