Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] Added Operator Support #131

Merged
merged 14 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from .utils import dtype_from_torch, device_from_torch

Number = Union[int, float, bool]
TorchDtype = torch.dtype
TorchDevice = torch.device


@register_function(torch.nn.functional.conv2d)
Expand Down Expand Up @@ -253,7 +251,9 @@ def ones(


@register_function(torch.nn.functional.gelu)
def gelu(x: Tensor):
def gelu(x: Tensor, approximate: Optional[str] = "none"):
if approximate is not None:
warnings.warn_once("approximate is not None")
return ops.gelu(x)


Expand Down Expand Up @@ -352,9 +352,9 @@ def arange(
step: Number = 1,
*,
out: Optional[Tensor] = None,
dtype: Optional[TorchDtype] = None,
layout: Optional = None,
device: Optional[Union[TorchDevice, str, None]] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[Union[torch.device, str, None]] = None,
pin_memory: Optional[bool] = False,
requires_grad: Optional[bool] = False,
):
Expand Down Expand Up @@ -428,3 +428,31 @@ def torch_tensor(
else:
tt = torch.tensor(data, dtype=dtype, device=device)
return from_torch(tt)


@register_function(torch.sigmoid)
def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
warnings.warn_once("hidet: does not support torch.sigmoid(..., out=...)")
return ops.sigmoid(x)


@register_function(torch.nn.functional.hardsigmoid)
def hardsigmoid(x: Tensor, inplace: bool):
if inplace:
warnings.warn_once('hidet: hardsigmoid with inplace=True is not supported. Treat as inplace=False.')
return ops.hardsigmoid(x)


@register_function(torch.nn.functional.silu)
def silu(x: Tensor, inplace: bool):
if inplace:
warnings.warn_once('hidet: silu with inplace=True is not supported. Treat as inplace=False.')
return ops.silu(x)


@register_function(torch.nn.functional.hardswish)
def hardswish(x: Tensor, inplace: bool):
if inplace:
warnings.warn_once('hidet: hardswish with inplace=True is not supported. Treat as inplace=False.')
return ops.hardswish(x)
57 changes: 57 additions & 0 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,60 @@ class HidetReLU6(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ReLU6)
return regs.relu6(x, self.mod.inplace)


@register_module(torch.nn.Sigmoid)
class HidetSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Sigmoid)
return regs.sigmoid(x)


@register_module(torch.nn.Hardsigmoid)
class HidetHardsigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardsigmoid)
return regs.hardsigmoid(x, self.mod.inplace)


@register_module(torch.nn.AvgPool2d)
class HidetAvgPool2d(HidetModule):
def __call__(self, x=Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AvgPool2d)
return regs.avg_pool2d(
x=x,
kernel_size=self.mod.kernel_size,
stride=self.mod.stride,
padding=self.mod.padding,
ceil_mode=self.mod.ceil_mode,
count_include_pad=self.mod.count_include_pad,
divisor_override=self.mod.divisor_override,
)


@register_module(torch.nn.Flatten)
class HidetFlatten(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Flatten)
return regs.flatten(x, self.mod.start_dim, self.mod.end_dim)


@register_module(torch.nn.Hardswish)
class HidetHardswish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardswish)
return regs.hardswish(x, self.mod.inplace)


@register_module(torch.nn.GELU)
class HidetGELU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GELU)
return regs.gelu(x, self.mod.approximate)


@register_module(torch.nn.SiLU)
class HidetSiLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.SiLU)
return regs.silu(x, self.mod.inplace)
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .definitions.pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .definitions.pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
from .definitions.softmax import softmax
from .definitions.activation import relu, leaky_relu, sigmoid, clip, relu6, prelu, gelu
from .definitions.activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
from .definitions.norm import batch_norm_infer, instance_norm, layer_norm
from .definitions.image import resize2d
from .definitions.create import full, arange, linspace
Expand Down
46 changes: 43 additions & 3 deletions python/hidet/graph/ops/definitions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,20 @@ def __init__(self, x: Tensor, alpha):

class SigmoidOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(x, op=lambda v: x.dtype(1.0) / (x.dtype.one + prim.exp(-v)), name='sigmoid')
super().__init__(x, op=lambda v: x.dtype(1.0) / (x.dtype(1.0) + prim.exp(-v)), name='sigmoid')


class HardSigmoidOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(
x,
op=lambda v: if_then_else(
v <= x.dtype(-3),
x.dtype.zero,
if_then_else(v >= x.dtype(3), x.dtype.one, v / x.dtype(6) + x.dtype(0.5)),
),
name='hardsigmoid',
)


class ClipOp(UnaryElementwiseOp):
Expand All @@ -48,17 +61,32 @@ def op(v):

class GeluOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
dtype = x.dtype
super().__init__(
x=x, op=lambda v: dtype(0.5) * v * (dtype.one + prim.erf(v * dtype(1 / math.sqrt(2)))), name='gelu'
x, op=lambda v: x.dtype(0.5) * v * (x.dtype.one + prim.erf(v * x.dtype(1 / math.sqrt(2)))), name='gelu'
)


class SiluOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(x, op=lambda v: v * (x.dtype(1.0) / (x.dtype(1.0) + prim.exp(-v))), name='silu')


class PReluOp(BinaryElementwiseOp):
def __init__(self, x, slope):
super().__init__(x, slope, op=lambda a, b: if_then_else(a >= 0, a, a * b), name='prelu')


class HardSwishOp(UnaryElementwiseOp):
def __init__(self, x: Tensor):
super().__init__(
x,
op=lambda v: if_then_else(
v <= x.dtype(-3), x.dtype.zero, if_then_else(v >= x.dtype(3), v, (v * (v + x.dtype(3))) / x.dtype(6))
),
name='hardswish',
)


def relu(x) -> Tensor:
return ReluOp(x).get_output(0)

Expand All @@ -71,6 +99,10 @@ def sigmoid(x: Tensor) -> Tensor:
return SigmoidOp(x).get_output(0)


def hardsigmoid(x: Tensor) -> Tensor:
return HardSigmoidOp(x).get_output(0)


def clip(x: Tensor, min_val: Optional[float], max_val: Optional[float]) -> Tensor:
return ClipOp(x, min_val, max_val).get_output(0)

Expand All @@ -83,5 +115,13 @@ def gelu(x: Tensor) -> Tensor:
return GeluOp(x).get_output(0)


def silu(x: Tensor) -> Tensor:
return SiluOp(x).get_output(0)


def prelu(x: Tensor, slope: Tensor) -> Tensor:
return PReluOp(x, slope).get_output(0)


def hardswish(x: Tensor) -> Tensor:
return HardSwishOp(x).get_output(0)
6 changes: 3 additions & 3 deletions python/hidet/ir/primitives/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def acosh(self, a: Expr) -> Expr:
def atanh(self, a: Expr) -> Expr:
raise NotImplementedError()

def atan2(self, a: Expr, b: Expr) -> Expr:
raise NotImplementedError()

def exp(self, a: Expr) -> Expr:
raise NotImplementedError()

Expand Down Expand Up @@ -142,6 +139,9 @@ def mod(self, a: Expr, b: Expr) -> Expr:
def pow(self, a: Expr, b: Expr) -> Expr:
raise NotImplementedError()

def atan2(self, a: Expr, b: Expr) -> Expr:
raise NotImplementedError()
yaoyaoding marked this conversation as resolved.
Show resolved Hide resolved

# ternary math functions
def fma(self, a: Expr, b: Expr, c: Expr) -> Expr:
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@

