Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Script] Add pointer arthematic #237

Merged
merged 5 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,14 @@ def visit_Constant(self, e: Constant):
return Text('"{}"'.format(e.value))
elif e.is_scalar():
return self.scalar_literal(e.value, e.type)
else:
assert isinstance(e.type, TensorType)
elif e.is_tensor():
dtype = e.type.dtype
items = [self.scalar_literal(v, dtype) for v in np.array(e.value).flatten()]
return '{' + doc_join(items, ', ') + '}'
elif isinstance(e.type, PointerType):
return '(' + self(e.type) + ')' + str(int(e.value))
else:
raise ValueError("invalid constant type: {}".format(e))

def visit_DeclareStmt(self, stmt: DeclareStmt):
doc = NewLine()
Expand Down
10 changes: 5 additions & 5 deletions python/hidet/ffi/packedfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ctypes import POINTER, Structure

import hidet.graph.frontend.torch
from hidet.ir.type import TypeNode, DataType, TensorType, PointerType, TensorPointerType
from hidet.ir.type import BaseType, DataType, TensorType, PointerType, TensorPointerType
from hidet.ir.expr import Constant
from .ffi import _LIB

Expand All @@ -34,7 +34,7 @@ class ArgTypeCode(Enum):
FLOAT16 = 4

@staticmethod
def from_type(type_node: TypeNode) -> ArgTypeCode:
def from_type(type_node: BaseType) -> ArgTypeCode:
if isinstance(type_node, DataType):
if type_node.name == 'int32':
return ArgTypeCode.INT32
Expand All @@ -54,7 +54,7 @@ class CPackedFunc(Structure):
_fields_ = [("num_args", c_int32), ("arg_types", c_int32_p), ("func_pointer", c_void_p)]


def make_c_packed_func(param_types: Sequence[TypeNode], c_func_pointer) -> CPackedFunc:
def make_c_packed_func(param_types: Sequence[BaseType], c_func_pointer) -> CPackedFunc:
type_codes = [ArgTypeCode.from_type(param_type).value for param_type in param_types]
n = len(type_codes)
num_args = c_int32(n)
Expand All @@ -64,8 +64,8 @@ def make_c_packed_func(param_types: Sequence[TypeNode], c_func_pointer) -> CPack


class PackedFunc:
def __init__(self, param_types: Sequence[TypeNode], c_func_pointer):
self.param_types: List[TypeNode] = list(param_types)
def __init__(self, param_types: Sequence[BaseType], c_func_pointer):
self.param_types: List[BaseType] = list(param_types)
self.c_packed_func: CPackedFunc = make_c_packed_func(param_types, c_func_pointer)

def convert_args(self, args: Sequence):
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/schedules/auto_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from typing import Union, List, Dict, Sequence, Tuple, Set, Optional

from hidet.ir.type import DataType, tensor_pointer_type, void_pointer
from hidet.ir.type import DataType, tensor_pointer_type, void_p
from hidet.ir.expr import TensorElement, Expr, Var, SymbolVar, Constant, scalar_var, convert, cast
from hidet.ir.stmt import Stmt, AssignStmt, ForStmt, DeclareStmt, BufferStoreStmt, AssertStmt
from hidet.ir.task import Task
Expand Down Expand Up @@ -232,7 +232,7 @@ def schedule_task(self, task: Task, device: str) -> IRModule:
# packed function arguments, packed_func(num_args: int32, arg_types: *int32, args: **void)
num_args = scalar_var('num_args', 'int32')
arg_types = Var('arg_types', ~int32)
args = Var('args', ~void_pointer())
args = Var('args', ~void_p)
fb.extend_params([num_args, arg_types, args])

# extract the packed arguments
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .node import Node
from .func import IRModule, Function
from .type import TypeNode, TensorType, DataType, FuncType, VoidType, PointerType, TensorPointerType
from .type import BaseType, TensorType, DataType, FuncType, VoidType, PointerType, TensorPointerType
from .type import data_type, tensor_type, tensor_pointer_type

from .expr import Expr, Var, Constant
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/ir/dialects/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from typing import Any, Tuple, Optional, Dict, ContextManager
from contextlib import ExitStack
from hidet.ir.node import Node
from hidet.ir.type import TypeNode
from hidet.ir.type import BaseType
from hidet.ir.expr import Expr, Constant, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, Equal, LessEqual
from hidet.ir.expr import BitwiseXor
from hidet.ir.expr import IfThenElse, LogicalAnd, LogicalOr, BinaryExpr


