From 8d755f73de0ebc109eb7dbccb825a9471a0e8811 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Sat, 5 Aug 2023 12:07:10 +0800 Subject: [PATCH 1/6] [Operator] Add clamp/isinf/any/all op, enhance where op (#343) --- .../frontend/torch/register_functions.py | 39 ++++++- .../graph/frontend/torch/register_methods.py | 10 ++ python/hidet/graph/frontend/torch/utils.py | 9 ++ python/hidet/graph/ops/__init__.py | 13 +-- python/hidet/graph/ops/arithmetic.py | 106 ++++++++++++++++-- python/hidet/graph/ops/reduce/reduce.py | 6 +- python/hidet/ir/primitives/__init__.py | 2 +- 7 files changed, 157 insertions(+), 28 deletions(-) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 250a3deb5..8df6af553 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -25,7 +25,7 @@ from hidet.runtime.device import Device from .interpreter import register_function, register_method from .interpreter import warnings -from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar +from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar, convert_to_scalar_if_possible Number = Union[int, float, bool] @@ -590,7 +590,7 @@ def addmm( @register_function(torch.where) -def where(condition: Tensor, x: Tensor, y: Tensor): +def where(condition: Tensor, x: Union[Tensor, Number], y: Union[Tensor, Number]): return ops.where(cond=condition, x=x, y=y) @@ -1069,3 +1069,38 @@ def zeros_like( hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype) if dtype else x.dtype return ops.full(x.shape, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero) + + +@register_function(torch.clamp) +def clamp( + x: Tensor, + min: Optional[Union[Tensor, Number]] = None, + max: Optional[Union[Tensor, Number]] = None, + *, + out: Optional[Tensor] = None, +) -> Tensor: + if out is not None: + raise NotImplementedError("hidet: does not support torch.clamp(..., out=...)") + + min = convert_to_scalar_if_possible(min) + max = convert_to_scalar_if_possible(max) + + if min is None and max is None: + return x + elif min is None: + if not isinstance(max, Tensor): + assert isinstance(max, (int, float, complex)) + max = ops.full([], value=max, dtype=x.dtype, device=x.device) + return ops.minimum(x, max) + elif max is None: + if not isinstance(min, Tensor): + assert isinstance(min, (int, float, complex)) + min = ops.full([], value=min, dtype=x.dtype, device=x.device) + return ops.maximum(x, min) + else: + return ops.clamp(x, min, max) + + +@register_function(torch.isinf) +def isinf(x: Tensor) -> Tensor: + return ops.isinf(x) diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index 07b25721a..a260bbe83 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -255,3 +255,13 @@ def tensor_repeat(self: Tensor, *sizes: int) -> Tensor: @register_method(torch.Tensor.detach) def tensor_detach(self: Tensor) -> Tensor: return self + + +@register_method(torch.Tensor.any) +def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor: + return ops.any(self, axis=dim, keepdims=keepdim) + + +@register_method(torch.Tensor.all) +def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor: + return ops.all(self, axis=dim, keepdims=keepdim) diff --git a/python/hidet/graph/frontend/torch/utils.py b/python/hidet/graph/frontend/torch/utils.py index 6d4bb72d7..9a2ef141f 100644 --- a/python/hidet/graph/frontend/torch/utils.py +++ b/python/hidet/graph/frontend/torch/utils.py @@ -285,3 +285,12 @@ def normalize_to_scalar(value: Union[Tensor, Expr, float, int, bool]) -> Union[E raise RuntimeError(f'Cannot convert tensor {value.signature()} to scalar') else: return value + + +def convert_to_scalar_if_possible(x: Union[Tensor, Expr, float, int, bool]) -> Optional[Union[Expr, float, int, bool]]: + if isinstance(x, Tensor): + if len(x.shape) == 0 and x.storage: + return x.item() + return None + else: + return x diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index f2644005e..a10466a6a 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -13,15 +13,8 @@ from .matmul import batch_matmul, matmul, matmul_x86 from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose -from .conv2d import ( - conv2d, - conv2d_channel_last, - conv2d_winograd, - conv2d_gemm, - conv2d_gemm_fp16, - conv2d_gemm_fp16_channel_last, - conv2d_gemm_image_transform, -) +from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 +from .conv2d import conv2d_gemm_fp16_channel_last, conv2d_gemm_image_transform from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm from .conv3d import conv3d, conv3d_gemm from .conv3d_transpose import conv3d_transpose @@ -38,7 +31,7 @@ from .arithmetic import floor, ceil, round, trunc, sqrt, rsqrt, pow, abs from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or -from .arithmetic import bitwise_xor, maximum, minimum +from .arithmetic import bitwise_xor, maximum, minimum, clamp from .arithmetic import isfinite, isinf, isnan, sign, where from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2 from .complex import real, imag, conj, make_complex diff --git a/python/hidet/graph/ops/arithmetic.py b/python/hidet/graph/ops/arithmetic.py index 9acc397c3..3e36c7bd4 100644 --- a/python/hidet/graph/ops/arithmetic.py +++ b/python/hidet/graph/ops/arithmetic.py @@ -13,14 +13,17 @@ from typing import List, Callable, Any, Union, Optional, Dict from hidet.ir import primitives -from hidet.ir import expr, dtypes +from hidet.ir import Var, expr, dtypes from hidet.ir.type import DataType -from hidet.ir.expr import Expr, Var, if_then_else from hidet.ir.tools import rewrite +from hidet.ir.expr import Expr, if_then_else, is_true from hidet.utils import prod, same_list from .utils import Task, Operator, Tensor, TensorNode, InverseMap, compute, input_like from .utils import broadcast_shape, broadcast_shapes, broadcast_indices +PyScalar = Union[int, float, bool] + + # In order for the subgraph rewrite of Composite Elementwise Operator to work, # we need to store the callable in an Operator object. But lambda cannot be pickled, # so we define auxiliary classes UnaryElementwiseOperation and BinaryElementwiseOperation @@ -117,7 +120,7 @@ def __init__(self, name: str, args: List[TensorNode], op: Callable[[Any], Any]): inverse_map={ v: InverseMap.identity(len(v_shape)) for v, v_shape in zip(args, shapes) - if prod(v_shape) == prod(out_shape) + if is_true(prod(v_shape) == prod(out_shape)) and len(v_shape) == len(out_shape) }, ) @@ -207,6 +210,15 @@ def __init__(self, x: Tensor, y: Tensor, op, name: str): ) +def get_dtype(scalar: Expr): + from hidet.ir.tools import infer_type + + inferred_type = infer_type(scalar) + if not isinstance(inferred_type, DataType): + raise TypeError(f'Expected scalar to be of type DataType, got {type(inferred_type)}') + return inferred_type + + class CompositeElementwiseOp(Operator): def __init__( self, @@ -238,37 +250,37 @@ def resolve_dtype(tensor_dtype: DataType, scalar_dtype: DataType) -> DataType: class AddScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v + dtype(scalar), attributes={'scalar': scalar}, name='adds') class SubScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v - dtype(scalar), attributes={'scalar': scalar}, name='subs') class RSubScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: dtype(scalar) - v, attributes={'scalar': scalar}, name='rsubs') class MultiplyScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v * dtype(scalar), attributes={'scalar': scalar}, name='muls') class DivideScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v / dtype(scalar), attributes={'scalar': scalar}, name='divs') class RDivideScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: dtype(scalar) / v, attributes={'scalar': scalar}, name='rdivs') @@ -478,6 +490,19 @@ def __init__(self, x: Tensor): ) +class ClampOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, min_value: Union[int, float], max_value: Union[int, float]): + assert isinstance(min_value, (int, float)) + assert isinstance(max_value, (int, float)) + min_value = x.dtype(min_value) + max_value = x.dtype(max_value) + super().__init__( + x, + op=lambda a: if_then_else(a < min_value, min_value, if_then_else(a > max_value, max_value, a)), + name='clamp', + ) + + class RightShiftOp(BinaryElementwiseOp): def __init__(self, x: Tensor, y: Tensor): super().__init__(x, y, op=lambda a, b: expr.RightShift(a, b), name='rightshift') @@ -522,6 +547,47 @@ def __init__(self, cond: Tensor, x: Tensor, y: Tensor): ) +class WhereScalarScalarOp(Operator): + def __init__(self, cond: Tensor, x: PyScalar, y: PyScalar): + if isinstance(x, int) and isinstance(y, int): + dtype = dtypes.default_int_dtype + elif isinstance(x, float) or isinstance(y, float): + dtype = dtypes.default_float_dtype + else: + raise ValueError(f'Unsupported scalar type: {type(x)}') + x, y = dtype(x), dtype(y) + super().__init__( + inputs=[cond], + attributes={'x': x, 'y': y}, + task=UnaryElementwiseTask(name='where', x=input_like(cond, 'cond'), op=lambda a: if_then_else(a, x, y)), + ) + + +class WhereScalarTensorOp(Operator): + def __init__(self, cond: Tensor, y: Tensor, x: PyScalar): + dtype = y.dtype + x = dtype(x) + super().__init__( + inputs=[cond, y], + attributes={'x': x}, + task=BinaryElementwiseTask( + name='where', x=input_like(cond, 'cond'), y=input_like(y, 'y'), op=lambda a, b: if_then_else(a, x, b) + ), + ) + + +class WhereTensorScalarOp(Operator): + def __init__(self, cond: Tensor, x: Tensor, y: PyScalar): + y = x.dtype(y) + super().__init__( + inputs=[cond, x], + attributes={'y': y}, + task=BinaryElementwiseTask( + name='where', x=input_like(cond, 'cond'), y=input_like(x, 'x'), op=lambda a, b: if_then_else(a, b, y) + ), + ) + + class MaxOp(Operator): def __init__(self, *tensors: Tensor): def scalar_max(args: List[expr.Expr]): @@ -792,10 +858,25 @@ def sign(x: Tensor) -> Tensor: return SignOp(x).outputs[0] -def where(cond: Tensor, x: Tensor, y: Tensor) -> Tensor: +def clamp(x: Tensor, min: Union[Tensor, float, int], max: Union[Tensor, float, int]) -> Tensor: + if isinstance(min, Tensor) or isinstance(max, Tensor): + raise NotImplementedError('clamp with tensor min/max is not implemented yet') + return ClampOp(x, min, max).outputs[0] + + +def where(cond: Tensor, x: Union[Tensor, PyScalar], y: Union[Tensor, PyScalar]) -> Tensor: if cond.dtype != dtypes.boolean: raise ValueError('The condition tensor must have dtype "bool", but got {}'.format(cond.dtype.name)) - return WhereOp(cond, x, y).outputs[0] + if isinstance(x, Tensor) and isinstance(y, Tensor): + return WhereOp(cond, x, y).outputs[0] + elif isinstance(x, Tensor) and isinstance(y, (int, float, complex)): + return WhereTensorScalarOp(cond, x=x, y=y).outputs[0] + elif isinstance(x, (int, float, complex)) and isinstance(y, Tensor): + return WhereScalarTensorOp(cond, x=x, y=y).outputs[0] + elif isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)): + return WhereScalarScalarOp(cond, x=x, y=y).outputs[0] + else: + raise ValueError('Invalid arguments for where: x={}, y={}'.format(x, y)) def maximum(a: Tensor, b: Tensor, *others: Tensor) -> Tensor: @@ -812,7 +893,8 @@ def mod(x: Tensor, y: Tensor) -> Tensor: return ModOp(x, y).outputs[0] -remainder = mod +def remainder(x: Tensor, y: Tensor) -> Tensor: + return mod(x, y) def abs(x: Tensor) -> Tensor: diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index 769de4cc7..ca669dadb 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -56,7 +56,7 @@ def allow_epilogue(self) -> bool: def allow_prologue(self) -> bool: return False - def implement_cuda(self, working_dir: str) -> IRModule: + def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]: rank = len(self.inputs[0].shape) if rank - 1 in self.dims: return tune.extract_ir_modules(self.cuda_schedule_reduce_by_warp) @@ -80,7 +80,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule: xdtype = x.type.dtype shape: List[Int] = list(x.shape) lanes = 1 - vtype: DataType = xdtype + vtype: Union[DataType, VectorType] = xdtype if xdtype.nbytes < 4: num_eles: int = 4 // xdtype.nbytes if is_constant(shape[-1]) and shape[-1] % num_eles == 0: @@ -204,7 +204,7 @@ def cuda_schedule_reduce_by_default(self) -> IRModule: shape: List[Int] = list(x.shape) lanes = 1 - vtype: DataType = xdtype + vtype: Union[VectorType, DataType] = xdtype if xdtype.nbytes < 4: num_eles: int = 4 // xdtype.nbytes if shape[-1] % num_eles == 0: diff --git a/python/hidet/ir/primitives/__init__.py b/python/hidet/ir/primitives/__init__.py index 21c1ae524..49866a0af 100644 --- a/python/hidet/ir/primitives/__init__.py +++ b/python/hidet/ir/primitives/__init__.py @@ -15,7 +15,7 @@ # pylint: disable=redefined-builtin from .math import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, expm1, abs from .math import max, min, exp, pow, sqrt, rsqrt, erf, ceil, log, log2, log10, log1p, round, floor, trunc -from .math import isfinite, isinf, isnan, make_vector +from .math import isfinite, isinf, isnan, make_vector, atan2, mod from .complex import real, imag, conj, make_complex From 740ff3ce04cd86058eb53612a14cbaecad590194 Mon Sep 17 00:00:00 2001 From: Hanjie <50634613+hjjq@users.noreply.github.com> Date: Sat, 12 Aug 2023 03:28:08 -0400 Subject: [PATCH 2/6] [Torch][Graph][Operator] Add and fix various items for torchvision model support (#347) 1. Enhance support for `__setitem__` and` __getitem__` of Tensor; Add SetStridedSlice Op, Roll Op. 2. Add/Update torch mapping for adaptive_avg_pool3d, eq, pad, roll, matmul, new_zeros, batch_norm, MultiHeadAttention. 3. Update torch Linear mapping to optionally accept transposed weights. 4. Fix a bug where a empty graph will output a zero tensor instead of the input/weight. --- .../hidet/graph/frontend/torch/interpreter.py | 8 ++ .../frontend/torch/register_functions.py | 107 ++++++++++++++- .../graph/frontend/torch/register_methods.py | 24 ++++ .../graph/frontend/torch/register_modules.py | 85 +++++++++++- python/hidet/graph/ops/__init__.py | 2 +- python/hidet/graph/ops/arithmetic.py | 127 +++++++++++++++++- python/hidet/graph/ops/transform.py | 45 +------ python/hidet/graph/ops/utils/tensor_utils.py | 43 +++++- python/hidet/graph/tensor.py | 9 ++ python/hidet/runtime/compiled_graph.py | 20 ++- 10 files changed, 401 insertions(+), 69 deletions(-) diff --git a/python/hidet/graph/frontend/torch/interpreter.py b/python/hidet/graph/frontend/torch/interpreter.py index 121d837dd..29efb0d59 100644 --- a/python/hidet/graph/frontend/torch/interpreter.py +++ b/python/hidet/graph/frontend/torch/interpreter.py @@ -359,6 +359,10 @@ def load_arg(a, env): hidet_kwargs = load_arg(node.kwargs, hidet_env) try: hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs) + from .register_functions import setitem + + if exec_func.functions[0] is setitem: + hidet_env[str(node.args[0])] = hidet_env[node.name] except Exception as e: self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs) elif node.op == "call_method": @@ -448,6 +452,10 @@ def load_arg(a, env): try: hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs) + from .register_functions import setitem + + if hidet_func.functions[0] is setitem: + hidet_env[str(node.args[0])] = hidet_env[node.name] except Exception as e: self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 8df6af553..3d0e50a55 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -18,10 +18,9 @@ from hidet.graph import ops from hidet.utils import same_list from hidet.ir.type import DataType -from hidet.ir.expr import Expr from hidet.ir import expr from hidet.ir.dtypes import promote_type -from hidet.ir.expr import Int +from hidet.ir.expr import Expr, Int, is_constant from hidet.runtime.device import Device from .interpreter import register_function, register_method from .interpreter import warnings @@ -97,6 +96,11 @@ def adaptive_avg_pool2d(x: Tensor, output_size): return ops.adaptive_avg_pool2d(x, output_size) +@register_function(torch.nn.functional.adaptive_avg_pool3d) +def adaptive_avg_pool3d(x: Tensor, output_size): + return ops.adaptive_avg_pool3d(x, output_size) + + @register_function(torch.nn.functional.relu) def relu(x: Tensor, inplace: bool): # if inplace: @@ -130,7 +134,9 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode= @register_function(torch.nn.functional.linear) -def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]): +def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False): + if len(weight.shape) > 1 and not weight_is_transposed: + weight = ops.transpose(weight, [1, 0]) y = ops.matmul(x, weight) if bias is not None: y = y + bias @@ -205,10 +211,18 @@ def batch_norm( ) y = ops.batch_norm_infer(x, running_mean, running_var, epsilon=eps) _ = momentum # unused + if len(x.shape) == 3: + dims = [0, 2] + if len(x.shape) == 4: + dims = [0, 2, 3] + elif len(x.shape) == 5: + dims = [0, 2, 3, 4] + else: + raise NotImplementedError("batch_norm only accepts 3D, 4D, 5D input") if weight is not None: - y = y * weight.unsqueeze([0, 2, 3]) + y = y * weight.unsqueeze(dims) if bias is not None: - y = y + bias.unsqueeze([0, 2, 3]) + y = y + bias.unsqueeze(dims) return y @@ -222,6 +236,70 @@ def getitem(x: Tensor, index): return x[index] +@register_function(operator.setitem) +def setitem(x: Tensor, item, setvalue): + + if isinstance(item, list): + item = tuple(item) + if not isinstance(item, tuple): + item = tuple([item]) + + if not isinstance(setvalue, (int, float)): + raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values') + + # now, the item could have + # 1. integer index + # 2. slice + # 3. Ellipsis + # 4. None + # e.g., [1, 3:5, ..., None] + + # process Ellipsis + # e.g., x[1, ..., 2] -> x[1, :, :, 2] + if Ellipsis in item: + if item.count(Ellipsis) > 1: + raise ValueError('Only one ellipsis allowed in index.') + ellipsis_index = item.index(Ellipsis) + ellipsis_ndim = len(x.shape) - sum([1 if axis not in [None, Ellipsis] else 0 for axis in item]) + ellipsis_ndim = max(ellipsis_ndim, 0) + item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :] + + # normalize index + normalized_item = [] + for i, v in enumerate(item): + if isinstance(v, int): + if v < 0: + v = v + x.shape[i] + if is_constant(v, x.shape[i]) and (v < 0 or v >= x.shape[i]): + raise IndexError('index {} is out of bound for dimension {} with size {}'.format(v, i, x.shape[i])) + normalized_item.append(v) + elif v is not None: + # None affects getitem, but is ignored in setitem + normalized_item.append(v) + item = tuple(normalized_item) + + # process slice and integer index + rank = len(x.shape) + while len(item) < rank: + item = item + (slice(None),) + starts, ends, steps = [], [], [] + squeeze_dims = [] + for dim, v in enumerate(item): + if isinstance(v, (int, Expr)): + squeeze_dims.append(dim) + starts.append(v) + ends.append(v + 1) + steps.append(1) + else: + assert isinstance(v, slice) + starts.append(v.start) + ends.append(v.stop) + steps.append(v.step) + + out = ops.set_strided_slice(x, starts, ends, steps, setvalue) + return out + + @register_function(operator.mul) @register_function(torch.mul) @register_function(torch.ops.aten.mul.Tensor) @@ -931,6 +1009,13 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor @register_function(operator.eq) def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: + if isinstance(a, Tensor) or isinstance(b, Tensor): + from hidet.graph.ops.utils import convert_to_tensor + + if isinstance(a, Tensor): + return ops.equal(a, convert_to_tensor(b, a)) + else: + return ops.equal(b, convert_to_tensor(a, b)) return a == b @@ -1104,3 +1189,15 @@ def clamp( @register_function(torch.isinf) def isinf(x: Tensor) -> Tensor: return ops.isinf(x) + + +@register_function(torch.nn.functional.pad) +def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0): + if isinstance(pad, tuple): + pad = list(pad) + return ops.pad(x, pads=pad, mode=mode, value=value) + + +@register_function(torch.roll) +def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None): + return ops.roll(x, shifts, dims) diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index a260bbe83..022bdbd23 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -265,3 +265,27 @@ def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor: @register_method(torch.Tensor.all) def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor: return ops.all(self, axis=dim, keepdims=keepdim) + + +@register_method(torch.Tensor.matmul) +def tensor_matmul(self: Tensor, other: Tensor) -> Tensor: + return ops.matmul(self, other) + + +@register_method(torch.Tensor.new_zeros) +def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False): + if layout is not None: + raise NotImplementedError("layout is not None") + if len(size) == 1: + if isinstance(size[0], (list, tuple)): + size = size[0] + shape = size + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + _ = pin_memory + _ = requires_grad + + return ops.full(shape, dtype=dtype, device=device, value=dtype.zero) diff --git a/python/hidet/graph/frontend/torch/register_modules.py b/python/hidet/graph/frontend/torch/register_modules.py index c4c58774f..df3055216 100644 --- a/python/hidet/graph/frontend/torch/register_modules.py +++ b/python/hidet/graph/frontend/torch/register_modules.py @@ -11,6 +11,7 @@ # limitations under the License. from __future__ import annotations import torch +from hidet.graph import ops from hidet.graph.tensor import Tensor from .interpreter import HidetModule, register_module from . import register_functions as regs @@ -117,6 +118,13 @@ def __call__(self, x: Tensor) -> Tensor: return regs.adaptive_avg_pool2d(x, self.mod.output_size) +@register_module(torch.nn.AdaptiveAvgPool3d) +class HidetAdaptiveAvgPool3d(HidetModule): + def __call__(self, x: Tensor) -> Tensor: + assert isinstance(self.mod, torch.nn.AdaptiveAvgPool3d) + return regs.adaptive_avg_pool3d(x, self.mod.output_size) + + @register_module(torch.nn.ReLU) class HidetReLU(HidetModule): def __call__(self, x: Tensor) -> Tensor: @@ -158,21 +166,21 @@ def __call__(self, x: Tensor) -> Tensor: class HidetLinear(HidetModule): def __init__(self, torch_module: torch.nn.Module): super().__init__(torch_module) - from hidet import ops - steal = dynamo_config['steal_weights'] - self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0]) def __call__(self, x: Tensor) -> Tensor: assert isinstance(self.mod, torch.nn.Linear) - return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True)) + return regs.linear( + x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True + ) @register_module(torch.nn.BatchNorm2d) +@register_module(torch.nn.BatchNorm3d) class HidetBatchNorm2d(HidetModule): def __call__(self, x: Tensor) -> Tensor: - assert isinstance(self.mod, torch.nn.BatchNorm2d) + assert isinstance(self.mod, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)) return regs.batch_norm( x=x, running_mean=self.param('running_mean'), @@ -404,3 +412,70 @@ def __call__(self, x: Tensor) -> Tensor: align_corners=self.mod.align_corners, recompute_scale_factor=self.mod.recompute_scale_factor, ) + + +@register_module(torch.nn.MultiheadAttention) +class HidetMultiheadAttention(HidetModule): + def __init__(self, torch_module: torch.nn.Module): + super().__init__(torch_module) + steal = dynamo_config['steal_weights'] + self.in_proj_weight_transposed = ops.transpose(self.param('in_proj_weight', steal=steal), [1, 0]) + self.out_proj_weight_transposed = ops.transpose(self.param('out_proj.weight', steal=steal), [1, 0]) + + def __call__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + is_causal=False, + ) -> Tensor: + assert isinstance(self.mod, torch.nn.MultiheadAttention) + supported = ( + self.mod._qkv_same_embed_dim + and self.mod.bias_k is None + and self.mod.bias_v is None + and not self.mod.add_zero_attn + and self.mod.batch_first + and key_padding_mask is None + and not need_weights + ) + if not supported: + raise NotImplementedError( + "Hidet Multihead Attention currently only supports " + "kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False, " + "batch_first=True, forward(key_padding_mask=None, need_weights=False)." + ) + + # Input feed forward + wq, wk, wv = ops.split(self.in_proj_weight_transposed, parts_or_sections=3, axis=1) + query = ops.matmul(query, wq) + key = ops.matmul(key, wk) + value = ops.matmul(value, wv) + if self.mod.in_proj_bias is not None: + bq, bk, bv = ops.split(self.param('in_proj_bias'), parts_or_sections=3, axis=0) + query = ops.add(query, bq) + key = ops.add(key, bk) + value = ops.add(value, bv) + + # Split heads + split_head_dims = [query.shape[0], query.shape[1], self.mod.num_heads, query.shape[2] // self.mod.num_heads] + query = ops.transpose(query.reshape(split_head_dims), [0, 2, 1, 3]) + key = ops.transpose(key.reshape(split_head_dims), [0, 2, 1, 3]) + value = ops.transpose(value.reshape(split_head_dims), [0, 2, 1, 3]) + + # fmha + out = regs.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=self.mod.dropout, is_causal=is_causal + ) + + # Output feed forward + merge_head_dims = [out.shape[0], out.shape[2], self.mod.embed_dim] + out = ops.transpose(out, [0, 2, 1, 3]).reshape(merge_head_dims) + out = ops.matmul(out, self.out_proj_weight_transposed) + if self.mod.out_proj.bias is not None: + out = ops.add(out, self.param('out_proj.bias')) + return out diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index a10466a6a..4303cd820 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -32,7 +32,7 @@ from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or from .arithmetic import bitwise_xor, maximum, minimum, clamp -from .arithmetic import isfinite, isinf, isnan, sign, where +from .arithmetic import isfinite, isinf, isnan, sign, where, set_strided_slice, roll from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2 from .complex import real, imag, conj, make_complex from .compare import equal, not_equal, less, greater, less_equal, greater_equal diff --git a/python/hidet/graph/ops/arithmetic.py b/python/hidet/graph/ops/arithmetic.py index 3e36c7bd4..9fbf807ba 100644 --- a/python/hidet/graph/ops/arithmetic.py +++ b/python/hidet/graph/ops/arithmetic.py @@ -10,16 +10,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin, unnecessary-lambda -from typing import List, Callable, Any, Union, Optional, Dict +from typing import List, Callable, Any, Union, Optional, Dict, Sequence from hidet.ir import primitives from hidet.ir import Var, expr, dtypes from hidet.ir.type import DataType +from hidet.ir.expr import Expr, if_then_else, logical_or, is_constant, is_true from hidet.ir.tools import rewrite -from hidet.ir.expr import Expr, if_then_else, is_true from hidet.utils import prod, same_list from .utils import Task, Operator, Tensor, TensorNode, InverseMap, compute, input_like from .utils import broadcast_shape, broadcast_shapes, broadcast_indices +from .utils import normalize_slice, normalize_dim PyScalar = Union[int, float, bool] @@ -188,6 +189,78 @@ def __init__(self, cond: TensorNode, x: TensorNode, y: TensorNode): ) +class SetStridedSliceTask(Task): + def __init__( + self, + data: TensorNode, + starts: List[Optional[int]], + ends: List[Optional[int]], + axes: List[int], + strides: List[int], + setvalue: [Union[int, float]], + ): + assert len(starts) == len(ends) == len(axes) == len(strides) + if len(axes) != len(set(axes)): + raise ValueError('Duplicated axes in slice, axes: {}'.format(axes)) + output_shape = list(data.shape) + axis2info = {} + for axis, start, end, stride in zip(axes, starts, ends, strides): + if stride == 0: + raise NotImplementedError( + 'Stride can not be 0 in slicing: ' + 'starts {} ends {} axes {} strides {}.'.format(starts, ends, axes, strides) + ) + if is_constant(output_shape[axis]) and output_shape[axis] < 0: + raise NotImplementedError( + 'Slice result can not be: ' + 'starts {} ends {} axes {} strides {}'.format(starts, ends, axes, strides) + ) + axis2info[axis] = (start, end, stride) + + def fmap(indices): + ret = data.type.dtype(setvalue) + for axis, index in enumerate(indices): + start, end, stride = axis2info[axis] + ret = if_then_else( + logical_or(index < start, index >= end, (index - start) % stride != 0), data[indices], ret + ) + return ret + + out = compute('out', shape=output_shape, fcompute=lambda *indices: fmap(indices)) + super().__init__(name='set_slice', inputs=[data], outputs=[out]) + + +class RollTask(Task): + def __init__(self, x: TensorNode, shifts: Sequence[int], dims: Sequence[int]): + output_shape = list(x.shape) + + def fmap(indices): + data_indices = [] + for axis, index in enumerate(indices): + if axis in dims: + i = dims.index(axis) + if shifts[i] > 0: + data_indices.append( + if_then_else( + index - shifts[i] >= 0, index - shifts[i], index + output_shape[axis] - shifts[i] + ) + ) + else: + data_indices.append( + if_then_else( + index - shifts[i] < output_shape[axis], + index - shifts[i], + index - output_shape[axis] - shifts[i], + ) + ) + else: + data_indices.append(index) + return x[data_indices] + + out = compute('out', shape=output_shape, fcompute=lambda *indices: fmap(indices)) + super().__init__(name='roll', inputs=[x], outputs=[out]) + + class UnaryElementwiseOp(Operator): def __init__(self, x: Tensor, op, name: str, attributes: Optional[Dict[str, Any]] = None, task_attributes=None): if attributes is None: @@ -626,6 +699,32 @@ def scalar_min(args: List[expr.Expr]): ) +class SetStridedSliceOp(Operator): + def __init__( + self, + data: Tensor, + starts: Sequence[Optional[int]], + ends: Sequence[Optional[int]], + strides: Optional[Sequence[Optional[int]]] = None, + setvalue: Optional[Union[int, float]] = 0.0, + ): + starts, ends, axes, strides = normalize_slice(data.shape, starts, ends, axes=None, strides=strides) + task = SetStridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides, setvalue) + super().__init__( + inputs=[data], + attributes={'starts': starts, 'ends': ends, 'strides': strides, 'setvalue': setvalue}, + task=task, + ) + + +class RollOp(Operator): + def __init__(self, x: Tensor, shifts: Sequence[int], dims: Sequence[int]) -> Tensor: + if not len(shifts) == len(dims): + raise ValueError('Roll must have same size shifts and dims, got {} and {}'.format(len(shifts), len(dims))) + task = RollTask(input_like(x, 'x'), shifts, dims) + super().__init__(inputs=[x], attributes={'shifts': shifts, 'dims': dims}, task=task) + + Scalar = Union[Expr, float, int, complex] @@ -946,6 +1045,20 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: return log(exp(x - max_val) + exp(y - max_val)) + max_val +def roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None) -> Tensor: + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + if dims is None: + from .transform import flatten, reshape + + shape = x.shape + return reshape(RollOp(flatten(x), shifts, dims=[0]).outputs[0], shape) + dims = normalize_dim(dims, len(x.shape)) + return RollOp(x, shifts, dims).outputs[0] + + # out = binary_op(left_unary_op(x), right_unary_op(x)); This allows more fusion opportunity. def composite_elementwise( x: Tensor, @@ -954,3 +1067,13 @@ def composite_elementwise( binary_op: BinaryElementwiseOperation, ) -> Tensor: return CompositeElementwiseOp(x, left_unary_op, right_unary_op, binary_op).outputs[0] + + +def set_strided_slice( + data: Tensor, + starts: Sequence[Optional[int]], + ends: Sequence[Optional[int]], + strides: Optional[Sequence[Optional[int]]] = None, + setvalue: Optional[Union[int, float]] = 0.0, +) -> Tensor: + return SetStridedSliceOp(data, starts, ends, strides, setvalue).outputs[0] diff --git a/python/hidet/graph/ops/transform.py b/python/hidet/graph/ops/transform.py index 90af495c1..e0c5c114d 100644 --- a/python/hidet/graph/ops/transform.py +++ b/python/hidet/graph/ops/transform.py @@ -17,7 +17,7 @@ from hidet.ir.utils import index_deserialize, index_serialize from hidet.utils import prod from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast -from .utils import TensorInput +from .utils import TensorInput, normalize_slice def is_true(x: Union[Expr, bool]) -> bool: @@ -444,53 +444,12 @@ def __init__( axes: Optional[Sequence[Optional[int]]] = None, strides: Optional[Sequence[Optional[int]]] = None, ): - starts, ends, axes, strides = self.normalize(data.shape, starts, ends, axes, strides) + starts, ends, axes, strides = normalize_slice(data.shape, starts, ends, axes, strides) task = StridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides) super().__init__( inputs=[data], attributes={'starts': starts, 'ends': ends, 'axes': axes, 'strides': strides}, task=task ) - @staticmethod - def normalize(data_shape, starts, ends, axes: Optional[List[int]], strides: Optional[List[Optional[int]]]): - # follow: https://data-apis.org/array-api/latest/API_specification/indexing.html - if axes is None: - axes = [i for i in range(len(starts))] - axes = normalize_dim(axes, len(data_shape)) - if strides is None: - strides = [1 for _ in range(len(starts))] - shape = [data_shape[i] for i in axes] - assert len(shape) == len(starts) == len(ends) == len(axes) == len(strides) - - ii, jj, kk = [], [], [] - for i, j, k, n in zip(starts, ends, strides, shape): - if k is None: - k = 1 - if k > 0: - i = i if i is not None else 0 - j = j if j is not None else n - if is_constant(i, j, n) and not (-n <= i <= n and -n <= j): - raise IndexError('Invalid slice') - j = if_then_else(j < n, j, n) - if is_constant(i) and i < 0: - i = i + n - if is_constant(j) and j < 0: - j = j + n - elif k < 0: - i = i if i is not None else n - 1 - j = j if j is not None else -n - 1 - if is_constant(i) and i < 0: - i += n - if is_constant(j) and j < -1: - j += n - if is_constant(i, j, n) and not (-n <= i <= n and -n - 1 <= j <= max(0, n - 1)): - raise IndexError('Invalid slice') - else: - raise IndexError('slice step cannot be zero') - ii.append(i) - jj.append(j) - kk.append(k) - return ii, jj, axes, kk - class BroadcastOp(Operator): def __init__(self, data: Tensor, shape: List[int]): diff --git a/python/hidet/graph/ops/utils/tensor_utils.py b/python/hidet/graph/ops/utils/tensor_utils.py index 3d7a6fde5..93f13a712 100644 --- a/python/hidet/graph/ops/utils/tensor_utils.py +++ b/python/hidet/graph/ops/utils/tensor_utils.py @@ -14,7 +14,7 @@ import builtins from hidet.ir.layout import DataLayout from hidet.ir.type import Int -from hidet.ir.expr import Var, Expr, Constant, is_constant +from hidet.ir.expr import Var, Expr, Constant, is_constant, if_then_else from hidet.ir.type import TensorType, tensor_type, DataType from hidet.ir.task import Task, InverseMap from hidet.ir.module import IRModule @@ -159,3 +159,44 @@ def convert_to_tensor(value: Union[int, float, bool, complex, Tensor], involved_ return full_like(involved_tensor, fill_value=value, shape=[], dtype='float32') else: raise ValueError('Can not recognize dtype {}'.format(involved_tensor.dtype)) + + +def normalize_slice(data_shape, starts, ends, axes: Optional[List[int]], strides: Optional[List[Optional[int]]]): + # follow: https://data-apis.org/array-api/latest/API_specification/indexing.html + if axes is None: + axes = [i for i in range(len(starts))] + axes = normalize_dim(axes, len(data_shape)) + if strides is None: + strides = [1 for _ in range(len(starts))] + shape = [data_shape[i] for i in axes] + assert len(shape) == len(starts) == len(ends) == len(axes) == len(strides) + + ii, jj, kk = [], [], [] + for i, j, k, n in zip(starts, ends, strides, shape): + if k is None: + k = 1 + if k > 0: + i = i if i is not None else 0 + j = j if j is not None else n + if is_constant(i, j, n) and not (-n <= i <= n and -n <= j): + raise IndexError('Invalid slice') + j = if_then_else(j < n, j, n) + if is_constant(i) and i < 0: + i = i + n + if is_constant(j) and j < 0: + j = j + n + elif k < 0: + i = i if i is not None else n - 1 + j = j if j is not None else -n - 1 + if is_constant(i) and i < 0: + i += n + if is_constant(j) and j < -1: + j += n + if is_constant(i, j, n) and not (-n <= i <= n and -n - 1 <= j <= max(0, n - 1)): + raise IndexError('Invalid slice') + else: + raise IndexError('slice step cannot be zero') + ii.append(i) + jj.append(j) + kk.append(k) + return ii, jj, axes, kk diff --git a/python/hidet/graph/tensor.py b/python/hidet/graph/tensor.py index c45f73eee..d7896aa15 100644 --- a/python/hidet/graph/tensor.py +++ b/python/hidet/graph/tensor.py @@ -342,6 +342,15 @@ def __str__(self): def __getitem__(self, item): from hidet.graph.ops import strided_slice + if isinstance(item, Tensor): + if len(item.shape) > 1: + raise NotImplementedError("Tensor indexing via Tensor currently only supports 1D index tensor") + if not item.dtype.is_integer(): + raise TypeError("Tensor indexing via Tensor requires integer index tensor") + from .ops import take + + return take(self, item, axis=0) + if isinstance(item, list): item = tuple(item) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index dd86f0e03..2553bcdbe 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Dict, Any, Callable, Union +from typing import List, Optional, Tuple, Dict, Any, Callable import zipfile import os import json @@ -95,7 +95,6 @@ def __init__( # derived properties (will be initialized in _init_compiled_graph at the end of this constructor) self.dynamic_dims: List[Tuple[str, Tuple[int, int]]] = [] # [(name, (tensor_index, dim_index))] self.is_dynamic: bool = False - self.constant_outputs: List[Union[None, Tensor]] = [] # runtime state self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash) @@ -145,11 +144,6 @@ def _init_compiled_graph(self): self.is_dynamic = True else: self.is_dynamic = False - for out_idx in self.graph_execution.outputs_index: - if out_idx in self.graph_execution.weights_index: - self.constant_outputs.append(self.weights[self.graph_execution.weights_index.index(out_idx)]) - else: - self.constant_outputs.append(None) # initialize weights weights_buffer = Array(void_p, len(self.weights)) @@ -209,13 +203,15 @@ def _update_symbol_dims(self, inputs) -> Tuple[int, ...]: runtime_api.set_symbol_value(name, symbol_dims[-1]) return tuple(symbol_dims) - def _create_outputs(self): + def _create_outputs(self, inputs): from hidet.graph.tensor import empty outputs = [] - for output_index, (sig, const_out) in enumerate(zip(self.meta.outputs, self.constant_outputs)): - if const_out is not None: - outputs.append(const_out) + for output_index, (exec_idx, sig) in enumerate(zip(self.graph_execution.outputs_index, self.meta.outputs)): + if exec_idx in self.graph_execution.inputs_index: + outputs.append(inputs[self.graph_execution.inputs_index.index(exec_idx)]) + elif exec_idx in self.graph_execution.weights_index: + outputs.append(self.weights[self.graph_execution.weights_index.index(exec_idx)]) else: if self.is_dynamic: shape_buffer = Array(i32, len(sig.shape)) @@ -240,7 +236,7 @@ def _prepare_workspace(self): def _run_fast_path(self, inputs, symbol_dims: Tuple[int, ...]): # create output tensors - outputs = self._create_outputs() + outputs = self._create_outputs(inputs) # prepare workspace self._prepare_workspace() From edb6503de22ad6eb7975a8da2019ff28867f06a4 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Mon, 14 Aug 2023 17:06:12 -0400 Subject: [PATCH 3/6] [Dynamo] minor enhancements to attention and register a few functions (#345) Encountered a few minor issues when compiling a transformer-based model using torch.compile with very large batch sizes, submitting the fix here. --- .../frontend/torch/register_functions.py | 22 +++++++++++++++++++ .../graph/frontend/torch/register_methods.py | 10 ++++++++- python/hidet/graph/ops/attention/attention.py | 20 ++++++++--------- .../graph/ops/attention/attention_mask.py | 20 ++++++++--------- 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 3d0e50a55..436d35796 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1066,6 +1066,28 @@ def torch_mean( return output +@register_function(torch.sum) +@register_method(torch.Tensor.sum) +def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor: + if dtype: + x = x.astype(dtype_from_torch(dtype)) + output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True) + return output + + +@register_function(torch.sum) +@register_method(torch.Tensor.sum) +def torch_sum( + x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None +) -> Tensor: + if out is not None: + raise NotImplementedError("hidet: does not support torch.sum(..., out=...)") + if dtype: + x = x.astype(dtype_from_torch(dtype)) + output = ops.sum(x, dims=dim, keep_dim=keepdim) + return output + + @register_function(torch.cumsum) def torch_cumsum(x: Tensor, dim, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None) -> Tensor: if out is not None: diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index 022bdbd23..6123f2101 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -90,6 +90,11 @@ def tensor_to(self: Tensor, *args, **kwargs) -> Tensor: if self.is_symbolic() and instantiate_device(device_from_torch(arg)) != self.device: raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.') device = arg + elif isinstance(arg, Tensor): + dtype = arg.dtype + if self.is_symbolic() and arg.device != self.device: + raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.') + device = arg.device else: raise ValueError(f'Unsupported argument type: {type(arg)}') @@ -222,7 +227,10 @@ def tensor_type(self: Tensor, dtype: Union[str, torch.dtype], non_blocking: bool @register_method(torch.Tensor.expand) def tensor_expand(self: Tensor, *sizes: int) -> Tensor: - sizes: List[int] = list(sizes) + if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): + sizes = sizes[0] + else: + sizes: List[int] = list(sizes) assert len(sizes) >= len(self.shape) for i in range(len(sizes)): if sizes[i] == -1: diff --git a/python/hidet/graph/ops/attention/attention.py b/python/hidet/graph/ops/attention/attention.py index 3e459b7d1..278deb49b 100644 --- a/python/hidet/graph/ops/attention/attention.py +++ b/python/hidet/graph/ops/attention/attention.py @@ -397,7 +397,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type): def copy_k_g2s_sm80( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j_seg in k_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -411,7 +411,7 @@ def copy_k_g2s_sm80( @hidet.script def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j_seg in v_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -421,7 +421,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j_seg in q_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -433,7 +433,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs def copy_k_g2s_sm75( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j in k_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]: @@ -444,7 +444,7 @@ def copy_k_g2s_sm75( @hidet.script def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j in v_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]: @@ -455,7 +455,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j in q_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]: @@ -488,7 +488,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: @hidet.script def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_o = o[o_head_index][offset_i:, :] for k_round in range(warp_count_k): for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id): @@ -652,12 +652,12 @@ def attn_kernel( v: f16[v_head + [n_kv_size, d_size]], o: f16[o_head + [n_size, d_size]], ): - attrs.cuda.grid_dim = (i_split, bs) + attrs.cuda.grid_dim = i_split * bs attrs.cuda.block_dim = block_size attrs.cuda.min_blocks = 1 attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes - offset_i = blockIdx.x * i_rows_per_tb + offset_i = (blockIdx.x % i_split) * i_rows_per_tb smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout) smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout) @@ -702,7 +702,7 @@ def attn_kernel( j_tiles = cdiv(n_kv_size, block_j) if is_causal: - j_tiles = min(cdiv((blockIdx.x + 1) * block_i, block_j), j_tiles) + j_tiles = min(cdiv(((blockIdx.x % i_split) + 1) * block_i, block_j), j_tiles) for j in range(j_tiles): offset_j = block_j * j diff --git a/python/hidet/graph/ops/attention/attention_mask.py b/python/hidet/graph/ops/attention/attention_mask.py index 873c34de7..1a91c02d6 100644 --- a/python/hidet/graph/ops/attention/attention_mask.py +++ b/python/hidet/graph/ops/attention/attention_mask.py @@ -425,7 +425,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type): def copy_k_g2s_sm80( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j_seg in k_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -439,7 +439,7 @@ def copy_k_g2s_sm80( @hidet.script def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j_seg in v_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -449,7 +449,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j_seg in q_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -461,7 +461,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs def copy_k_g2s_sm75( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j in k_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]: @@ -472,7 +472,7 @@ def copy_k_g2s_sm75( @hidet.script def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j in v_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]: @@ -483,7 +483,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j in q_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]: @@ -516,7 +516,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: @hidet.script def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_o = o[o_head_index][offset_i:, :] for k_round in range(warp_count_k): for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id): @@ -681,12 +681,12 @@ def attn_kernel( mask: f16[mask_shape], o: f16[o_head + [n_size, d_size]], ): - attrs.cuda.grid_dim = (i_split, bs) + attrs.cuda.grid_dim = i_split * bs attrs.cuda.block_dim = block_size attrs.cuda.min_blocks = 1 attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes - offset_i = blockIdx.x * i_rows_per_tb + offset_i = (blockIdx.x % i_split) * i_rows_per_tb smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout) smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout) @@ -773,7 +773,7 @@ def attn_kernel( copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j) # Apply Masking - qk_head_index = list(spatial(*qk_head).map(blockIdx.y)) + qk_head_index = list(spatial(*qk_head).map(blockIdx.x // i_split)) for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 wi, wj, wk = spatial(warp_count_m, warp_count_n, warp_count_k).map(warp_id) From 1fb45f72b31be0b8be451d428ff7fae9ed5e2758 Mon Sep 17 00:00:00 2001 From: Hanjie <50634613+hjjq@users.noreply.github.com> Date: Tue, 15 Aug 2023 17:08:49 -0400 Subject: [PATCH 4/6] [Torch][Operator] More torchvision model support (#348) This is a continuation of #347. 1. Add LP normalization task (ToDo: schedule template) 2. Add torch mappings for normalize, clone, zero_, exp, chunk 3. Add ceil_mode=True support for pool2d 4. Fix dtype issue in resize 5. Fix other bugs in pad, conv2d_pattern --- .../frontend/torch/register_functions.py | 47 ++++++++- .../graph/frontend/torch/register_methods.py | 5 + python/hidet/graph/ops/__init__.py | 2 +- python/hidet/graph/ops/arithmetic.py | 2 +- python/hidet/graph/ops/image.py | 17 ++-- python/hidet/graph/ops/normalize/__init__.py | 1 + python/hidet/graph/ops/normalize/lp.py | 98 +++++++++++++++++++ python/hidet/graph/ops/pool.py | 28 +++--- .../graph_patterns/conv2d_patterns.py | 2 +- .../hidet/ir/primitives/cuda/math/float16.py | 4 + 10 files changed, 181 insertions(+), 25 deletions(-) create mode 100644 python/hidet/graph/ops/normalize/lp.py diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 436d35796..27958f4b6 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -113,11 +113,9 @@ def relu(x: Tensor, inplace: bool): def max_pool2d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=False, return_indices=False): if dilation != 1 and not same_list(dilation, [1, 1]): raise NotImplementedError("dilation != 1") - if ceil_mode: - raise NotImplementedError("ceil_mode=True") if return_indices: raise NotImplementedError("return_indices=True") - y = ops.max_pool2d(x, kernel_size, stride, padding) + y = ops.max_pool2d(x, kernel_size, stride, padding, ceil_mode=ceil_mode) return y @@ -594,6 +592,7 @@ def permute(x: Tensor, *args): return ops.transpose(x, dims) +@register_function(torch.swapaxes) @register_function(torch.transpose) @register_method(torch.Tensor.transpose) def transpose(x: Tensor, dim0: int, dim1: int): @@ -775,6 +774,7 @@ def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: @register_function(torch.exp) +@register_method(torch.Tensor.exp) def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: if out is not None: warnings.warn_once("hidet: does not support torch.exp(..., out=...)") @@ -1217,9 +1217,50 @@ def isinf(x: Tensor) -> Tensor: def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0): if isinstance(pad, tuple): pad = list(pad) + # Torch's pad list has form [p2left, p2right, p1left, p1right, p0left, p0right] + # Hidet's pad list has form [p0left, p1left, p2left, p0right, p1right, p2right] + left = [] + right = [] + for i, p in enumerate(pad): + if i % 2 == 0: + left.append(p) + else: + right.append(p) + left.reverse() + right.reverse() + pad = [] + for p in left: + pad.append(p) + for p in right: + pad.append(p) return ops.pad(x, pads=pad, mode=mode, value=value) @register_function(torch.roll) def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None): return ops.roll(x, shifts, dims) + + +@register_function(torch.nn.functional.normalize) +def torch_normalize(x: Tensor, p=2.0, dim=1, eps=1e-12, out=None): + if out is not None: + raise NotImplementedError("out is not None") + return ops.lp_norm(x, p, dim, eps) + + +@register_function(torch.clone) +@register_method(torch.Tensor.clone) +def torch_clone(x: Tensor, *, memory_format=torch.preserve_format): + if memory_format is not torch.preserve_format: + warnings.warn_once( + "torch.clone got memory_format not torch.preserve_format, treating it as torch.preserve_format" + ) + if x.is_symbolic(): + return x + else: + return x.copy() + + +@register_function(torch.chunk) +def torch_chunk(x: Tensor, chunks: int, dim: int = 0): + return ops.split(x, parts_or_sections=chunks, axis=dim) diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index 6123f2101..a9b5c9d21 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -297,3 +297,8 @@ def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, _ = requires_grad return ops.full(shape, dtype=dtype, device=device, value=dtype.zero) + + +@register_method(torch.Tensor.zero_) +def tensor_zero_(self: Tensor): + return ops.full(self.shape, dtype=self.dtype, device=self.device, value=self.dtype.zero) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 4303cd820..64f539725 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -24,7 +24,7 @@ from .activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink from .activation import softshrink, softmax, softmin, hardtanh from .attention import attention -from .normalize import batch_norm_infer, instance_norm, layer_norm, group_norm +from .normalize import batch_norm_infer, instance_norm, layer_norm, group_norm, lp_norm from .image import resize2d from .create import full, arange, linspace, tri from .arithmetic import add, subtract, multiply, divide, mod, remainder, negative, positive, square diff --git a/python/hidet/graph/ops/arithmetic.py b/python/hidet/graph/ops/arithmetic.py index 9fbf807ba..7233dc197 100644 --- a/python/hidet/graph/ops/arithmetic.py +++ b/python/hidet/graph/ops/arithmetic.py @@ -718,7 +718,7 @@ def __init__( class RollOp(Operator): - def __init__(self, x: Tensor, shifts: Sequence[int], dims: Sequence[int]) -> Tensor: + def __init__(self, x: Tensor, shifts: Sequence[int], dims: Sequence[int]): if not len(shifts) == len(dims): raise ValueError('Roll must have same size shifts and dims, got {} and {}'.format(len(shifts), len(dims))) task = RollTask(input_like(x, 'x'), shifts, dims) diff --git a/python/hidet/graph/ops/image.py b/python/hidet/graph/ops/image.py index b59a7f0af..2100d1be5 100644 --- a/python/hidet/graph/ops/image.py +++ b/python/hidet/graph/ops/image.py @@ -12,7 +12,7 @@ from typing import Optional, List, Sequence, Union from hidet.ir.dtypes import int32 -from hidet.ir.expr import Expr, Int, if_then_else, cast, logical_or, logical_and +from hidet.ir.expr import Expr, Int, if_then_else, cast, logical_or, logical_and, convert from hidet.ir import primitives as prim from .utils import Task, Operator, Tensor, TensorNode, compute, input_like @@ -59,8 +59,8 @@ def get_2d_pixel(data: TensorNode, n, c, h, w) -> Expr: return data[n, c, h, w] -def linear_interpolate(a, b, ratio): - return a * (1.0 - ratio) + b * ratio +def linear_interpolate(a, b, ratio, dtype='float32'): + return a * (convert(1.0, dtype) - ratio) + b * ratio def get_cubic_weights(s, a) -> List[int]: @@ -112,6 +112,7 @@ def resize2d_nchw_compute( scale_factor = _normalize(scale_factor, 2) size = _normalize(size, 2) + dtype = data.type.dtype if size is not None and scale_factor is None: target_size = size @@ -146,12 +147,12 @@ def fmap(n, c, h, w): elif method == 'linear': h_int = cast(prim.floor(h), 'int32') w_int = cast(prim.floor(w), 'int32') - h_ratio = h - h_int - w_ratio = w - w_int + h_ratio = cast(h - h_int, dtype) + w_ratio = cast(w - w_int, dtype) pixels = [[get_2d_pixel(data, n, c, h_int + i, w_int + j) for j in range(2)] for i in range(2)] - top = linear_interpolate(*pixels[0], w_ratio) - bottom = linear_interpolate(*pixels[1], w_ratio) - value = linear_interpolate(top, bottom, h_ratio) + top = linear_interpolate(*pixels[0], w_ratio, dtype) + bottom = linear_interpolate(*pixels[1], w_ratio, dtype) + value = linear_interpolate(top, bottom, h_ratio, dtype) elif method == 'cubic': h_int = cast(prim.floor(h), 'int32') w_int = cast(prim.floor(w), 'int32') diff --git a/python/hidet/graph/ops/normalize/__init__.py b/python/hidet/graph/ops/normalize/__init__.py index 29b24ce29..a3b05d9be 100644 --- a/python/hidet/graph/ops/normalize/__init__.py +++ b/python/hidet/graph/ops/normalize/__init__.py @@ -10,4 +10,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from .layers import batch_norm_infer, layer_norm, instance_norm, group_norm +from .lp import lp_norm from . import resolve diff --git a/python/hidet/graph/ops/normalize/lp.py b/python/hidet/graph/ops/normalize/lp.py new file mode 100644 index 000000000..e0fe68d13 --- /dev/null +++ b/python/hidet/graph/ops/normalize/lp.py @@ -0,0 +1,98 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir import primitives as prim +from hidet.ir.compute import reduce +from hidet.ir.expr import cast +from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode +from hidet.graph.ops.utils import compute, input_like, normalize_dim + + +class LpNormTask(Task): + """ + Performs LP normalization along specified dimension. + For a tensor input of shape [n0, n1, ..., ndim, ..., nk], each + ndim-element vector v along dimension dim is transformed as: + v = v / max(sum(pow(abs(v), p)), eps) + """ + + def __init__(self, x: TensorNode, p: float, dim: int, eps: float): + x_shape = x.const_shape + y_shape = x_shape + dtype = x.type.dtype + reduce_shape = [] + other_shape = [] + for idx, size in enumerate(x_shape): + if idx == dim: + reduce_shape.append(size) + else: + other_shape.append(size) + + def sum_compute(*indices): + def sum_reduce(*reduction_axis): + x_indices = [] + p = 0 + q = 0 + for i in range(len(x.shape)): + if i != dim: + x_indices.append(indices[p]) + p += 1 + else: + x_indices.append(reduction_axis[q]) + q += 1 + assert p == len(indices) and q == len(reduction_axis) + # Force fp32 reduction for accuracy + return prim.pow(cast(prim.abs(x[x_indices]), 'float32'), p) + + return reduce(shape=reduce_shape, fcompute=sum_reduce, reduce_type='sum') + + sum_ = compute(name='sum', shape=other_shape, fcompute=sum_compute) + + p_norm = compute(name='p_norm', shape=other_shape, fcompute=lambda *indices: prim.pow(sum_[indices], 1.0 / p)) + + def y_compute(*indices): + norm_indices = [index for i, index in enumerate(indices) if i != dim] + return cast(x[indices] / prim.max(p_norm[norm_indices], eps), dtype) + + y = compute(name='y', shape=y_shape, fcompute=y_compute) + + super().__init__(name='lp_norm', inputs=[x], outputs=[y], attributes={'p': p, 'dim': dim, 'eps': eps}) + + +class LpNormOp(Operator): + def __init__(self, x: Tensor, p: float, dim: int, eps: float): + super().__init__( + inputs=[x], attributes={'p': p, 'dim': dim, 'eps': eps}, task=LpNormTask(input_like(x, 'x'), p, dim, eps) + ) + + +def lp_norm(x: Tensor, p=2.0, dim=1, eps=1e-12): + """LP norm. + + Parameters + ---------- + x: Tensor + The data to be normalized. + p: float + The exponent value in the norm formulation. + dim: int + The dimension to reduce. + eps: float + Small value to avoid division by zero. + + Returns + ------- + ret: Tensor + The normalized tensor. + """ + # Normalize dim + dim = normalize_dim(dim, rank=len(x.shape)) + return LpNormOp(x, p, dim, eps).outputs[0] diff --git a/python/hidet/graph/ops/pool.py b/python/hidet/graph/ops/pool.py index 4e34e281c..42544fbf0 100644 --- a/python/hidet/graph/ops/pool.py +++ b/python/hidet/graph/ops/pool.py @@ -19,14 +19,18 @@ class Pool2dTask(Task): - def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str): + def __init__(self, x: TensorNode, kernel, strides, padding, ceil_mode: bool, reduce_type: str): assert reduce_type in ['max', 'avg'] kernel = normalize_kernel(kernel) strides = normalize_stride(strides) padding = normalize_padding(padding) batch_size, channels, height, width = x.shape - out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 + if ceil_mode: + out_height = (height + padding[0] + padding[2] - kernel[0] + strides[0] - 1) // strides[0] + 1 + out_width = (width + padding[1] + padding[3] - kernel[1] + strides[1] - 1) // strides[1] + 1 + else: + out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 + out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.type.dtype) pad = compute( name='pad', @@ -145,11 +149,12 @@ def __init__( kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], + ceil_mode: bool, ): super().__init__( inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='max'), + attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, + task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='max'), ) @@ -175,11 +180,12 @@ def __init__( kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], + ceil_mode: bool, ): super().__init__( inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='avg'), + attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, + task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='avg'), ) @@ -245,16 +251,16 @@ def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): super().__init__(x, output_size, reduce_type='max', attrs={'output_size': output_size}, spatial_ndim=3) -def max_pool2d(x: Tensor, kernel, stride, padding) -> Tensor: - return MaxPool2dOp(x, kernel, stride, padding).outputs[0] +def max_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] def max_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: return MaxPool3dOp(x, kernel, stride, padding).outputs[0] -def avg_pool2d(x: Tensor, kernel, stride, padding) -> Tensor: - return AvgPool2dOp(x, kernel, stride, padding).outputs[0] +def avg_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return AvgPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] def avg_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: diff --git a/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py b/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py index d097dcb97..e8aa18d4d 100644 --- a/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py +++ b/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py @@ -125,7 +125,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: x, w, stride=op1.attrs['stride'], - dilations=op1.attrs['dilation'], + dilations=op1.attrs['dilations'], groups=1, padding=op1.attrs['padding'], ) diff --git a/python/hidet/ir/primitives/cuda/math/float16.py b/python/hidet/ir/primitives/cuda/math/float16.py index 344875cfd..ed64d0a1c 100644 --- a/python/hidet/ir/primitives/cuda/math/float16.py +++ b/python/hidet/ir/primitives/cuda/math/float16.py @@ -62,6 +62,7 @@ class CUDAFloat16MathFunctionSet(MathFunctionSet): # pylint: disable=abstract-method def register(self): entries = { + 'abs': ['__habs', 1], 'sin': ['hsin', 1], 'cos': ['hcos', 1], 'exp': ['hexp', 1], @@ -142,6 +143,9 @@ def call(self, name: str, *args) -> Expr: entry = primitive_func_pool.lookup_by_name(name) return entry.var(*args) + def abs(self, a: Expr) -> Expr: + return self.call('cuda_f16_abs', a) + def sin(self, a: Expr) -> Expr: return self.call('cuda_f16_sin', a) From d4eadcc932fed38f217512c574bcb61b1d78461e Mon Sep 17 00:00:00 2001 From: Hanjie <50634613+hjjq@users.noreply.github.com> Date: Wed, 16 Aug 2023 11:22:45 -0400 Subject: [PATCH 5/6] [Operator] Add einsum (#349) Add an ad-hoc implementation of einsum based on pattern matching. Only supports batched matmul. --- .../frontend/torch/register_functions.py | 5 ++ python/hidet/graph/ops/__init__.py | 1 + python/hidet/graph/ops/linear.py | 59 +++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 python/hidet/graph/ops/linear.py diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 27958f4b6..0eddb6726 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -1264,3 +1264,8 @@ def torch_clone(x: Tensor, *, memory_format=torch.preserve_format): @register_function(torch.chunk) def torch_chunk(x: Tensor, chunks: int, dim: int = 0): return ops.split(x, parts_or_sections=chunks, axis=dim) + + +@register_function(torch.einsum) +def torch_einsum(equation, *operands): + return ops.einsum(equation, operands) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 64f539725..8a2fcd5f3 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -47,5 +47,6 @@ from .transfer import transfer from .special import barrier from .distributed import all_reduce, all_gather, reduce_scatter +from .linear import einsum from . import utils diff --git a/python/hidet/graph/ops/linear.py b/python/hidet/graph/ops/linear.py new file mode 100644 index 000000000..99e9db893 --- /dev/null +++ b/python/hidet/graph/ops/linear.py @@ -0,0 +1,59 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence +from .utils import Tensor +from .matmul import matmul + +# ToDo: Actually fully implement einsum, supporting same usage as Numpy and Torch + +# Do ad-hoc pattern matching: only support simple cases such as matrix multiply +def einsum(equation: str, operands: Sequence[Tensor]): + if '...' in equation: + raise NotImplementedError('einsum currently does not support ellipsis') + if len(operands) != 2: + raise NotImplementedError('einsum currently only supports 2 operands') + + a = operands[0] + b = operands[1] + equation = equation.replace(' ', '') + lhs, rhs = equation.split('->') + a_subs, b_subs = lhs.split(',') + + if len(rhs) != len(a_subs) or len(a_subs) != len(b_subs): + raise NotImplementedError('einsum currently only supports inputs and output of same rank') + + a_batch, a_dims = a_subs[:-2], a_subs[-2:] + b_batch, b_dims = b_subs[:-2], b_subs[-2:] + c_batch, c_dims = rhs[:-2], rhs[-2:] + + if a_batch != b_batch or a_batch != c_batch: + raise NotImplementedError('einsum currently only supports batched matmul') + + if a_dims[1] == b_dims[0]: + c = matmul(a, b) + elif a_dims[1] == b_dims[1]: + c = matmul(a, b.transpose(-1, -2)) + elif a_dims[0] == b_dims[0]: + c = matmul(a.transpose(-1, -2), b) + elif a_dims[0] == b_dims[1]: + c = matmul(a.transpose(-1, -2), b.transpose(-1, -2)) + else: + raise NotImplementedError('einsum currently only supports batched matmul') + + transpose_c = (c_dims[0] == b_dims[0] or c_dims[0] == b_dims[1]) and ( + c_dims[1] == a_dims[0] or c_dims[1] == a_dims[1] + ) + + if transpose_c: + return c.transpose(-1, -2) + else: + return c From c110fc39043c94261a549e53b9cdbd0e5cbffa92 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Sun, 20 Aug 2023 22:51:46 -0400 Subject: [PATCH 6/6] add compile lock to build_task, use context to manage imap so hidet build can work with spawned processes --- python/hidet/drivers/build_task.py | 63 +++++++++++++++--------------- python/hidet/drivers/utils.py | 26 ++++++++++++ python/hidet/utils/multiprocess.py | 6 ++- tests/distributed/test_runtime.py | 27 +++++++++++++ 4 files changed, 89 insertions(+), 33 deletions(-) diff --git a/python/hidet/drivers/build_task.py b/python/hidet/drivers/build_task.py index 8aa0cd2a6..4d44de5f5 100644 --- a/python/hidet/drivers/build_task.py +++ b/python/hidet/drivers/build_task.py @@ -24,7 +24,7 @@ from hidet.ir.module import IRModule from hidet.ir.task import Task from hidet.drivers.build_module import build_ir_module, build_ir_module_batch -from hidet.drivers.utils import lazy_initialize_cuda +from hidet.drivers.utils import lazy_initialize_cuda, CompileLock from hidet.runtime.compiled_module import compiled_module_exists from hidet.runtime.compiled_task import CompiledTask, TensorSignature, load_compiled_task, compiled_task_cache from hidet.runtime.device import Device @@ -250,42 +250,43 @@ def build_task(task: Task, target='cuda', load=True) -> Optional[CompiledTask]: lib_path = os.path.join(task_dir, 'lib.so') version_path = os.path.join(task_dir, 'version.txt') - version_matched = False - if os.path.exists(version_path): - with open(version_path, 'r') as f: - version = f.read() - if version.strip() == hidet.__version__: - version_matched = True - - # use previously generated library when available - if use_cache and version_matched and compiled_module_exists(task_dir): - logger.debug(f"Load cached task binary {green(task.name)} from path: \n{cyan(lib_path)}") - if load: - compiled_task = load_compiled_task(task_dir) - compiled_task_cache.add(target, space_level, task_string, compiled_task) - else: - logger.info(f"Compiling {target} task {green(task.signature())}...") + with CompileLock(os.path.join(task_dir, "compile.lock"), enabled=use_cache): + version_matched = False + if os.path.exists(version_path): + with open(version_path, 'r') as f: + version = f.read() + if version.strip() == hidet.__version__: + version_matched = True + + # use previously generated library when available + if use_cache and version_matched and compiled_module_exists(task_dir): + logger.debug(f"Load cached task binary {green(task.name)} from path: \n{cyan(lib_path)}") + if load: + compiled_task = load_compiled_task(task_dir) + compiled_task_cache.add(target, space_level, task_string, compiled_task) + else: + logger.info(f"Compiling {target} task {green(task.signature())}...") - # build from scratch - os.makedirs(task_dir, exist_ok=True) + # build from scratch + os.makedirs(task_dir, exist_ok=True) - # write task - with open(os.path.join(task_dir, 'task.txt'), 'w') as f: - f.write(task_string) + # write task + with open(os.path.join(task_dir, 'task.txt'), 'w') as f: + f.write(task_string) - # write version - with open(version_path, 'w') as f: - f.write(hidet.__version__) + # write version + with open(version_path, 'w') as f: + f.write(hidet.__version__) - # implement task to IRModule, each task may produce multiple IRModules (candidates) - # they have the same functionality but different performance - candidates = task.implement(target=target, working_dir=task_dir) + # implement task to IRModule, each task may produce multiple IRModules (candidates) + # they have the same functionality but different performance + candidates = task.implement(target=target, working_dir=task_dir) - # generate meta data - generate_meta_data(task, task_dir, target, len(candidates)) + # generate meta data + generate_meta_data(task, task_dir, target, len(candidates)) - # construct the ir module for the task - build_task_module(task, candidates, task_dir, target) + # construct the ir module for the task + build_task_module(task, candidates, task_dir, target) if load: compiled_task = load_compiled_task(task_dir) diff --git a/python/hidet/drivers/utils.py b/python/hidet/drivers/utils.py index 167fc9456..661c61af5 100644 --- a/python/hidet/drivers/utils.py +++ b/python/hidet/drivers/utils.py @@ -10,6 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os +from filelock import FileLock import hidet.cuda @@ -36,3 +38,27 @@ def lazy_initialize_cuda(): for i in range(hidet.cuda.device_count()): hidet.cuda.properties(i) hidet.cuda.compute_capability(i) + + +class CompileLock: + # There are cases where multiple instances of the same task or IRModule will be built at the same time, especially + # in the cases of distributed inference. We use this lock to make sure that a single task is not built multiple + # times. + + def __init__(self, lock_path: str, enabled=True): + self.enabled = enabled + self.lock_path = lock_path + self.lock = None + + def __enter__(self): + if not self.enabled: + return None + + os.makedirs(os.path.dirname(self.lock_path), exist_ok=True) + self.lock = FileLock(self.lock_path) + self.lock.acquire() + return self.lock + + def __exit__(self, exc_type, exc_value, exc_tb): + if self.lock: + self.lock.release() diff --git a/python/hidet/utils/multiprocess.py b/python/hidet/utils/multiprocess.py index fd79d506d..b30c343b5 100644 --- a/python/hidet/utils/multiprocess.py +++ b/python/hidet/utils/multiprocess.py @@ -48,7 +48,8 @@ def parallel_imap(func: Callable, jobs: Sequence[Any], num_workers: Optional[int if num_workers is None: num_workers = os.cpu_count() - with multiprocessing.Pool(num_workers) as pool: + ctx = multiprocessing.get_context('fork') + with ctx.Pool(num_workers) as pool: yield from pool.imap(_wrapped_func, range(len(jobs))) _job_queue = None @@ -65,7 +66,8 @@ def parallel_map(func: Callable, jobs: Sequence[Any], num_workers: Optional[int] if num_workers is None: num_workers = os.cpu_count() - with multiprocessing.Pool(num_workers) as pool: + ctx = multiprocessing.get_context('fork') + with ctx.Pool(num_workers) as pool: ret = pool.map(_wrapped_func, range(len(jobs))) _job_queue = None diff --git a/tests/distributed/test_runtime.py b/tests/distributed/test_runtime.py index 7ce21e773..4bc9656d8 100644 --- a/tests/distributed/test_runtime.py +++ b/tests/distributed/test_runtime.py @@ -163,6 +163,33 @@ def test_send_recv(rank): assert numpy.array_equal(x.cpu().numpy(), [[1, 2], [3, 4]]) +def build_job(): + import hidet + + a = hidet.symbol([1, 2, 4], dtype='float16', device='cuda') + b = hidet.symbol([1, 4, 8], dtype='float16', device='cuda') + c = hidet.ops.matmul(a, b) + d = 2 * (c + c).reshape([1, 16]) + + flow_graph = hidet.trace_from(d, inputs=[a, b]) + flow_graph = hidet.graph.optimize(flow_graph) + compiled_graph = flow_graph.build() + + +@pytest.mark.parametrize("world_size", [2]) +def test_parallel_build(world_size): + import multiprocessing as mp + + ctx = mp.get_context('spawn') + processes = [ctx.Process(target=build_job) for _ in range(world_size)] + + for p in processes: + p.start() + for p in processes: + p.join(timeout=100) + assert p.exitcode == 0 + + if __name__ == '__main__': test_all_reduce() test_broadcast()