from . import models
from . import utils
from .utils import check_unary, check_binary, check_torch_unary, check_torch_binary
from .utils import check_unary, check_binary, check_ternary, check_torch_unary, check_torch_binary, check_torch_ternary
49 changes: 49 additions & 0 deletions python/hidet/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def check_binary(
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_ternary(
a_shape, b_shape, c_shape, numpy_op, hidet_op, dtype: Union[str, np.dtype] = np.float32, atol=0.0, rtol=0.0
):
np.random.seed(1)
a = np.array(np.random.randn(*a_shape)).astype(dtype)
b = np.array(np.random.randn(*b_shape)).astype(dtype)
c = np.array(np.random.randn(*c_shape)).astype(dtype)

c = np.abs(c)

numpy_result = numpy_op(a, b, c)
import hidet as hi

hidet_args = [hi.asarray(v).cuda() for v in [a, b, c]]
hidet_result = hidet_op(*hidet_args).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_torch_unary(
shape: Sequence[int], torch_func, hidet_func, device: str = 'all', dtype: str = 'float32', atol=0.0, rtol=0.0
):
Expand Down Expand Up @@ -95,3 +113,34 @@ def check_torch_binary(
np.testing.assert_allclose(
actual=hidet_result.cpu().numpy(), desired=torch_result.cpu().numpy(), atol=atol, rtol=rtol
)


def check_torch_ternary(
a_shape: Sequence[int],
b_shape: Sequence[int],
c_shape: Sequence[int],
torch_func,
hidet_func,
device: str = 'all',
dtype: str = 'float32',
atol=0.0,
rtol=0.0,
):
if device == 'all':
for dev in ['cuda', 'cpu']:
check_torch_ternary(a_shape, b_shape, c_shape, torch_func, hidet_func, dev, dtype, atol, rtol)
return
import torch
import hidet

torch_a = torch.randn(*a_shape, dtype=getattr(torch, dtype)).to(device=device)
torch_b = torch.randn(*b_shape, dtype=getattr(torch, dtype)).to(device=device)
torch_c = torch.randn(*c_shape, dtype=getattr(torch, dtype)).to(device=device)
hidet_a = hidet.from_torch(torch_a)
hidet_b = hidet.from_torch(torch_b)
hidet_c = hidet.from_torch(torch_c)
torch_result: torch.Tensor = torch_func(torch_a, torch_b, torch_c)
hidet_result: hidet.Tensor = hidet_func(hidet_a, hidet_b, hidet_c)
np.testing.assert_allclose(
actual=hidet_result.cpu().numpy(), desired=torch_result.cpu().numpy(), atol=atol, rtol=rtol
)
22 changes: 22 additions & 0 deletions tests/frontends/torch/test_torch_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,25 @@
@pytest.mark.parametrize('dtype', [torch.float32])
def test_relu(shape, dtype):
check_module(torch.nn.ReLU(), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize("shape", [(10, 20)])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_hardsigmoid(shape, dtype):
check_module(torch.nn.Hardsigmoid(), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize("shape", [(10, 20)])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_sigmoid(shape, dtype):
check_module(torch.nn.Sigmoid(), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize("shape", [(10, 20)])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_hardswish(shape, dtype):
check_module(torch.nn.Hardswish(), [torch.randn(shape, dtype=dtype)])


if __name__ == '__main__':
pytest.main([__file__])
4 changes: 4 additions & 0 deletions tests/frontends/torch/test_torch_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ def test_conv2d(in_shape, w_shape, stride, padding, dtype):
),
args=[torch.randn(in_shape, dtype=dtype)],
)


if __name__ == '__main__':
pytest.main([__file__])
56 changes: 56 additions & 0 deletions tests/frontends/torch/test_torch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from hidet.testing.torch_utils import check_module


@pytest.mark.parametrize('shape', [[2, 2], [2, 6]])
@pytest.mark.parametrize('num_features', [2, 2])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_instance_norm_1d(shape, num_features, dtype):
check_module(torch.nn.InstanceNorm1d(num_features=num_features), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize('shape', [[2, 2, 2], [2, 6, 4]])
@pytest.mark.parametrize('num_features', [2, 2])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_instance_norm_2d(shape, num_features, dtype):
check_module(torch.nn.InstanceNorm2d(num_features=num_features), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize('shape', [[2, 2, 2, 2], [2, 6, 4, 6]])
@pytest.mark.parametrize('num_features', [2, 2])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_instance_norm_3d(shape, num_features, dtype):
check_module(torch.nn.InstanceNorm3d(num_features=num_features), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize('shape', [[2, 2]])
@pytest.mark.parametrize('normalized_shape', [2])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_layer_norm(shape, normalized_shape, dtype):
check_module(torch.nn.LayerNorm(normalized_shape=normalized_shape), [torch.randn(shape, dtype=dtype)])


@pytest.mark.parametrize('shape', [[2, 2], [2, 2]])
@pytest.mark.parametrize('num_groups', [2, 2])
@pytest.mark.parametrize('num_channels', [2, 2])
@pytest.mark.parametrize('dtype', [torch.float32])
def test_group_norm(shape, num_groups, num_channels, dtype):
check_module(
torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels), [torch.randn(shape, dtype=dtype)]
)


if __name__ == '__main__':
pytest.main([__file__])
Loading