class PlaceholderExpr(Expr):
def __init__(self, required_type: Optional[TypeNode] = None, require_const=False, require_non_const=False):
def __init__(self, required_type: Optional[BaseType] = None, require_const=False, require_non_const=False):
super().__init__()
self.required_type: Optional[TypeNode] = required_type
self.required_type: Optional[BaseType] = required_type
self.require_const: bool = require_const
self.require_non_const: bool = require_non_const
assert not (self.require_const and self.require_non_const), 'require placeholder to be both const & non-const'
Expand Down
75 changes: 51 additions & 24 deletions python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import operator
import numpy as np
from .node import Node
from .type import TypeNode, TensorType, DataType, TensorPointerType, PointerType, FuncType, StringType
from .type import BaseType, TensorType, DataType, TensorPointerType, PointerType, FuncType, StringType
from .type import tensor_pointer_type, string_type, tensor_type, data_type

PyScalar = Union[bool, int, float, complex, str]
Expand Down Expand Up @@ -201,22 +201,43 @@ def _unary(cls, a): # pylint: disable=bad-staticmethod-argument
return cls(a)

@staticmethod
def _binary(cls, a, b): # pylint: disable=bad-staticmethod-argument
def _binary(cls, a: Expr, b: Expr): # pylint: disable=bad-staticmethod-argument
if not isinstance(a, Expr):
a = convert(a)
if not isinstance(b, Expr):
b = convert(b)
if isinstance(a, Constant) and isinstance(b, Constant):
from hidet.ir.dtypes import promote_type

value = operator_dict[cls](a.value, b.value)
if cls in [Equal, NotEqual, LessThan, LessEqual, LogicalAnd, LogicalOr]:
return constant(value, const_type='bool')
elif cls in [LeftShift, RightShift]:
return constant(value, a.type)
if a.type.is_data_type() and b.type.is_data_type():
value = operator_dict[cls](a.value, b.value)
if cls in [Equal, NotEqual, LessThan, LessEqual, LogicalAnd, LogicalOr]:
return constant(value, const_type='bool')
elif cls in [LeftShift, RightShift]:
return constant(value, a.type)
else:
return constant(value, promote_type(a.type, b.type))
elif a.type.is_pointer() and b.type.is_pointer():
if cls is Sub:
return constant(a.value - b.value, 'uint64')
elif cls in [Equal, NotEqual]:
return constant(operator_dict[cls](a.value, b.value), 'bool')
else:
raise ValueError('unknown binary operator {}'.format(cls))
elif a.type.is_pointer() and b.type.is_data_type():
return constant(a.value + b.value, a.type)
elif a.type.is_data_type() and b.type.is_pointer():
return constant(a.value + b.value, b.type)
elif a.type.is_string_type() and b.type.is_string_type():
if cls is Add:
return constant(a.value + b.value, a.type)
elif cls in [Equal, NotEqual]:
return constant(operator_dict[cls](a.value, b.value), 'bool')
else:
raise ValueError('unknown binary operator {}'.format(cls))
else:
return constant(value, promote_type(a.type, b.type))
elif isinstance(b, Constant):
raise ValueError('unknown binary operator {}'.format(cls))
elif isinstance(b, Constant) and b.type.is_data_type():
from hidet.ir.dtypes import promote_type
from hidet.ir.tools import infer_type

Expand Down Expand Up @@ -413,22 +434,24 @@ def __init__(self, var, value, body):


class Cast(Expr):
def __init__(self, expr, target_type: TypeNode):
def __init__(self, expr, target_type: BaseType):
self.expr: Expr = expr
self.target_type: TypeNode = target_type
self.target_type: BaseType = target_type

assert isinstance(target_type, TypeNode)
assert isinstance(target_type, BaseType)


class Constant(Expr):
# reuse commonly-used constant objects
_constant_pool: Dict[Tuple[Union[int, float, bool], str], Constant] = {}

def __init__(
self, value: Union[np.ndarray, float, int, complex, str], const_type: Union[DataType, StringType, TensorType]
self,
value: Union[np.ndarray, float, int, complex, str],
const_type: Union[DataType, StringType, TensorType, PointerType],
):
self.value: Union[np.ndarray, float, int, complex, str] = value
self.type: Union[DataType, StringType, TensorType] = const_type
self.type: Union[DataType, StringType, TensorType, PointerType] = const_type

