Skip to content

Commit

Permalink
[Complex Type] Add more support for complex data type (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding authored Jun 5, 2023
1 parent 9057b7d commit a1bd18c
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 33 deletions.
22 changes: 22 additions & 0 deletions include/hidet/runtime/cuda/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ HIDET_HOST_DEVICE Complex<T> conj(Complex<T> a) {
return {a.real, -a.imag};
}

template<typename T>
HIDET_HOST_DEVICE Complex<T> sin(Complex<T> a) {
T real = sin(a.real) * cosh(a.imag);
T imag = cos(a.real) * sinh(a.imag);
return {real, imag};
}

template<typename T>
HIDET_HOST_DEVICE Complex<T> cos(Complex<T> a) {
T real = cos(a.real) * cosh(a.imag);
T imag = -sin(a.real) * sinh(a.imag);
return {real, imag};
}

template<typename T>
HIDET_HOST_DEVICE Complex<T> exp(Complex<T> a) {
T magnitude = exp(a.real);
T real = magnitude * cos(a.imag);
T imag = magnitude * sin(a.imag);
return {real, imag};
}

template<typename T>
HIDET_HOST_DEVICE T abs(Complex<T> a);

Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _raise_exception(exception: Exception, target, caused_callable: Any, args, k
caused_callable = dispatched

callable_name, filename, lineno = Interpreter._callable_info(caused_callable)
raise type(exception)(
raise RuntimeError(
f'{exception}, occurred when interpreting {target_name} with\n'
f' {callable_name}({", ".join(argument_strings)})\n'
f'{callable_name} is defined at\n'
Expand Down
27 changes: 26 additions & 1 deletion python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# limitations under the License.
from typing import Optional, Union, Sequence, Any, Tuple, List
import operator
import functools
import torch
from hidet.graph.tensor import Tensor, full_like, from_torch
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
from hidet.ir.dtypes import promote_type
from hidet.ir.expr import Int
from hidet.runtime.device import Device
from .interpreter import register_function, register_method
Expand Down Expand Up @@ -209,6 +211,8 @@ def mul(x: Tensor, y: Tensor):

@register_function(torch.cat)
def cat(tensors: List[Tensor], dim: int):
dtype = functools.reduce(promote_type, [t.dtype for t in tensors])
tensors = [ops.cast(t, dtype) for t in tensors]
return ops.concat(tensors, dim)


Expand All @@ -218,13 +222,17 @@ def unsqueeze(x: Tensor, dim: int):


@register_function(torch.nn.functional.avg_pool2d)
def avg_pool2d(x: Tensor, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None):
def avg_pool2d(
x: Tensor, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None
):
if ceil_mode:
raise NotImplementedError("ceil_mode=True")
if not count_include_pad:
raise NotImplementedError("count_include_pad=False")
if divisor_override is not None:
raise NotImplementedError("divisor_override is not None")
if stride is None:
stride = kernel_size
y = ops.avg_pool2d(x, kernel_size, stride, padding)
return y

Expand Down Expand Up @@ -965,3 +973,20 @@ def torch_ne(x: Tensor, y: Union[Tensor, float, int], out: Optional[Tensor] = No
y = ops.full(x.shape, y, dtype=x.dtype, device=x.device)
output = ops.not_equal(x, y)
return output


@register_function(torch.stack)
def torch_stack(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int = 0, *, out: Optional[Tensor] = None):
if out is not None:
raise NotImplementedError("hidet: does not support torch.stack(..., out=...)")
tensors = [ops.unsqueeze(t, dims=dim) for t in tensors]
return ops.concat(tensors, axis=dim)


@register_function(torch.conj)
@register_method(torch.Tensor.conj)
def torch_conj(x: Tensor) -> Tensor:
if x.dtype.is_complex():
return ops.conj(x)
else:
return x
10 changes: 9 additions & 1 deletion python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ def tensor_squeeze(self: Tensor, dim=None) -> Tensor:

@register_method(torch.Tensor.unsqueeze)
def tensor_unsqueeze(self: Tensor, dim) -> Tensor:
return ops.unsqueeze(self, [int(dim)])
dim = int(dim)
if dim < 0:
dim = len(self.shape) + dim + 1
return ops.unsqueeze(self, [dim])


@register_method(torch.Tensor.type)
Expand Down Expand Up @@ -229,3 +232,8 @@ def tensor_masked_fill_(self: Tensor, mask: Tensor, value: float) -> Tensor:
@register_method(torch.Tensor.repeat)
def tensor_repeat(self: Tensor, *sizes: int) -> Tensor:
return ops.tile(self, sizes)


@register_method(torch.Tensor.detach)
def tensor_detach(self: Tensor) -> Tensor:
return self
2 changes: 2 additions & 0 deletions python/hidet/graph/frontend/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def dtype_from_torch(torch_dtype) -> Optional[DataType]:
torch.uint8: dtypes.uint8,
torch.bool: dtypes.boolean,
torch.double: dtypes.float64,
torch.complex64: dtypes.complex64,
torch.complex128: dtypes.complex128,
}
return mapping[torch_dtype]

Expand Down
13 changes: 10 additions & 3 deletions python/hidet/graph/ops/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,16 @@ def scalar_min(args: List[expr.Expr]):


def binary_arithmetic(
x: Union[Tensor, Constant, float, int],
y: Union[Tensor, Constant, float, int],
x: Union[Tensor, Constant, complex, float, int],
y: Union[Tensor, Constant, complex, float, int],
tensor_scalar_op: Callable[[Tensor, Constant], Tensor],
scalar_tensor_op: Callable[[Constant, Tensor], Tensor],
tensor_tensor_op: Callable[[Tensor, Tensor], Tensor],
) -> Union[Tensor, float, int]:
if not (isinstance(x, (Tensor, float, int, Constant)) and isinstance(y, (Tensor, float, int, Constant))):
if not (
isinstance(x, (Tensor, complex, float, int, Constant))
and isinstance(y, (Tensor, complex, float, int, Constant))
):
raise ValueError(
'Only support add/sub/mul/div between hidet.Tensor, float, int, and Constant. got {} and {}'.format(
type(x), type(y)
Expand All @@ -485,6 +488,8 @@ def binary_arithmetic(
x = dtypes.int32(x)
elif isinstance(x, float):
x = dtypes.float32(x)
elif isinstance(x, complex):
x = dtypes.complex64(x)
elif isinstance(x, Tensor) and len(x.shape) == 0:
if x.trace is None and x.storage is not None:
x = x.dtype(x.item())
Expand All @@ -493,6 +498,8 @@ def binary_arithmetic(
y = dtypes.int32(y)
elif isinstance(y, float):
y = dtypes.float32(y)
elif isinstance(y, complex):
y = dtypes.complex64(y)
elif isinstance(y, Tensor) and len(y.shape) == 0:
if y.trace is None and y.storage is not None:
y = y.dtype(y.item())
Expand Down
30 changes: 3 additions & 27 deletions python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,6 @@ def __init__(self, x: Tensor, dims: List[int]):
task=RearrangeTask(input_like(x, 'x'), plan=[[i] for i in range(len(x.shape)) if i not in dims]),
)

# def imperative_run(self, inputs: Optional[List[Tensor]] = None) -> List[Tensor]:
# x = inputs[0] if inputs else self.inputs[0]
# if isinstance(x.layout, RowMajorLayout):
# shape = self.task.outputs[0].const_shape()
# return [Tensor(shape=shape, dtype=x.dtype, device=x.device, storage=x.storage, trace=None)]
# else:
# return Operator.imperative_run(self, inputs)


class UnsqueezeOp(Operator):
def __init__(self, x: Tensor, dims: List[int]):
Expand All @@ -375,18 +367,10 @@ def __init__(self, x: Tensor, dims: List[int]):
else:
plan.append([c])
c += 1
assert c == len(x.shape)
if c != len(x.shape):
raise ValueError('Invalid unsqueeze dims: {} for shape: {}'.format(dims, x.shape))
super().__init__(inputs=[x], attributes={'dims': dims}, task=RearrangeTask(input_like(x, 'x'), plan=plan))

# def imperative_run(self, inputs: Optional[List[Tensor]] = None) -> List[Tensor]:
# x = inputs[0] if inputs else self.inputs[0]
# if isinstance(x.layout, (RowMajorLayout, ColumnMajorLayout)):
# shape = self.task.outputs[0].const_shape()
# layout = x.layout.__class__(shape)
# return [Tensor(shape=shape, dtype=x.dtype, device=x.device, storage=x.storage, layout=layout, trace=None)]
# else:
# return Operator.imperative_run(self, inputs)


class FlattenOp(Operator):
def __init__(self, x: Tensor, start_dim: int, end_dim: int):
Expand All @@ -402,15 +386,6 @@ def __init__(self, x: Tensor, start_dim: int, end_dim: int):
task=RearrangeTask(input_like(x, 'x'), plan=plan),
)

# def imperative_run(self, inputs: List[Tensor]) -> List[Tensor]:
# x = inputs[0] if inputs else self.inputs[0]
# if isinstance(x.layout, (RowMajorLayout, ColumnMajorLayout)):
# shape = self.task.outputs[0].const_shape()
# layout = x.layout.__class__(shape)
# return [Tensor(shape=shape, dtype=x.dtype, device=x.device, storage=x.storage, layout=layout, trace=None)]
# else:
# return Operator.imperative_run(self, inputs)


class PermuteDimsOp(Operator):
def __init__(self, x: Tensor, axes: Optional[List[int]] = None):
Expand Down Expand Up @@ -604,6 +579,7 @@ def squeeze(x: Tensor, dims: Union[int, Sequence[int]]) -> Tensor:
def unsqueeze(x: Tensor, dims: Union[int, Sequence[int]]) -> Tensor:
if isinstance(dims, int):
dims = [dims]
dims = [normalize_dim(dim, len(x.shape) + len(dims)) for dim in dims]
if len(dims) == 0:
return x
return UnsqueezeOp(x, dims).get_output(0)
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ir/primitives/cuda/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
from . import float64
from . import int64
from . import int32
from . import complex64
from . import complex128
52 changes: 52 additions & 0 deletions python/hidet/ir/primitives/cuda/math/complex128.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.expr import Expr
from hidet.ir.type import FuncType
from hidet.ir.primitives.func import register_primitive_function, primitive_func_pool
from hidet.ir.primitives.math import MathFunctionSet, register_math_function_set


class CUDAComplex128MathFunctionSet(MathFunctionSet):
# pylint: disable=abstract-method
@staticmethod
def register():
unary_funcs = {'sin': 'sin', 'cos': 'cos', 'abs': 'abs', 'exp': 'exp'}

for name_map, num_args in zip([unary_funcs], [1]):
for name, codegen_name in name_map.items():
register_primitive_function(
name='cuda_c128_{}'.format(name),
codegen_name=codegen_name,
func_or_type=FuncType(
param_types=['complex128'] * num_args,
ret_type='complex128' if name not in ['abs'] else 'float64',
),
)

@staticmethod
def call(name: str, *args) -> Expr:
entry = primitive_func_pool.lookup_by_name('cuda_c128_{}'.format(name))
return entry.var(*args)

def sin(self, a: Expr) -> Expr:
return self.call('sin', a)

def cos(self, a: Expr) -> Expr:
return self.call('cos', a)

def exp(self, a: Expr) -> Expr:
return self.call('exp', a)


cuda_c128_math_function_set = CUDAComplex128MathFunctionSet()
cuda_c128_math_function_set.register()
register_math_function_set('cuda', 'complex128', cuda_c128_math_function_set)
51 changes: 51 additions & 0 deletions python/hidet/ir/primitives/cuda/math/complex64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.ir.expr import Expr
from hidet.ir.type import FuncType
from hidet.ir.primitives.func import register_primitive_function, primitive_func_pool
from hidet.ir.primitives.math import MathFunctionSet, register_math_function_set


class CUDAComplex64MathFunctionSet(MathFunctionSet):
# pylint: disable=abstract-method
@staticmethod
def register():
unary_funcs = {'sin': 'sin', 'cos': 'cos', 'abs': 'abs', 'exp': 'exp'}

for name_map, num_args in zip([unary_funcs], [1]):
for name, codegen_name in name_map.items():
register_primitive_function(
name='cuda_c64_{}'.format(name),
codegen_name=codegen_name,
func_or_type=FuncType(
param_types=['complex64'] * num_args, ret_type='complex64' if name not in ['abs'] else 'float32'
),
)

@staticmethod
def call(name: str, *args) -> Expr:
entry = primitive_func_pool.lookup_by_name('cuda_c64_{}'.format(name))
return entry.var(*args)

def sin(self, a: Expr) -> Expr:
return self.call('sin', a)

def cos(self, a: Expr) -> Expr:
return self.call('cos', a)

def exp(self, a: Expr) -> Expr:
return self.call('exp', a)


cuda_c64_math_function_set = CUDAComplex64MathFunctionSet()
cuda_c64_math_function_set.register()
register_math_function_set('cuda', 'complex64', cuda_c64_math_function_set)

0 comments on commit a1bd18c

Please sign in to comment.