Skip to content

Commit

Permalink
fix(CINN-LLIR): Parse ast.Assign in a more complete way
Browse files Browse the repository at this point in the history
  • Loading branch information
6clc committed Aug 28, 2023
1 parent cd3be0b commit da6b637
Showing 1 changed file with 89 additions and 33 deletions.
122 changes: 89 additions & 33 deletions python/cinn/compiler/compute_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import ast
from typing import Union

from cinn import ir

Expand All @@ -27,8 +28,7 @@ def __init__(self, function_name, inputs_signature):
self.function_name = function_name
self.inputs_signature = inputs_signature
self.cinn_llir_func = None
self.left_value_scope = {}
self.local_variables = {}
self.variables_table = {}

def visit_FunctionDef(self, node) -> None:
"""
Expand All @@ -44,6 +44,7 @@ def visit_FunctionDef(self, node) -> None:
# 1. Construct args of function
llir_args = []
for i, arg_name in enumerate(arg_names):
# Obj of Argument is ir::Buffer
if hasattr(self.inputs_signature[i], "dtype"):
llir_value = ir._Buffer_.make(
"_" + arg_name, self.inputs_signature[i].dtype
Expand All @@ -54,15 +55,21 @@ def visit_FunctionDef(self, node) -> None:
tensor_shape = [
ir.Expr(dim) for dim in self.inputs_signature[i].shape
]

# The computational logic of CINN is implemented through Tensor,
# so ir::_Tensor_ is stored in local variables
llir_value = ir._Tensor_.make(
arg_name,
self.inputs_signature[i].dtype,
tensor_shape,
tensor_shape,
)
# Obj of Argument is ir::Var
else:
llir_value = ir.Var(arg_name)
llir_args.append(ir.Argument(llir_value))
# The computational logic of CINN is implemented through Expr<Var>,
# so ir::Expr is stored in local variables
llir_value = ir.Expr(llir_value)
self.set_value(arg_name, llir_value)

Expand All @@ -87,14 +94,12 @@ def visit_compound_statement(self, stmts):

def visit_arguments(self, node):
arg_names = []
# Just get the name of the arg,
# the properties of the arg are already stored in JIT Function.
for arg in node.args:
arg_names += [self.visit(arg)]
arg_names += arg.arg
return arg_names

def visit_arg(self, node):
ast.NodeVisitor.generic_visit(self, node)
return node.arg

def visit_For(self, node) -> ir.Expr:
"""
parse CINN Low Level IR For.
Expand All @@ -105,6 +110,7 @@ def visit_For(self, node) -> ir.Expr:
Returns:
ir.Expr, Points to the Expr of ir::ExprNode<For>
"""
# 1. Parse the iter of the For loop
iter_args = [self.visit(arg) for arg in node.iter.args]
assert (
len(iter_args) <= 2
Expand All @@ -113,12 +119,15 @@ def visit_For(self, node) -> ir.Expr:
ast_extent = iter_args[1] if len(iter_args) > 1 else iter_args[0]

# TODO(6clc): support sub region's local variable
# AS code in `visit_FunctionDef`, store ir::Expr in local variables
llir_var = ir.Var(node.target.id)
llir_var_expr = ir.Expr(llir_var)
self.set_value(node.target.id, llir_var_expr)

llir_for_min = ir.Expr(ast_min)
llir_for_extent = ir.Expr(ast_extent)

# 2. Parse the body of the For loop
llir_for_body = self.visit_compound_statement(node.body)
llir_for_body = ir.Block.make(llir_for_body)
for_expr = ir.For.make(
Expand All @@ -127,30 +136,25 @@ def visit_For(self, node) -> ir.Expr:
return for_expr

def visit_Name(self, node):
# Store Node
if type(node.ctx) == ast.Store:
if node.id in self.local_variables:
return self.local_variables[node.id]
if node.id in self.variables_table:
return self.variables_table[node.id]
return node.id
# Load Node
assert (
node.id in self.local_variables
node.id in self.variables_table
), f"{node.id} is not defined in context"
return self.local_variables[node.id]

def visit_BinOp(self, node):
cinn_tensor_l, indexs_l = self.visit(node.left)
lhs = ir.Load.make(cinn_tensor_l, indexs_l)
cinn_tensor_r, indexs_r = self.visit(node.right)
rhs = ir.Load.make(cinn_tensor_r, indexs_r)
ast2cinn = {ast.Add: ir.Add}
return ast2cinn[ast.Add].make(lhs, rhs)
return self.variables_table[node.id]

def visit_Subscript(self, node):
lhs_tensor = self.visit(node.value)
idxs = [
expr_tensor = self.visit(node.value)
indices = [
self.visit(node.slice),
]
return lhs_tensor.Expr(), idxs
if type(node.ctx) == ast.Load:
return ir.Load.make(expr_tensor, indices)
return expr_tensor, indices

def visit_Tuple(self, node):
args = [self.visit(x) for x in node.elts]
Expand All @@ -170,21 +174,73 @@ def visit_Assign(self, node):
ir.Expr, Points to the Expr of ir::ExprNode<Store>
"""

_names = []
for target in node.targets:
_names += [self.visit(target)]
assert (
len(_names) == 1
len(node.targets) == 1
), "Unsupport targets is a \
list of nodes, like 'a, b = c'"
names = _names[0]
value = self.visit(node.value)
list of nodes, like 'a = b = c'"
lhs = node.targets[0]

return ir.Store.make(names[0], value, names[1])
# 1 parse RHS
rhs_expr = self.eval_expression(node.value)

def set_value(self, name, value):
self.left_value_scope[name] = value
self.local_variables[name] = value
# 2 parse LHS
assert isinstance(
lhs, ast.Subscript
), f'Currently only tensor assignment expressions are supported. {lhs.value} is not a Tensor'
expr_tensor, expr_indices = self.visit(lhs)
return ir.Store.make(expr_tensor, rhs_expr, expr_indices)

def visit_Constant(self, node):
return ir.Expr(node.value)

def eval_expression(self, node):
"""
Parse Expr expression composed of AST nodes
"""
args = []
if isinstance(node, ast.BinOp):
args = [node.left, node.right]
elif isinstance(node, ast.UnaryOp):
args = [node.operand]
elif isinstance(node, ast.Compare):
assert (
len(node.ops) == 1
), "Only binary comparison symbols are supported. Expressions such as '1 <= a < 10' are not supported."
args = [node.left, *node.comparators]
elif isinstance(node, ast.BoolOp):
args = node.values
elif isinstance(node, ast.Call):
args = node.args
else:
raise NotImplementedError(
f'The parse data type: {node} is not currently supported'
)
for i, arg in enumerate(args):
args[i] = self.visit(arg)

ast2cinn = {
# Binary Op
ast.Add: ir.Add,
ast.Sub: ir.Sub,
ast.Mult: ir.Mul,
ast.Div: ir.Div,
ast.Mod: ir.Mod,
ast.And: ir.And,
ast.Or: ir.Or,
# Comparator Op
ast.Eq: ir.EQ,
ast.NotEq: ir.NE,
ast.Lt: ir.LT,
ast.LtE: ir.LE,
ast.Gt: ir.GT,
ast.GtE: ir.GE,
# Unary Op
ast.USub: ir.Minus,
ast.Not: ir.Not,
}
return ast2cinn[type(node.op)].make(*args)

def set_value(self, name, value: Union[ir.Tensor, ir.Var]):
if isinstance(value, ir.Tensor):
value = value.Expr()
self.variables_table[name] = value

0 comments on commit da6b637

Please sign in to comment.