def is_scalar(self) -> bool:
return isinstance(self.type, DataType)
Expand All @@ -448,9 +471,6 @@ def __float__(self):
def __bool__(self):
return bool(self.value)

def __str__(self):
return str(self.value)

def __complex__(self):
return complex(self.value)

Expand Down Expand Up @@ -504,7 +524,7 @@ def __init__(self, expr: Expr):
class Var(Expr):
id_clock = 0

def __init__(self, hint: Optional[str], type: TypeNode, name: Optional[str] = None):
def __init__(self, hint: Optional[str], type: BaseType, name: Optional[str] = None):
"""
A variable may have a hint, name, and id.

Expand All @@ -520,7 +540,7 @@ def __init__(self, hint: Optional[str], type: TypeNode, name: Optional[str] = No
"""
self.hint: Optional[str] = hint
self.name: Optional[str] = name
self.type: Union[TypeNode, TensorType, TensorPointerType, FuncType] = type
self.type: Union[BaseType, TensorType, TensorPointerType, FuncType] = type
self.id: int = self.new_id()

@staticmethod
Expand Down Expand Up @@ -556,7 +576,6 @@ def __init__(self, name: str, dtype: DataType):
ExprFloat16 = ExprFloat32 = ExprFloat64 = ExprBFloat16 = ExprTFloat32 = Expr
Int = Union[int, Expr]


