diff --git a/python/hidet/graph/frontend/torch/interpreter.py b/python/hidet/graph/frontend/torch/interpreter.py index 121d837dd..29efb0d59 100644 --- a/python/hidet/graph/frontend/torch/interpreter.py +++ b/python/hidet/graph/frontend/torch/interpreter.py @@ -359,6 +359,10 @@ def load_arg(a, env): hidet_kwargs = load_arg(node.kwargs, hidet_env) try: hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs) + from .register_functions import setitem + + if exec_func.functions[0] is setitem: + hidet_env[str(node.args[0])] = hidet_env[node.name] except Exception as e: self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs) elif node.op == "call_method": @@ -448,6 +452,10 @@ def load_arg(a, env): try: hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs) + from .register_functions import setitem + + if hidet_func.functions[0] is setitem: + hidet_env[str(node.args[0])] = hidet_env[node.name] except Exception as e: self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 250a3deb5..0eddb6726 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -18,14 +18,13 @@ from hidet.graph import ops from hidet.utils import same_list from hidet.ir.type import DataType -from hidet.ir.expr import Expr from hidet.ir import expr from hidet.ir.dtypes import promote_type -from hidet.ir.expr import Int +from hidet.ir.expr import Expr, Int, is_constant from hidet.runtime.device import Device from .interpreter import register_function, register_method from .interpreter import warnings -from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar +from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar, convert_to_scalar_if_possible Number = Union[int, float, bool] @@ -97,6 +96,11 @@ def adaptive_avg_pool2d(x: Tensor, output_size): return ops.adaptive_avg_pool2d(x, output_size) +@register_function(torch.nn.functional.adaptive_avg_pool3d) +def adaptive_avg_pool3d(x: Tensor, output_size): + return ops.adaptive_avg_pool3d(x, output_size) + + @register_function(torch.nn.functional.relu) def relu(x: Tensor, inplace: bool): # if inplace: @@ -109,11 +113,9 @@ def relu(x: Tensor, inplace: bool): def max_pool2d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=False, return_indices=False): if dilation != 1 and not same_list(dilation, [1, 1]): raise NotImplementedError("dilation != 1") - if ceil_mode: - raise NotImplementedError("ceil_mode=True") if return_indices: raise NotImplementedError("return_indices=True") - y = ops.max_pool2d(x, kernel_size, stride, padding) + y = ops.max_pool2d(x, kernel_size, stride, padding, ceil_mode=ceil_mode) return y @@ -130,7 +132,9 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode= @register_function(torch.nn.functional.linear) -def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]): +def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False): + if len(weight.shape) > 1 and not weight_is_transposed: + weight = ops.transpose(weight, [1, 0]) y = ops.matmul(x, weight) if bias is not None: y = y + bias @@ -205,10 +209,18 @@ def batch_norm( ) y = ops.batch_norm_infer(x, running_mean, running_var, epsilon=eps) _ = momentum # unused + if len(x.shape) == 3: + dims = [0, 2] + if len(x.shape) == 4: + dims = [0, 2, 3] + elif len(x.shape) == 5: + dims = [0, 2, 3, 4] + else: + raise NotImplementedError("batch_norm only accepts 3D, 4D, 5D input") if weight is not None: - y = y * weight.unsqueeze([0, 2, 3]) + y = y * weight.unsqueeze(dims) if bias is not None: - y = y + bias.unsqueeze([0, 2, 3]) + y = y + bias.unsqueeze(dims) return y @@ -222,6 +234,70 @@ def getitem(x: Tensor, index): return x[index] +@register_function(operator.setitem) +def setitem(x: Tensor, item, setvalue): + + if isinstance(item, list): + item = tuple(item) + if not isinstance(item, tuple): + item = tuple([item]) + + if not isinstance(setvalue, (int, float)): + raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values') + + # now, the item could have + # 1. integer index + # 2. slice + # 3. Ellipsis + # 4. None + # e.g., [1, 3:5, ..., None] + + # process Ellipsis + # e.g., x[1, ..., 2] -> x[1, :, :, 2] + if Ellipsis in item: + if item.count(Ellipsis) > 1: + raise ValueError('Only one ellipsis allowed in index.') + ellipsis_index = item.index(Ellipsis) + ellipsis_ndim = len(x.shape) - sum([1 if axis not in [None, Ellipsis] else 0 for axis in item]) + ellipsis_ndim = max(ellipsis_ndim, 0) + item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :] + + # normalize index + normalized_item = [] + for i, v in enumerate(item): + if isinstance(v, int): + if v < 0: + v = v + x.shape[i] + if is_constant(v, x.shape[i]) and (v < 0 or v >= x.shape[i]): + raise IndexError('index {} is out of bound for dimension {} with size {}'.format(v, i, x.shape[i])) + normalized_item.append(v) + elif v is not None: + # None affects getitem, but is ignored in setitem + normalized_item.append(v) + item = tuple(normalized_item) + + # process slice and integer index + rank = len(x.shape) + while len(item) < rank: + item = item + (slice(None),) + starts, ends, steps = [], [], [] + squeeze_dims = [] + for dim, v in enumerate(item): + if isinstance(v, (int, Expr)): + squeeze_dims.append(dim) + starts.append(v) + ends.append(v + 1) + steps.append(1) + else: + assert isinstance(v, slice) + starts.append(v.start) + ends.append(v.stop) + steps.append(v.step) + + out = ops.set_strided_slice(x, starts, ends, steps, setvalue) + return out + + @register_function(operator.mul) @register_function(torch.mul) @register_function(torch.ops.aten.mul.Tensor) @@ -516,6 +592,7 @@ def permute(x: Tensor, *args): return ops.transpose(x, dims) +@register_function(torch.swapaxes) @register_function(torch.transpose) @register_method(torch.Tensor.transpose) def transpose(x: Tensor, dim0: int, dim1: int): @@ -590,7 +667,7 @@ def addmm( @register_function(torch.where) -def where(condition: Tensor, x: Tensor, y: Tensor): +def where(condition: Tensor, x: Union[Tensor, Number], y: Union[Tensor, Number]): return ops.where(cond=condition, x=x, y=y) @@ -697,6 +774,7 @@ def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: @register_function(torch.exp) +@register_method(torch.Tensor.exp) def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor: if out is not None: warnings.warn_once("hidet: does not support torch.exp(..., out=...)") @@ -931,6 +1009,13 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor @register_function(operator.eq) def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: + if isinstance(a, Tensor) or isinstance(b, Tensor): + from hidet.graph.ops.utils import convert_to_tensor + + if isinstance(a, Tensor): + return ops.equal(a, convert_to_tensor(b, a)) + else: + return ops.equal(b, convert_to_tensor(a, b)) return a == b @@ -981,6 +1066,28 @@ def torch_mean( return output +@register_function(torch.sum) +@register_method(torch.Tensor.sum) +def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor: + if dtype: + x = x.astype(dtype_from_torch(dtype)) + output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True) + return output + + +@register_function(torch.sum) +@register_method(torch.Tensor.sum) +def torch_sum( + x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None +) -> Tensor: + if out is not None: + raise NotImplementedError("hidet: does not support torch.sum(..., out=...)") + if dtype: + x = x.astype(dtype_from_torch(dtype)) + output = ops.sum(x, dims=dim, keep_dim=keepdim) + return output + + @register_function(torch.cumsum) def torch_cumsum(x: Tensor, dim, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None) -> Tensor: if out is not None: @@ -1069,3 +1176,96 @@ def zeros_like( hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype) if dtype else x.dtype return ops.full(x.shape, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero) + + +@register_function(torch.clamp) +def clamp( + x: Tensor, + min: Optional[Union[Tensor, Number]] = None, + max: Optional[Union[Tensor, Number]] = None, + *, + out: Optional[Tensor] = None, +) -> Tensor: + if out is not None: + raise NotImplementedError("hidet: does not support torch.clamp(..., out=...)") + + min = convert_to_scalar_if_possible(min) + max = convert_to_scalar_if_possible(max) + + if min is None and max is None: + return x + elif min is None: + if not isinstance(max, Tensor): + assert isinstance(max, (int, float, complex)) + max = ops.full([], value=max, dtype=x.dtype, device=x.device) + return ops.minimum(x, max) + elif max is None: + if not isinstance(min, Tensor): + assert isinstance(min, (int, float, complex)) + min = ops.full([], value=min, dtype=x.dtype, device=x.device) + return ops.maximum(x, min) + else: + return ops.clamp(x, min, max) + + +@register_function(torch.isinf) +def isinf(x: Tensor) -> Tensor: + return ops.isinf(x) + + +@register_function(torch.nn.functional.pad) +def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0): + if isinstance(pad, tuple): + pad = list(pad) + # Torch's pad list has form [p2left, p2right, p1left, p1right, p0left, p0right] + # Hidet's pad list has form [p0left, p1left, p2left, p0right, p1right, p2right] + left = [] + right = [] + for i, p in enumerate(pad): + if i % 2 == 0: + left.append(p) + else: + right.append(p) + left.reverse() + right.reverse() + pad = [] + for p in left: + pad.append(p) + for p in right: + pad.append(p) + return ops.pad(x, pads=pad, mode=mode, value=value) + + +@register_function(torch.roll) +def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None): + return ops.roll(x, shifts, dims) + + +@register_function(torch.nn.functional.normalize) +def torch_normalize(x: Tensor, p=2.0, dim=1, eps=1e-12, out=None): + if out is not None: + raise NotImplementedError("out is not None") + return ops.lp_norm(x, p, dim, eps) + + +@register_function(torch.clone) +@register_method(torch.Tensor.clone) +def torch_clone(x: Tensor, *, memory_format=torch.preserve_format): + if memory_format is not torch.preserve_format: + warnings.warn_once( + "torch.clone got memory_format not torch.preserve_format, treating it as torch.preserve_format" + ) + if x.is_symbolic(): + return x + else: + return x.copy() + + +@register_function(torch.chunk) +def torch_chunk(x: Tensor, chunks: int, dim: int = 0): + return ops.split(x, parts_or_sections=chunks, axis=dim) + + +@register_function(torch.einsum) +def torch_einsum(equation, *operands): + return ops.einsum(equation, operands) diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index 07b25721a..a9b5c9d21 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -90,6 +90,11 @@ def tensor_to(self: Tensor, *args, **kwargs) -> Tensor: if self.is_symbolic() and instantiate_device(device_from_torch(arg)) != self.device: raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.') device = arg + elif isinstance(arg, Tensor): + dtype = arg.dtype + if self.is_symbolic() and arg.device != self.device: + raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.') + device = arg.device else: raise ValueError(f'Unsupported argument type: {type(arg)}') @@ -222,7 +227,10 @@ def tensor_type(self: Tensor, dtype: Union[str, torch.dtype], non_blocking: bool @register_method(torch.Tensor.expand) def tensor_expand(self: Tensor, *sizes: int) -> Tensor: - sizes: List[int] = list(sizes) + if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)): + sizes = sizes[0] + else: + sizes: List[int] = list(sizes) assert len(sizes) >= len(self.shape) for i in range(len(sizes)): if sizes[i] == -1: @@ -255,3 +263,42 @@ def tensor_repeat(self: Tensor, *sizes: int) -> Tensor: @register_method(torch.Tensor.detach) def tensor_detach(self: Tensor) -> Tensor: return self + + +@register_method(torch.Tensor.any) +def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor: + return ops.any(self, axis=dim, keepdims=keepdim) + + +@register_method(torch.Tensor.all) +def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor: + return ops.all(self, axis=dim, keepdims=keepdim) + + +@register_method(torch.Tensor.matmul) +def tensor_matmul(self: Tensor, other: Tensor) -> Tensor: + return ops.matmul(self, other) + + +@register_method(torch.Tensor.new_zeros) +def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False): + if layout is not None: + raise NotImplementedError("layout is not None") + if len(size) == 1: + if isinstance(size[0], (list, tuple)): + size = size[0] + shape = size + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + _ = pin_memory + _ = requires_grad + + return ops.full(shape, dtype=dtype, device=device, value=dtype.zero) + + +@register_method(torch.Tensor.zero_) +def tensor_zero_(self: Tensor): + return ops.full(self.shape, dtype=self.dtype, device=self.device, value=self.dtype.zero) diff --git a/python/hidet/graph/frontend/torch/register_modules.py b/python/hidet/graph/frontend/torch/register_modules.py index c4c58774f..df3055216 100644 --- a/python/hidet/graph/frontend/torch/register_modules.py +++ b/python/hidet/graph/frontend/torch/register_modules.py @@ -11,6 +11,7 @@ # limitations under the License. from __future__ import annotations import torch +from hidet.graph import ops from hidet.graph.tensor import Tensor from .interpreter import HidetModule, register_module from . import register_functions as regs @@ -117,6 +118,13 @@ def __call__(self, x: Tensor) -> Tensor: return regs.adaptive_avg_pool2d(x, self.mod.output_size) +@register_module(torch.nn.AdaptiveAvgPool3d) +class HidetAdaptiveAvgPool3d(HidetModule): + def __call__(self, x: Tensor) -> Tensor: + assert isinstance(self.mod, torch.nn.AdaptiveAvgPool3d) + return regs.adaptive_avg_pool3d(x, self.mod.output_size) + + @register_module(torch.nn.ReLU) class HidetReLU(HidetModule): def __call__(self, x: Tensor) -> Tensor: @@ -158,21 +166,21 @@ def __call__(self, x: Tensor) -> Tensor: class HidetLinear(HidetModule): def __init__(self, torch_module: torch.nn.Module): super().__init__(torch_module) - from hidet import ops - steal = dynamo_config['steal_weights'] - self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0]) def __call__(self, x: Tensor) -> Tensor: assert isinstance(self.mod, torch.nn.Linear) - return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True)) + return regs.linear( + x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True + ) @register_module(torch.nn.BatchNorm2d) +@register_module(torch.nn.BatchNorm3d) class HidetBatchNorm2d(HidetModule): def __call__(self, x: Tensor) -> Tensor: - assert isinstance(self.mod, torch.nn.BatchNorm2d) + assert isinstance(self.mod, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)) return regs.batch_norm( x=x, running_mean=self.param('running_mean'), @@ -404,3 +412,70 @@ def __call__(self, x: Tensor) -> Tensor: align_corners=self.mod.align_corners, recompute_scale_factor=self.mod.recompute_scale_factor, ) + + +@register_module(torch.nn.MultiheadAttention) +class HidetMultiheadAttention(HidetModule): + def __init__(self, torch_module: torch.nn.Module): + super().__init__(torch_module) + steal = dynamo_config['steal_weights'] + self.in_proj_weight_transposed = ops.transpose(self.param('in_proj_weight', steal=steal), [1, 0]) + self.out_proj_weight_transposed = ops.transpose(self.param('out_proj.weight', steal=steal), [1, 0]) + + def __call__( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + average_attn_weights=True, + is_causal=False, + ) -> Tensor: + assert isinstance(self.mod, torch.nn.MultiheadAttention) + supported = ( + self.mod._qkv_same_embed_dim + and self.mod.bias_k is None + and self.mod.bias_v is None + and not self.mod.add_zero_attn + and self.mod.batch_first + and key_padding_mask is None + and not need_weights + ) + if not supported: + raise NotImplementedError( + "Hidet Multihead Attention currently only supports " + "kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False, " + "batch_first=True, forward(key_padding_mask=None, need_weights=False)." + ) + + # Input feed forward + wq, wk, wv = ops.split(self.in_proj_weight_transposed, parts_or_sections=3, axis=1) + query = ops.matmul(query, wq) + key = ops.matmul(key, wk) + value = ops.matmul(value, wv) + if self.mod.in_proj_bias is not None: + bq, bk, bv = ops.split(self.param('in_proj_bias'), parts_or_sections=3, axis=0) + query = ops.add(query, bq) + key = ops.add(key, bk) + value = ops.add(value, bv) + + # Split heads + split_head_dims = [query.shape[0], query.shape[1], self.mod.num_heads, query.shape[2] // self.mod.num_heads] + query = ops.transpose(query.reshape(split_head_dims), [0, 2, 1, 3]) + key = ops.transpose(key.reshape(split_head_dims), [0, 2, 1, 3]) + value = ops.transpose(value.reshape(split_head_dims), [0, 2, 1, 3]) + + # fmha + out = regs.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=self.mod.dropout, is_causal=is_causal + ) + + # Output feed forward + merge_head_dims = [out.shape[0], out.shape[2], self.mod.embed_dim] + out = ops.transpose(out, [0, 2, 1, 3]).reshape(merge_head_dims) + out = ops.matmul(out, self.out_proj_weight_transposed) + if self.mod.out_proj.bias is not None: + out = ops.add(out, self.param('out_proj.bias')) + return out diff --git a/python/hidet/graph/frontend/torch/utils.py b/python/hidet/graph/frontend/torch/utils.py index 6d4bb72d7..9a2ef141f 100644 --- a/python/hidet/graph/frontend/torch/utils.py +++ b/python/hidet/graph/frontend/torch/utils.py @@ -285,3 +285,12 @@ def normalize_to_scalar(value: Union[Tensor, Expr, float, int, bool]) -> Union[E raise RuntimeError(f'Cannot convert tensor {value.signature()} to scalar') else: return value + + +def convert_to_scalar_if_possible(x: Union[Tensor, Expr, float, int, bool]) -> Optional[Union[Expr, float, int, bool]]: + if isinstance(x, Tensor): + if len(x.shape) == 0 and x.storage: + return x.item() + return None + else: + return x diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index f2644005e..8a2fcd5f3 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -13,15 +13,8 @@ from .matmul import batch_matmul, matmul, matmul_x86 from .conv1d import conv1d, conv1d_gemm from .conv1d_transpose import conv1d_transpose -from .conv2d import ( - conv2d, - conv2d_channel_last, - conv2d_winograd, - conv2d_gemm, - conv2d_gemm_fp16, - conv2d_gemm_fp16_channel_last, - conv2d_gemm_image_transform, -) +from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16 +from .conv2d import conv2d_gemm_fp16_channel_last, conv2d_gemm_image_transform from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm from .conv3d import conv3d, conv3d_gemm from .conv3d_transpose import conv3d_transpose @@ -31,15 +24,15 @@ from .activation import logsigmoid, celu, hardshrink, softplus, softsign, tanhshrink from .activation import softshrink, softmax, softmin, hardtanh from .attention import attention -from .normalize import batch_norm_infer, instance_norm, layer_norm, group_norm +from .normalize import batch_norm_infer, instance_norm, layer_norm, group_norm, lp_norm from .image import resize2d from .create import full, arange, linspace, tri from .arithmetic import add, subtract, multiply, divide, mod, remainder, negative, positive, square from .arithmetic import floor, ceil, round, trunc, sqrt, rsqrt, pow, abs from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or -from .arithmetic import bitwise_xor, maximum, minimum -from .arithmetic import isfinite, isinf, isnan, sign, where +from .arithmetic import bitwise_xor, maximum, minimum, clamp +from .arithmetic import isfinite, isinf, isnan, sign, where, set_strided_slice, roll from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2 from .complex import real, imag, conj, make_complex from .compare import equal, not_equal, less, greater, less_equal, greater_equal @@ -54,5 +47,6 @@ from .transfer import transfer from .special import barrier from .distributed import all_reduce, all_gather, reduce_scatter +from .linear import einsum from . import utils diff --git a/python/hidet/graph/ops/arithmetic.py b/python/hidet/graph/ops/arithmetic.py index 9acc397c3..7233dc197 100644 --- a/python/hidet/graph/ops/arithmetic.py +++ b/python/hidet/graph/ops/arithmetic.py @@ -10,16 +10,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # pylint: disable=redefined-builtin, unnecessary-lambda -from typing import List, Callable, Any, Union, Optional, Dict +from typing import List, Callable, Any, Union, Optional, Dict, Sequence from hidet.ir import primitives -from hidet.ir import expr, dtypes +from hidet.ir import Var, expr, dtypes from hidet.ir.type import DataType -from hidet.ir.expr import Expr, Var, if_then_else +from hidet.ir.expr import Expr, if_then_else, logical_or, is_constant, is_true from hidet.ir.tools import rewrite from hidet.utils import prod, same_list from .utils import Task, Operator, Tensor, TensorNode, InverseMap, compute, input_like from .utils import broadcast_shape, broadcast_shapes, broadcast_indices +from .utils import normalize_slice, normalize_dim + +PyScalar = Union[int, float, bool] + # In order for the subgraph rewrite of Composite Elementwise Operator to work, # we need to store the callable in an Operator object. But lambda cannot be pickled, @@ -117,7 +121,7 @@ def __init__(self, name: str, args: List[TensorNode], op: Callable[[Any], Any]): inverse_map={ v: InverseMap.identity(len(v_shape)) for v, v_shape in zip(args, shapes) - if prod(v_shape) == prod(out_shape) + if is_true(prod(v_shape) == prod(out_shape)) and len(v_shape) == len(out_shape) }, ) @@ -185,6 +189,78 @@ def __init__(self, cond: TensorNode, x: TensorNode, y: TensorNode): ) +class SetStridedSliceTask(Task): + def __init__( + self, + data: TensorNode, + starts: List[Optional[int]], + ends: List[Optional[int]], + axes: List[int], + strides: List[int], + setvalue: [Union[int, float]], + ): + assert len(starts) == len(ends) == len(axes) == len(strides) + if len(axes) != len(set(axes)): + raise ValueError('Duplicated axes in slice, axes: {}'.format(axes)) + output_shape = list(data.shape) + axis2info = {} + for axis, start, end, stride in zip(axes, starts, ends, strides): + if stride == 0: + raise NotImplementedError( + 'Stride can not be 0 in slicing: ' + 'starts {} ends {} axes {} strides {}.'.format(starts, ends, axes, strides) + ) + if is_constant(output_shape[axis]) and output_shape[axis] < 0: + raise NotImplementedError( + 'Slice result can not be: ' + 'starts {} ends {} axes {} strides {}'.format(starts, ends, axes, strides) + ) + axis2info[axis] = (start, end, stride) + + def fmap(indices): + ret = data.type.dtype(setvalue) + for axis, index in enumerate(indices): + start, end, stride = axis2info[axis] + ret = if_then_else( + logical_or(index < start, index >= end, (index - start) % stride != 0), data[indices], ret + ) + return ret + + out = compute('out', shape=output_shape, fcompute=lambda *indices: fmap(indices)) + super().__init__(name='set_slice', inputs=[data], outputs=[out]) + + +class RollTask(Task): + def __init__(self, x: TensorNode, shifts: Sequence[int], dims: Sequence[int]): + output_shape = list(x.shape) + + def fmap(indices): + data_indices = [] + for axis, index in enumerate(indices): + if axis in dims: + i = dims.index(axis) + if shifts[i] > 0: + data_indices.append( + if_then_else( + index - shifts[i] >= 0, index - shifts[i], index + output_shape[axis] - shifts[i] + ) + ) + else: + data_indices.append( + if_then_else( + index - shifts[i] < output_shape[axis], + index - shifts[i], + index - output_shape[axis] - shifts[i], + ) + ) + else: + data_indices.append(index) + return x[data_indices] + + out = compute('out', shape=output_shape, fcompute=lambda *indices: fmap(indices)) + super().__init__(name='roll', inputs=[x], outputs=[out]) + + class UnaryElementwiseOp(Operator): def __init__(self, x: Tensor, op, name: str, attributes: Optional[Dict[str, Any]] = None, task_attributes=None): if attributes is None: @@ -207,6 +283,15 @@ def __init__(self, x: Tensor, y: Tensor, op, name: str): ) +def get_dtype(scalar: Expr): + from hidet.ir.tools import infer_type + + inferred_type = infer_type(scalar) + if not isinstance(inferred_type, DataType): + raise TypeError(f'Expected scalar to be of type DataType, got {type(inferred_type)}') + return inferred_type + + class CompositeElementwiseOp(Operator): def __init__( self, @@ -238,37 +323,37 @@ def resolve_dtype(tensor_dtype: DataType, scalar_dtype: DataType) -> DataType: class AddScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v + dtype(scalar), attributes={'scalar': scalar}, name='adds') class SubScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v - dtype(scalar), attributes={'scalar': scalar}, name='subs') class RSubScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: dtype(scalar) - v, attributes={'scalar': scalar}, name='rsubs') class MultiplyScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v * dtype(scalar), attributes={'scalar': scalar}, name='muls') class DivideScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: v / dtype(scalar), attributes={'scalar': scalar}, name='divs') class RDivideScalarOp(UnaryElementwiseOp): def __init__(self, x: Tensor, scalar: Expr): - dtype = resolve_dtype(x.dtype, scalar.type) + dtype = resolve_dtype(x.dtype, get_dtype(scalar)) super().__init__(x, op=lambda v: dtype(scalar) / v, attributes={'scalar': scalar}, name='rdivs') @@ -478,6 +563,19 @@ def __init__(self, x: Tensor): ) +class ClampOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, min_value: Union[int, float], max_value: Union[int, float]): + assert isinstance(min_value, (int, float)) + assert isinstance(max_value, (int, float)) + min_value = x.dtype(min_value) + max_value = x.dtype(max_value) + super().__init__( + x, + op=lambda a: if_then_else(a < min_value, min_value, if_then_else(a > max_value, max_value, a)), + name='clamp', + ) + + class RightShiftOp(BinaryElementwiseOp): def __init__(self, x: Tensor, y: Tensor): super().__init__(x, y, op=lambda a, b: expr.RightShift(a, b), name='rightshift') @@ -522,6 +620,47 @@ def __init__(self, cond: Tensor, x: Tensor, y: Tensor): ) +class WhereScalarScalarOp(Operator): + def __init__(self, cond: Tensor, x: PyScalar, y: PyScalar): + if isinstance(x, int) and isinstance(y, int): + dtype = dtypes.default_int_dtype + elif isinstance(x, float) or isinstance(y, float): + dtype = dtypes.default_float_dtype + else: + raise ValueError(f'Unsupported scalar type: {type(x)}') + x, y = dtype(x), dtype(y) + super().__init__( + inputs=[cond], + attributes={'x': x, 'y': y}, + task=UnaryElementwiseTask(name='where', x=input_like(cond, 'cond'), op=lambda a: if_then_else(a, x, y)), + ) + + +class WhereScalarTensorOp(Operator): + def __init__(self, cond: Tensor, y: Tensor, x: PyScalar): + dtype = y.dtype + x = dtype(x) + super().__init__( + inputs=[cond, y], + attributes={'x': x}, + task=BinaryElementwiseTask( + name='where', x=input_like(cond, 'cond'), y=input_like(y, 'y'), op=lambda a, b: if_then_else(a, x, b) + ), + ) + + +class WhereTensorScalarOp(Operator): + def __init__(self, cond: Tensor, x: Tensor, y: PyScalar): + y = x.dtype(y) + super().__init__( + inputs=[cond, x], + attributes={'y': y}, + task=BinaryElementwiseTask( + name='where', x=input_like(cond, 'cond'), y=input_like(x, 'x'), op=lambda a, b: if_then_else(a, b, y) + ), + ) + + class MaxOp(Operator): def __init__(self, *tensors: Tensor): def scalar_max(args: List[expr.Expr]): @@ -560,6 +699,32 @@ def scalar_min(args: List[expr.Expr]): ) +class SetStridedSliceOp(Operator): + def __init__( + self, + data: Tensor, + starts: Sequence[Optional[int]], + ends: Sequence[Optional[int]], + strides: Optional[Sequence[Optional[int]]] = None, + setvalue: Optional[Union[int, float]] = 0.0, + ): + starts, ends, axes, strides = normalize_slice(data.shape, starts, ends, axes=None, strides=strides) + task = SetStridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides, setvalue) + super().__init__( + inputs=[data], + attributes={'starts': starts, 'ends': ends, 'strides': strides, 'setvalue': setvalue}, + task=task, + ) + + +class RollOp(Operator): + def __init__(self, x: Tensor, shifts: Sequence[int], dims: Sequence[int]): + if not len(shifts) == len(dims): + raise ValueError('Roll must have same size shifts and dims, got {} and {}'.format(len(shifts), len(dims))) + task = RollTask(input_like(x, 'x'), shifts, dims) + super().__init__(inputs=[x], attributes={'shifts': shifts, 'dims': dims}, task=task) + + Scalar = Union[Expr, float, int, complex] @@ -792,10 +957,25 @@ def sign(x: Tensor) -> Tensor: return SignOp(x).outputs[0] -def where(cond: Tensor, x: Tensor, y: Tensor) -> Tensor: +def clamp(x: Tensor, min: Union[Tensor, float, int], max: Union[Tensor, float, int]) -> Tensor: + if isinstance(min, Tensor) or isinstance(max, Tensor): + raise NotImplementedError('clamp with tensor min/max is not implemented yet') + return ClampOp(x, min, max).outputs[0] + + +def where(cond: Tensor, x: Union[Tensor, PyScalar], y: Union[Tensor, PyScalar]) -> Tensor: if cond.dtype != dtypes.boolean: raise ValueError('The condition tensor must have dtype "bool", but got {}'.format(cond.dtype.name)) - return WhereOp(cond, x, y).outputs[0] + if isinstance(x, Tensor) and isinstance(y, Tensor): + return WhereOp(cond, x, y).outputs[0] + elif isinstance(x, Tensor) and isinstance(y, (int, float, complex)): + return WhereTensorScalarOp(cond, x=x, y=y).outputs[0] + elif isinstance(x, (int, float, complex)) and isinstance(y, Tensor): + return WhereScalarTensorOp(cond, x=x, y=y).outputs[0] + elif isinstance(x, (int, float, complex)) and isinstance(y, (int, float, complex)): + return WhereScalarScalarOp(cond, x=x, y=y).outputs[0] + else: + raise ValueError('Invalid arguments for where: x={}, y={}'.format(x, y)) def maximum(a: Tensor, b: Tensor, *others: Tensor) -> Tensor: @@ -812,7 +992,8 @@ def mod(x: Tensor, y: Tensor) -> Tensor: return ModOp(x, y).outputs[0] -remainder = mod +def remainder(x: Tensor, y: Tensor) -> Tensor: + return mod(x, y) def abs(x: Tensor) -> Tensor: @@ -864,6 +1045,20 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: return log(exp(x - max_val) + exp(y - max_val)) + max_val +def roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None) -> Tensor: + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + if dims is None: + from .transform import flatten, reshape + + shape = x.shape + return reshape(RollOp(flatten(x), shifts, dims=[0]).outputs[0], shape) + dims = normalize_dim(dims, len(x.shape)) + return RollOp(x, shifts, dims).outputs[0] + + # out = binary_op(left_unary_op(x), right_unary_op(x)); This allows more fusion opportunity. def composite_elementwise( x: Tensor, @@ -872,3 +1067,13 @@ def composite_elementwise( binary_op: BinaryElementwiseOperation, ) -> Tensor: return CompositeElementwiseOp(x, left_unary_op, right_unary_op, binary_op).outputs[0] + + +def set_strided_slice( + data: Tensor, + starts: Sequence[Optional[int]], + ends: Sequence[Optional[int]], + strides: Optional[Sequence[Optional[int]]] = None, + setvalue: Optional[Union[int, float]] = 0.0, +) -> Tensor: + return SetStridedSliceOp(data, starts, ends, strides, setvalue).outputs[0] diff --git a/python/hidet/graph/ops/attention/attention.py b/python/hidet/graph/ops/attention/attention.py index 3e459b7d1..278deb49b 100644 --- a/python/hidet/graph/ops/attention/attention.py +++ b/python/hidet/graph/ops/attention/attention.py @@ -397,7 +397,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type): def copy_k_g2s_sm80( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j_seg in k_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -411,7 +411,7 @@ def copy_k_g2s_sm80( @hidet.script def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j_seg in v_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -421,7 +421,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j_seg in q_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -433,7 +433,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs def copy_k_g2s_sm75( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j in k_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]: @@ -444,7 +444,7 @@ def copy_k_g2s_sm75( @hidet.script def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j in v_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]: @@ -455,7 +455,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j in q_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]: @@ -488,7 +488,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: @hidet.script def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_o = o[o_head_index][offset_i:, :] for k_round in range(warp_count_k): for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id): @@ -652,12 +652,12 @@ def attn_kernel( v: f16[v_head + [n_kv_size, d_size]], o: f16[o_head + [n_size, d_size]], ): - attrs.cuda.grid_dim = (i_split, bs) + attrs.cuda.grid_dim = i_split * bs attrs.cuda.block_dim = block_size attrs.cuda.min_blocks = 1 attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes - offset_i = blockIdx.x * i_rows_per_tb + offset_i = (blockIdx.x % i_split) * i_rows_per_tb smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout) smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout) @@ -702,7 +702,7 @@ def attn_kernel( j_tiles = cdiv(n_kv_size, block_j) if is_causal: - j_tiles = min(cdiv((blockIdx.x + 1) * block_i, block_j), j_tiles) + j_tiles = min(cdiv(((blockIdx.x % i_split) + 1) * block_i, block_j), j_tiles) for j in range(j_tiles): offset_j = block_j * j diff --git a/python/hidet/graph/ops/attention/attention_mask.py b/python/hidet/graph/ops/attention/attention_mask.py index 873c34de7..1a91c02d6 100644 --- a/python/hidet/graph/ops/attention/attention_mask.py +++ b/python/hidet/graph/ops/attention/attention_mask.py @@ -425,7 +425,7 @@ def init_lm_smem(smem_l: smem_l_type, smem_m: smem_m_type): def copy_k_g2s_sm80( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j_seg in k_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -439,7 +439,7 @@ def copy_k_g2s_sm80( @hidet.script def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j_seg in v_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -449,7 +449,7 @@ def copy_v_g2s_sm80(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j_seg in q_g2s_layout.on(threadIdx.x): j = j_seg * 8 @@ -461,7 +461,7 @@ def copy_q_g2s_sm80(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offs def copy_k_g2s_sm75( k: f16[k_head + [d_size, n_kv_size]], smem_k: smem_k_type, offset_j: i32, offset_k: i32 ): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_k = k[broadcast_indices(o_head_index, k_head, o_head)][offset_k:, offset_j:] for i, j in k_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < k_g2s_layout_sm75.num_workers and i < smem_k_type.shape[0]: @@ -472,7 +472,7 @@ def copy_k_g2s_sm75( @hidet.script def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, offset_j: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_v = v[broadcast_indices(o_head_index, v_head, o_head)][offset_j:, :] for i, j in v_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < v_g2s_layout_sm75.num_workers and i < smem_v_type.shape[0]: @@ -483,7 +483,7 @@ def copy_v_g2s_sm75(v: f16[v_head + [n_kv_size, d_size]], smem_v: smem_v_type, o @hidet.script def copy_q_g2s_sm75(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: i32): - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_q = q[broadcast_indices(o_head_index, q_head, o_head)][offset_i:, :] for i, j in q_g2s_layout_sm75.on(threadIdx.x): if threadIdx.x < q_g2s_layout_sm75.num_workers and i < smem_q_type.shape[0]: @@ -516,7 +516,7 @@ def copy_q_g2s(q: f16[q_head + [n_size, d_size]], smem_q: smem_q_type, offset_i: @hidet.script def copy_o_r2g(o: f16[o_head + [n_size, d_size]], regs_o: regs_o_type, offset_i: i32): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 - o_head_index = spatial(*o_head).map(blockIdx.y) + o_head_index = spatial(*o_head).map(blockIdx.x // i_split) gmem_o = o[o_head_index][offset_i:, :] for k_round in range(warp_count_k): for wi, wj, wk in spatial(warp_count_m_o, warp_count_n_o, warp_count_k_o).on(warp_id): @@ -681,12 +681,12 @@ def attn_kernel( mask: f16[mask_shape], o: f16[o_head + [n_size, d_size]], ): - attrs.cuda.grid_dim = (i_split, bs) + attrs.cuda.grid_dim = i_split * bs attrs.cuda.block_dim = block_size attrs.cuda.min_blocks = 1 attrs.cuda.dynamic_smem_bytes = dynamic_smem_bytes - offset_i = blockIdx.x * i_rows_per_tb + offset_i = (blockIdx.x % i_split) * i_rows_per_tb smem_q = tensor_pointer('float16', shape=smem_q_type.shape, layout=smem_q_type.layout) smem_k = tensor_pointer('float16', shape=smem_k_db_type.shape, layout=smem_k_db_type.layout) @@ -773,7 +773,7 @@ def attn_kernel( copy_v_g2s(v, ~smem_v[0, 0, 0], offset_j) # Apply Masking - qk_head_index = list(spatial(*qk_head).map(blockIdx.y)) + qk_head_index = list(spatial(*qk_head).map(blockIdx.x // i_split)) for mma_i, mma_j in grid(mmas_per_warp_m, mmas_per_warp_n): warp_id, lane_id = threadIdx.x / 32, threadIdx.x % 32 wi, wj, wk = spatial(warp_count_m, warp_count_n, warp_count_k).map(warp_id) diff --git a/python/hidet/graph/ops/image.py b/python/hidet/graph/ops/image.py index b59a7f0af..2100d1be5 100644 --- a/python/hidet/graph/ops/image.py +++ b/python/hidet/graph/ops/image.py @@ -12,7 +12,7 @@ from typing import Optional, List, Sequence, Union from hidet.ir.dtypes import int32 -from hidet.ir.expr import Expr, Int, if_then_else, cast, logical_or, logical_and +from hidet.ir.expr import Expr, Int, if_then_else, cast, logical_or, logical_and, convert from hidet.ir import primitives as prim from .utils import Task, Operator, Tensor, TensorNode, compute, input_like @@ -59,8 +59,8 @@ def get_2d_pixel(data: TensorNode, n, c, h, w) -> Expr: return data[n, c, h, w] -def linear_interpolate(a, b, ratio): - return a * (1.0 - ratio) + b * ratio +def linear_interpolate(a, b, ratio, dtype='float32'): + return a * (convert(1.0, dtype) - ratio) + b * ratio def get_cubic_weights(s, a) -> List[int]: @@ -112,6 +112,7 @@ def resize2d_nchw_compute( scale_factor = _normalize(scale_factor, 2) size = _normalize(size, 2) + dtype = data.type.dtype if size is not None and scale_factor is None: target_size = size @@ -146,12 +147,12 @@ def fmap(n, c, h, w): elif method == 'linear': h_int = cast(prim.floor(h), 'int32') w_int = cast(prim.floor(w), 'int32') - h_ratio = h - h_int - w_ratio = w - w_int + h_ratio = cast(h - h_int, dtype) + w_ratio = cast(w - w_int, dtype) pixels = [[get_2d_pixel(data, n, c, h_int + i, w_int + j) for j in range(2)] for i in range(2)] - top = linear_interpolate(*pixels[0], w_ratio) - bottom = linear_interpolate(*pixels[1], w_ratio) - value = linear_interpolate(top, bottom, h_ratio) + top = linear_interpolate(*pixels[0], w_ratio, dtype) + bottom = linear_interpolate(*pixels[1], w_ratio, dtype) + value = linear_interpolate(top, bottom, h_ratio, dtype) elif method == 'cubic': h_int = cast(prim.floor(h), 'int32') w_int = cast(prim.floor(w), 'int32') diff --git a/python/hidet/graph/ops/linear.py b/python/hidet/graph/ops/linear.py new file mode 100644 index 000000000..99e9db893 --- /dev/null +++ b/python/hidet/graph/ops/linear.py @@ -0,0 +1,59 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence +from .utils import Tensor +from .matmul import matmul + +# ToDo: Actually fully implement einsum, supporting same usage as Numpy and Torch + +# Do ad-hoc pattern matching: only support simple cases such as matrix multiply +def einsum(equation: str, operands: Sequence[Tensor]): + if '...' in equation: + raise NotImplementedError('einsum currently does not support ellipsis') + if len(operands) != 2: + raise NotImplementedError('einsum currently only supports 2 operands') + + a = operands[0] + b = operands[1] + equation = equation.replace(' ', '') + lhs, rhs = equation.split('->') + a_subs, b_subs = lhs.split(',') + + if len(rhs) != len(a_subs) or len(a_subs) != len(b_subs): + raise NotImplementedError('einsum currently only supports inputs and output of same rank') + + a_batch, a_dims = a_subs[:-2], a_subs[-2:] + b_batch, b_dims = b_subs[:-2], b_subs[-2:] + c_batch, c_dims = rhs[:-2], rhs[-2:] + + if a_batch != b_batch or a_batch != c_batch: + raise NotImplementedError('einsum currently only supports batched matmul') + + if a_dims[1] == b_dims[0]: + c = matmul(a, b) + elif a_dims[1] == b_dims[1]: + c = matmul(a, b.transpose(-1, -2)) + elif a_dims[0] == b_dims[0]: + c = matmul(a.transpose(-1, -2), b) + elif a_dims[0] == b_dims[1]: + c = matmul(a.transpose(-1, -2), b.transpose(-1, -2)) + else: + raise NotImplementedError('einsum currently only supports batched matmul') + + transpose_c = (c_dims[0] == b_dims[0] or c_dims[0] == b_dims[1]) and ( + c_dims[1] == a_dims[0] or c_dims[1] == a_dims[1] + ) + + if transpose_c: + return c.transpose(-1, -2) + else: + return c diff --git a/python/hidet/graph/ops/normalize/__init__.py b/python/hidet/graph/ops/normalize/__init__.py index 29b24ce29..a3b05d9be 100644 --- a/python/hidet/graph/ops/normalize/__init__.py +++ b/python/hidet/graph/ops/normalize/__init__.py @@ -10,4 +10,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from .layers import batch_norm_infer, layer_norm, instance_norm, group_norm +from .lp import lp_norm from . import resolve diff --git a/python/hidet/graph/ops/normalize/lp.py b/python/hidet/graph/ops/normalize/lp.py new file mode 100644 index 000000000..e0fe68d13 --- /dev/null +++ b/python/hidet/graph/ops/normalize/lp.py @@ -0,0 +1,98 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir import primitives as prim +from hidet.ir.compute import reduce +from hidet.ir.expr import cast +from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode +from hidet.graph.ops.utils import compute, input_like, normalize_dim + + +class LpNormTask(Task): + """ + Performs LP normalization along specified dimension. + For a tensor input of shape [n0, n1, ..., ndim, ..., nk], each + ndim-element vector v along dimension dim is transformed as: + v = v / max(sum(pow(abs(v), p)), eps) + """ + + def __init__(self, x: TensorNode, p: float, dim: int, eps: float): + x_shape = x.const_shape + y_shape = x_shape + dtype = x.type.dtype + reduce_shape = [] + other_shape = [] + for idx, size in enumerate(x_shape): + if idx == dim: + reduce_shape.append(size) + else: + other_shape.append(size) + + def sum_compute(*indices): + def sum_reduce(*reduction_axis): + x_indices = [] + p = 0 + q = 0 + for i in range(len(x.shape)): + if i != dim: + x_indices.append(indices[p]) + p += 1 + else: + x_indices.append(reduction_axis[q]) + q += 1 + assert p == len(indices) and q == len(reduction_axis) + # Force fp32 reduction for accuracy + return prim.pow(cast(prim.abs(x[x_indices]), 'float32'), p) + + return reduce(shape=reduce_shape, fcompute=sum_reduce, reduce_type='sum') + + sum_ = compute(name='sum', shape=other_shape, fcompute=sum_compute) + + p_norm = compute(name='p_norm', shape=other_shape, fcompute=lambda *indices: prim.pow(sum_[indices], 1.0 / p)) + + def y_compute(*indices): + norm_indices = [index for i, index in enumerate(indices) if i != dim] + return cast(x[indices] / prim.max(p_norm[norm_indices], eps), dtype) + + y = compute(name='y', shape=y_shape, fcompute=y_compute) + + super().__init__(name='lp_norm', inputs=[x], outputs=[y], attributes={'p': p, 'dim': dim, 'eps': eps}) + + +class LpNormOp(Operator): + def __init__(self, x: Tensor, p: float, dim: int, eps: float): + super().__init__( + inputs=[x], attributes={'p': p, 'dim': dim, 'eps': eps}, task=LpNormTask(input_like(x, 'x'), p, dim, eps) + ) + + +def lp_norm(x: Tensor, p=2.0, dim=1, eps=1e-12): + """LP norm. + + Parameters + ---------- + x: Tensor + The data to be normalized. + p: float + The exponent value in the norm formulation. + dim: int + The dimension to reduce. + eps: float + Small value to avoid division by zero. + + Returns + ------- + ret: Tensor + The normalized tensor. + """ + # Normalize dim + dim = normalize_dim(dim, rank=len(x.shape)) + return LpNormOp(x, p, dim, eps).outputs[0] diff --git a/python/hidet/graph/ops/pool.py b/python/hidet/graph/ops/pool.py index 4e34e281c..42544fbf0 100644 --- a/python/hidet/graph/ops/pool.py +++ b/python/hidet/graph/ops/pool.py @@ -19,14 +19,18 @@ class Pool2dTask(Task): - def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str): + def __init__(self, x: TensorNode, kernel, strides, padding, ceil_mode: bool, reduce_type: str): assert reduce_type in ['max', 'avg'] kernel = normalize_kernel(kernel) strides = normalize_stride(strides) padding = normalize_padding(padding) batch_size, channels, height, width = x.shape - out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 - out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 + if ceil_mode: + out_height = (height + padding[0] + padding[2] - kernel[0] + strides[0] - 1) // strides[0] + 1 + out_width = (width + padding[1] + padding[3] - kernel[1] + strides[1] - 1) // strides[1] + 1 + else: + out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 + out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.type.dtype) pad = compute( name='pad', @@ -145,11 +149,12 @@ def __init__( kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], + ceil_mode: bool, ): super().__init__( inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='max'), + attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, + task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='max'), ) @@ -175,11 +180,12 @@ def __init__( kernel: Union[int, Sequence[int]], stride: Union[int, Sequence[int]], padding: Union[int, Sequence[int]], + ceil_mode: bool, ): super().__init__( inputs=[x], - attributes={'kernel': kernel, 'stride': stride, 'padding': padding}, - task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, reduce_type='avg'), + attributes={'kernel': kernel, 'stride': stride, 'padding': padding, 'ceil_mode': ceil_mode}, + task=Pool2dTask(input_like(x, 'x'), kernel, stride, padding, ceil_mode, reduce_type='avg'), ) @@ -245,16 +251,16 @@ def __init__(self, x: Tensor, output_size: Union[int, Sequence[int]]): super().__init__(x, output_size, reduce_type='max', attrs={'output_size': output_size}, spatial_ndim=3) -def max_pool2d(x: Tensor, kernel, stride, padding) -> Tensor: - return MaxPool2dOp(x, kernel, stride, padding).outputs[0] +def max_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return MaxPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] def max_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: return MaxPool3dOp(x, kernel, stride, padding).outputs[0] -def avg_pool2d(x: Tensor, kernel, stride, padding) -> Tensor: - return AvgPool2dOp(x, kernel, stride, padding).outputs[0] +def avg_pool2d(x: Tensor, kernel, stride, padding, ceil_mode=False) -> Tensor: + return AvgPool2dOp(x, kernel, stride, padding, ceil_mode).outputs[0] def avg_pool3d(x: Tensor, kernel, stride, padding) -> Tensor: diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index 769de4cc7..ca669dadb 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -56,7 +56,7 @@ def allow_epilogue(self) -> bool: def allow_prologue(self) -> bool: return False - def implement_cuda(self, working_dir: str) -> IRModule: + def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]: rank = len(self.inputs[0].shape) if rank - 1 in self.dims: return tune.extract_ir_modules(self.cuda_schedule_reduce_by_warp) @@ -80,7 +80,7 @@ def cuda_schedule_reduce_by_warp(self, use_atomic=True) -> IRModule: xdtype = x.type.dtype shape: List[Int] = list(x.shape) lanes = 1 - vtype: DataType = xdtype + vtype: Union[DataType, VectorType] = xdtype if xdtype.nbytes < 4: num_eles: int = 4 // xdtype.nbytes if is_constant(shape[-1]) and shape[-1] % num_eles == 0: @@ -204,7 +204,7 @@ def cuda_schedule_reduce_by_default(self) -> IRModule: shape: List[Int] = list(x.shape) lanes = 1 - vtype: DataType = xdtype + vtype: Union[VectorType, DataType] = xdtype if xdtype.nbytes < 4: num_eles: int = 4 // xdtype.nbytes if shape[-1] % num_eles == 0: diff --git a/python/hidet/graph/ops/transform.py b/python/hidet/graph/ops/transform.py index 90af495c1..e0c5c114d 100644 --- a/python/hidet/graph/ops/transform.py +++ b/python/hidet/graph/ops/transform.py @@ -17,7 +17,7 @@ from hidet.ir.utils import index_deserialize, index_serialize from hidet.utils import prod from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast -from .utils import TensorInput +from .utils import TensorInput, normalize_slice def is_true(x: Union[Expr, bool]) -> bool: @@ -444,53 +444,12 @@ def __init__( axes: Optional[Sequence[Optional[int]]] = None, strides: Optional[Sequence[Optional[int]]] = None, ): - starts, ends, axes, strides = self.normalize(data.shape, starts, ends, axes, strides) + starts, ends, axes, strides = normalize_slice(data.shape, starts, ends, axes, strides) task = StridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides) super().__init__( inputs=[data], attributes={'starts': starts, 'ends': ends, 'axes': axes, 'strides': strides}, task=task ) - @staticmethod - def normalize(data_shape, starts, ends, axes: Optional[List[int]], strides: Optional[List[Optional[int]]]): - # follow: https://data-apis.org/array-api/latest/API_specification/indexing.html - if axes is None: - axes = [i for i in range(len(starts))] - axes = normalize_dim(axes, len(data_shape)) - if strides is None: - strides = [1 for _ in range(len(starts))] - shape = [data_shape[i] for i in axes] - assert len(shape) == len(starts) == len(ends) == len(axes) == len(strides) - - ii, jj, kk = [], [], [] - for i, j, k, n in zip(starts, ends, strides, shape): - if k is None: - k = 1 - if k > 0: - i = i if i is not None else 0 - j = j if j is not None else n - if is_constant(i, j, n) and not (-n <= i <= n and -n <= j): - raise IndexError('Invalid slice') - j = if_then_else(j < n, j, n) - if is_constant(i) and i < 0: - i = i + n - if is_constant(j) and j < 0: - j = j + n - elif k < 0: - i = i if i is not None else n - 1 - j = j if j is not None else -n - 1 - if is_constant(i) and i < 0: - i += n - if is_constant(j) and j < -1: - j += n - if is_constant(i, j, n) and not (-n <= i <= n and -n - 1 <= j <= max(0, n - 1)): - raise IndexError('Invalid slice') - else: - raise IndexError('slice step cannot be zero') - ii.append(i) - jj.append(j) - kk.append(k) - return ii, jj, axes, kk - class BroadcastOp(Operator): def __init__(self, data: Tensor, shape: List[int]): diff --git a/python/hidet/graph/ops/utils/tensor_utils.py b/python/hidet/graph/ops/utils/tensor_utils.py index 3d7a6fde5..93f13a712 100644 --- a/python/hidet/graph/ops/utils/tensor_utils.py +++ b/python/hidet/graph/ops/utils/tensor_utils.py @@ -14,7 +14,7 @@ import builtins from hidet.ir.layout import DataLayout from hidet.ir.type import Int -from hidet.ir.expr import Var, Expr, Constant, is_constant +from hidet.ir.expr import Var, Expr, Constant, is_constant, if_then_else from hidet.ir.type import TensorType, tensor_type, DataType from hidet.ir.task import Task, InverseMap from hidet.ir.module import IRModule @@ -159,3 +159,44 @@ def convert_to_tensor(value: Union[int, float, bool, complex, Tensor], involved_ return full_like(involved_tensor, fill_value=value, shape=[], dtype='float32') else: raise ValueError('Can not recognize dtype {}'.format(involved_tensor.dtype)) + + +def normalize_slice(data_shape, starts, ends, axes: Optional[List[int]], strides: Optional[List[Optional[int]]]): + # follow: https://data-apis.org/array-api/latest/API_specification/indexing.html + if axes is None: + axes = [i for i in range(len(starts))] + axes = normalize_dim(axes, len(data_shape)) + if strides is None: + strides = [1 for _ in range(len(starts))] + shape = [data_shape[i] for i in axes] + assert len(shape) == len(starts) == len(ends) == len(axes) == len(strides) + + ii, jj, kk = [], [], [] + for i, j, k, n in zip(starts, ends, strides, shape): + if k is None: + k = 1 + if k > 0: + i = i if i is not None else 0 + j = j if j is not None else n + if is_constant(i, j, n) and not (-n <= i <= n and -n <= j): + raise IndexError('Invalid slice') + j = if_then_else(j < n, j, n) + if is_constant(i) and i < 0: + i = i + n + if is_constant(j) and j < 0: + j = j + n + elif k < 0: + i = i if i is not None else n - 1 + j = j if j is not None else -n - 1 + if is_constant(i) and i < 0: + i += n + if is_constant(j) and j < -1: + j += n + if is_constant(i, j, n) and not (-n <= i <= n and -n - 1 <= j <= max(0, n - 1)): + raise IndexError('Invalid slice') + else: + raise IndexError('slice step cannot be zero') + ii.append(i) + jj.append(j) + kk.append(k) + return ii, jj, axes, kk diff --git a/python/hidet/graph/tensor.py b/python/hidet/graph/tensor.py index c45f73eee..d7896aa15 100644 --- a/python/hidet/graph/tensor.py +++ b/python/hidet/graph/tensor.py @@ -342,6 +342,15 @@ def __str__(self): def __getitem__(self, item): from hidet.graph.ops import strided_slice + if isinstance(item, Tensor): + if len(item.shape) > 1: + raise NotImplementedError("Tensor indexing via Tensor currently only supports 1D index tensor") + if not item.dtype.is_integer(): + raise TypeError("Tensor indexing via Tensor requires integer index tensor") + from .ops import take + + return take(self, item, axis=0) + if isinstance(item, list): item = tuple(item) diff --git a/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py b/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py index d097dcb97..e8aa18d4d 100644 --- a/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py +++ b/python/hidet/graph/transforms/graph_patterns/conv2d_patterns.py @@ -125,7 +125,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]: x, w, stride=op1.attrs['stride'], - dilations=op1.attrs['dilation'], + dilations=op1.attrs['dilations'], groups=1, padding=op1.attrs['padding'], ) diff --git a/python/hidet/ir/primitives/__init__.py b/python/hidet/ir/primitives/__init__.py index 21c1ae524..49866a0af 100644 --- a/python/hidet/ir/primitives/__init__.py +++ b/python/hidet/ir/primitives/__init__.py @@ -15,7 +15,7 @@ # pylint: disable=redefined-builtin from .math import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, expm1, abs from .math import max, min, exp, pow, sqrt, rsqrt, erf, ceil, log, log2, log10, log1p, round, floor, trunc -from .math import isfinite, isinf, isnan, make_vector +from .math import isfinite, isinf, isnan, make_vector, atan2, mod from .complex import real, imag, conj, make_complex diff --git a/python/hidet/ir/primitives/cuda/math/float16.py b/python/hidet/ir/primitives/cuda/math/float16.py index 344875cfd..ed64d0a1c 100644 --- a/python/hidet/ir/primitives/cuda/math/float16.py +++ b/python/hidet/ir/primitives/cuda/math/float16.py @@ -62,6 +62,7 @@ class CUDAFloat16MathFunctionSet(MathFunctionSet): # pylint: disable=abstract-method def register(self): entries = { + 'abs': ['__habs', 1], 'sin': ['hsin', 1], 'cos': ['hcos', 1], 'exp': ['hexp', 1], @@ -142,6 +143,9 @@ def call(self, name: str, *args) -> Expr: entry = primitive_func_pool.lookup_by_name(name) return entry.var(*args) + def abs(self, a: Expr) -> Expr: + return self.call('cuda_f16_abs', a) + def sin(self, a: Expr) -> Expr: return self.call('cuda_f16_sin', a) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index dd86f0e03..2553bcdbe 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Dict, Any, Callable, Union +from typing import List, Optional, Tuple, Dict, Any, Callable import zipfile import os import json @@ -95,7 +95,6 @@ def __init__( # derived properties (will be initialized in _init_compiled_graph at the end of this constructor) self.dynamic_dims: List[Tuple[str, Tuple[int, int]]] = [] # [(name, (tensor_index, dim_index))] self.is_dynamic: bool = False - self.constant_outputs: List[Union[None, Tensor]] = [] # runtime state self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash) @@ -145,11 +144,6 @@ def _init_compiled_graph(self): self.is_dynamic = True else: self.is_dynamic = False - for out_idx in self.graph_execution.outputs_index: - if out_idx in self.graph_execution.weights_index: - self.constant_outputs.append(self.weights[self.graph_execution.weights_index.index(out_idx)]) - else: - self.constant_outputs.append(None) # initialize weights weights_buffer = Array(void_p, len(self.weights)) @@ -209,13 +203,15 @@ def _update_symbol_dims(self, inputs) -> Tuple[int, ...]: runtime_api.set_symbol_value(name, symbol_dims[-1]) return tuple(symbol_dims) - def _create_outputs(self): + def _create_outputs(self, inputs): from hidet.graph.tensor import empty outputs = [] - for output_index, (sig, const_out) in enumerate(zip(self.meta.outputs, self.constant_outputs)): - if const_out is not None: - outputs.append(const_out) + for output_index, (exec_idx, sig) in enumerate(zip(self.graph_execution.outputs_index, self.meta.outputs)): + if exec_idx in self.graph_execution.inputs_index: + outputs.append(inputs[self.graph_execution.inputs_index.index(exec_idx)]) + elif exec_idx in self.graph_execution.weights_index: + outputs.append(self.weights[self.graph_execution.weights_index.index(exec_idx)]) else: if self.is_dynamic: shape_buffer = Array(i32, len(sig.shape)) @@ -240,7 +236,7 @@ def _prepare_workspace(self): def _run_fast_path(self, inputs, symbol_dims: Tuple[int, ...]): # create output tensors - outputs = self._create_outputs() + outputs = self._create_outputs(inputs) # prepare workspace self._prepare_workspace()