"""
The following are the mapping from hidet expression class to corresponding python operator.
Used by compilation-time constant folding.
Expand Down Expand Up @@ -807,7 +826,7 @@ def bitwise_xor(*args: Union[Expr, int]) -> BitwiseXor:
return _chain_binary_op(BitwiseXor, args, 0)


def cast(v: Expr, dtype: Union[str, DataType, TypeNode]):
def cast(v: Expr, dtype: Union[str, DataType, BaseType]):
if not isinstance(v, Expr):
raise ValueError('Expect an expression, got {}'.format(type(v).__name__))

Expand Down Expand Up @@ -859,7 +878,7 @@ def is_constant(e: Union[Expr, PyScalar], *other: Union[Expr, PyScalar]) -> bool
return True


def constant(value, const_type) -> Constant:
def constant(value, const_type: Union[str, BaseType]) -> Constant:
from hidet.ir.dtypes import boolean

if const_type and isinstance(const_type, str):
Expand All @@ -880,10 +899,18 @@ def constant(value, const_type) -> Constant:
value = tuple(value)
else:
raise ValueError(f"Invalid data const_type {const_type}")
else:
elif isinstance(const_type, TensorType):
value = np.array(value)
elif isinstance(const_type, PointerType):
value = int(value)
elif isinstance(const_type, StringType):
value = str(value)
else:
raise ValueError(f"Invalid const_type {const_type}")

if (isinstance(value, int) and -128 <= value <= 128) or (isinstance(value, float) and value in [-1.0, 0.0, 1.0]):
if const_type.is_data_type() and (
(isinstance(value, int) and -128 <= value <= 128) or (isinstance(value, float) and value in [-1.0, 0.0, 1.0])
):
# pylint: disable=protected-access
if (value, const_type.name) not in Constant._constant_pool:
Constant._constant_pool[(value, const_type.name)] = Constant(value, const_type)
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/ir/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Dict, List, Union, Optional
import string
from hidet.ir.node import Node
from hidet.ir.type import TypeNode, FuncType
from hidet.ir.type import BaseType, FuncType
from hidet.ir.expr import Var, Call
from hidet.ir.stmt import Stmt

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, name: str, params, body, ret_type, kind: str, attrs=None):
assert isinstance(kind, str) and kind in ['cuda_device', 'cuda_kernel', 'host_kernel', 'packed_func']
self.params: List[Var] = params
self.body: Stmt = body
self.ret_type: TypeNode = ret_type
self.ret_type: BaseType = ret_type
# self.extern_vars: List[Var] = extern_vars if extern_vars else []
self.attrs: Dict[str, Union[int, float, str, Node]] = attrs if attrs else {}

Expand Down
4 changes: 2 additions & 2 deletions python/hidet/ir/primitives/cuda/mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def resolve_ldmatrix_func_name(num: int, shared_space_addr: bool = False, trans=

@initialize()
def register_ldmatrix_instructions():
from hidet.lang import script, u32, void_pointer, attrs, ref_u32
from hidet.lang import script, u32, void_p, attrs, ref_u32

for num in [1, 2, 4]:
for trans in [False, True]:
Expand All @@ -265,7 +265,7 @@ def register_ldmatrix_instructions():
inst_name = 'ldmatrix.sync.aligned.m8n8{num}{trans}{ss}.b16'.format(
num=f'.x{num}', trans='.trans' if trans else '', ss='.shared' if shared_space_addr else ''
)
smem_type = u32 if shared_space_addr else void_pointer
smem_type = u32 if shared_space_addr else void_p
if num == 1:
template = '{inst_name} {{%0}}, [%1];'.format(inst_name=inst_name)

Expand Down
6 changes: 5 additions & 1 deletion python/hidet/ir/tools/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def visit_Constant(self, e: Constant):
return 'ConstTensor({}, {})'.format(e.value.shape, e.type)
elif e.is_string():
return Text('"{}"'.format(str(e.value)))
else:
elif e.is_scalar():
dtype = e.type.name
if dtype == 'float32':
ret = '{}f'.format(float(e.value))
Expand All @@ -229,6 +229,10 @@ def visit_Constant(self, e: Constant):
else:
ret = '{}({})'.format(dtype, e.value)
return Text(ret)
elif isinstance(e.type, PointerType):
return Text('({}){}'.format(self(e.type), self(e.value)))
else:
raise NotImplementedError("Unknown constant type: {}".format(e.type))

def visit_DeclareStmt(self, stmt: DeclareStmt):
doc = NewLine() + Text('declare ') + self(stmt.var) + Text(': ') + self(stmt.var.type)
Expand Down
36 changes: 28 additions & 8 deletions python/hidet/ir/tools/type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.type import DataType, TensorType, FuncType, PointerType, TensorPointerType, data_type, tensor_pointer_type
from hidet.ir.type import tensor_type
from hidet.ir.type import tensor_type, BaseType
from hidet.ir.expr import BinaryExpr, Add, Sub, Multiply, Div, Mod, FloorDiv, Condition, LessThan, Equal, IfThenElse
from hidet.ir.expr import TensorSlice, LogicalNot, LogicalOr, LogicalAnd, LessEqual, Let, RightShift, LeftShift
from hidet.ir.expr import BitwiseAnd, Neg, NotEqual, BitwiseXor, Dereference, Reference, Address, BitwiseNot, BitwiseOr
Expand All @@ -35,15 +35,35 @@ def visit_Reference(self, e: Reference):
def visit_Binary(self, e: BinaryExpr):
from hidet.ir.utils.type_utils import numeric_promotion

a_dtype: DataType = self.visit(e.a)
b_dtype: DataType = self.visit(e.b)
a_type: BaseType = self.visit(e.a)
b_type: BaseType = self.visit(e.b)
if a_type.is_data_type() and b_type.is_data_type():
a_dtype: DataType = a_type.as_data_type()
b_dtype: DataType = b_type.as_data_type()

if isinstance(e, (Add, Sub, Multiply, Div, Mod, FloorDiv)):
return numeric_promotion(a_dtype, b_dtype)
elif isinstance(e, Condition):
return data_type('bool')
if isinstance(e, (Add, Sub, Multiply, Div, Mod, FloorDiv)):
return numeric_promotion(a_dtype, b_dtype)
elif isinstance(e, Condition):
return data_type('bool')
else:
raise NotImplementedError('Binary operator type infer {}'.format(type(e)))
elif a_type.is_pointer() and b_type.is_pointer():
if isinstance(e, Sub):
return data_type('int32')
else:
raise TypeError("Can only do pointer subtraction, got {}".format(type(e)))
elif a_type.is_pointer() and b_type.is_data_type():
if isinstance(e, (Sub, Add)):
return a_type
else:
raise TypeError("Can only pointer +/- integer, got {} {} {}".format(a_type, type(e), b_type))
elif a_type.is_data_type() and b_type.is_pointer():
if isinstance(e, Add):
return b_type
else:
raise TypeError("Can only integer + pointer, got {} {} {}".format(a_type, type(e), b_type))
else:
raise NotImplementedError('Binary op type infer {}'.format(type(e)))
raise NotImplementedError('Binary operator type infer {} {} {}'.format(a_type, type(e), b_type))

def visit_Neg(self, e: Neg):
return self(e.a)
Expand Down
Loading