Skip to content

[Operator] Add dynamic shape support and tests for Operators. #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
groups: Optional[int],
output_padding: Optional[int],
):
num_channels, out_channels, length_in = data.const_shape
num_channels, out_channels, length_in = data.shape
out_channels, wc, kernel_size = weight.const_shape
s = normalize_stride(stride, dim=1)[0]
p = normalize_padding(padding, dim=1)[0]
Expand Down
6 changes: 4 additions & 2 deletions python/hidet/graph/ops/conv2d/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
from typing import List, Union, Sequence
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_stride, normalize_dilations, reduce
from hidet.ir.expr import is_constant


class Conv2dTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
# pylint: disable=too-many-locals
# we assume that only data needs to have dynamic shape
n, c, h, w = data.shape
oc, wc, kx, ky = weight.shape
sx, sy = stride
dilx, dily = dilations
p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1
if c % groups != 0 or oc % groups != 0:
if is_constant(c) and (c % groups != 0 or oc % groups != 0):
raise ValueError(
'Conv2d expect the in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
)
if wc * groups != c:
if is_constant(c) and wc * groups != c:
raise ValueError(
'Conv2d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups)
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/conv2d/conv2d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, x: TensorNode, kernel: List[int], stride: List[int], dilation
sx, sy = stride
dilx, dily = dilations
p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1
if c % groups != 0:
if is_constant(c) and c % groups != 0:
msg = 'Conv2d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups)
raise ValueError(msg)
gc = c // groups # group channels
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/graph/ops/conv2d/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
from hidet.graph import ops
from hidet.ir.expr import is_constant

from .conv2d import Conv2dOp

Expand All @@ -28,7 +29,7 @@ def resolve(self, op: Operator) -> Optional[List[Tensor]]:
groups = op.attrs['groups']
dilations = op.attrs['dilations']
channels = op.inputs[1].shape[0]
if groups == channels:
if is_constant(channels) and groups == channels:
return None # use depthwise schedule in the default Task
data, weight = op.inputs
kernel_size = weight.shape[2:]
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/graph/ops/conv2d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Sequence
from hidet.ir.expr import is_constant
from ..utils import normalize_stride


Expand All @@ -24,7 +25,7 @@ def infer_conv2d_shape(
oc, gc, kx, ky = w_shape
sx, sy = normalize_stride(strides)
dilx, dily = dilations
if gc * groups != c:
if is_constant(c) and gc * groups != c:
msg = 'Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups)
raise ValueError(msg)
if oc % groups != 0:
Expand Down
5 changes: 3 additions & 2 deletions python/hidet/graph/ops/conv3d/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import List, Union, Sequence
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import compute, input_like, normalize_stride, normalize_dilations, reduce
from hidet.ir.expr import is_constant


class Conv3dTask(Task):
Expand All @@ -26,12 +27,12 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila
(h - dilx * (kx - 1) - 1) // sx + 1,
(w - dily * (ky - 1) - 1) // sy + 1,
)
if c % groups != 0 or oc % groups != 0:
if is_constant(c) and c % groups != 0 or oc % groups != 0:
raise ValueError(
'Conv3d expect the in_channels % groups == 0 and out_channels % groups == 0, \n'
'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)
)
if wc * groups != c:
if is_constant(c) and wc * groups != c:
raise ValueError(
'Conv3d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n'
'got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups)
Expand Down
6 changes: 4 additions & 2 deletions python/hidet/graph/ops/conv3d/conv3d_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hidet.graph.ops.matmul import matmul
from hidet.graph.ops.utils import Task, Operator, Tensor, compute, input_like, TensorNode
from hidet.graph.ops.utils import normalize_kernel, normalize_stride
from hidet.ir.expr import is_constant
from .utils import infer_conv3d_shape


Expand All @@ -28,7 +29,7 @@ def __init__(self, x: TensorNode, kernel: List[int], stride: List[int], dilation
(h - dilx * (kx - 1) - 1) // sx + 1,
(w - dily * (ky - 1) - 1) // sy + 1,
)
if c % groups != 0:
if is_constant(c) and c % groups != 0:
msg = 'Conv3d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups)
raise ValueError(msg)
gc = c // groups # group channels
Expand Down Expand Up @@ -80,7 +81,8 @@ def conv3d_gemm_inverse_transform(gemm_y: Tensor, out_depth, out_height, out_wid
# output shape: [n, oc, r, p, q] where oc = groups * ogc
r, p, q = out_depth, out_height, out_width
groups, nrpq, ogc = gemm_y.shape
assert nrpq % (r * p * q) == 0
if is_constant(nrpq, r, p, q):
assert nrpq % (r * p * q) == 0
n = nrpq // (r * p * q)
y = gemm_y.reshape([groups, n, r, p, q, ogc])
y = y.rearrange([[1], [0, 5], [2], [3], [4]])
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/graph/ops/conv3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from hidet.ir.expr import is_constant
from ..utils import normalize_stride


Expand All @@ -20,7 +21,7 @@ def infer_conv3d_shape(
oc, gc, kz, kx, ky = w_shape
sz, sx, sy = normalize_stride(strides, dim=3)
dilz, dilx, dily = dilations
if gc * groups != c:
if is_constant(c) and gc * groups != c:
msg = 'Conv3d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups)
raise ValueError(msg)
if oc % groups != 0:
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/graph/ops/conv3d_transpose/conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Union, Tuple
from hidet.ir.expr import if_then_else, logical_and
from hidet.ir.expr import if_then_else, logical_and, is_constant
from hidet.ir.compute import compute, reduce
from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode
from hidet.graph.ops.utils import input_like, normalize_stride, normalize_padding
Expand All @@ -26,7 +26,7 @@ def __init__(
groups: int,
output_padding: Tuple[int, int, int],
):
n, oc, r, p, q = data.const_shape
n, oc, r, p, q = data.shape
oc, wc, kz, kx, ky = weight.const_shape
c = wc * groups
sz, sx, sy = stride
Expand All @@ -40,7 +40,7 @@ def __init__(
'Conv3dTranspose expect the output_padding < stride, \n'
'but got output_padding, stride: {}, {}'.format(output_padding, stride)
)
if any(p < 0 for p in padding):
if is_constant(p) and any(p < 0 for p in padding):
raise ValueError('Negative padding is not supported.')

og = oc // groups # output channels in each group
Expand Down
12 changes: 12 additions & 0 deletions python/hidet/graph/ops/matmul/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,20 @@ def batch_matmul_kernel(
return ir_module


def is_true(expr) -> bool:
if is_constant(expr):
return bool(expr) is True
return False


def is_false():
pass


class BatchMatmulOp(Operator):
def __init__(self, a: Tensor, b: Tensor, mma: str = 'simt'):
# if is_false(a.shape[0] == b.shape[0]) or is_false(a.shape[2] == b.shape[1]):
# raise
if not (
len(a.shape) == len(b.shape) == 3
and (not is_constant(a.shape[0], b.shape[0]) or a.shape[0] == b.shape[0])
Expand Down
24 changes: 14 additions & 10 deletions python/hidet/graph/ops/matmul/matmul_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import List, Tuple
from hidet.ir import dtypes
from hidet.ir.dtypes import float16
from hidet.ir.expr import if_then_else
from hidet.ir.expr import if_then_else, Int, Expr
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.compute import TensorNode
Expand All @@ -33,7 +33,8 @@ def __init__(self, a: TensorNode, b: TensorNode, parallel_k_parts: int = 1):
if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError('Matrix multiplication expect at least 2D tensor, got {} and {}'.format(a.shape, b.shape))

if a.shape[-1] != b.shape[-2]:
# TODO: add dynamic shape assertions
if not (isinstance(a.shape[-1], Expr) or isinstance(b.shape[-1], Expr)) and a.shape[-1] != b.shape[-2]:
raise ValueError(
'Matrix multiplication expect tensor A and B with shape [..., M, K] and [..., K, N]'
', got {} and {}'.format(a.shape, b.shape)
Expand All @@ -44,9 +45,9 @@ def __init__(self, a: TensorNode, b: TensorNode, parallel_k_parts: int = 1):
'Matrix multiplication expect tensor A and B with compatible broadcast shape, '
'got {} and {}'.format(a.shape, b.shape)
)
a_shape = a.const_shape
b_shape = b.const_shape
k_size = int(a.shape[-1])
a_shape = a.shape
b_shape = b.shape
k_size = a.shape[-1]
c_shape = [parallel_k_parts] + broadcast_shape(a.shape[:-2], b.shape[:-2]) + [a_shape[-2], b_shape[-1]]
k_part_extent = cdiv(k_size, parallel_k_parts)

Expand Down Expand Up @@ -104,9 +105,9 @@ def schedule(

# input shapes
node_a, node_b, node_c = self.inputs[0], self.inputs[1], self.outputs[0]
a_shape: Tuple[int, ...] = node_a.const_shape
b_shape: Tuple[int, ...] = node_b.const_shape
c_shape: Tuple[int, ...] = node_c.const_shape
a_shape: Tuple[Int, ...] = node_a.shape
b_shape: Tuple[Int, ...] = node_b.shape
c_shape: Tuple[Int, ...] = node_c.shape
m_size, n_size, k_size = a_shape[-2], b_shape[-1], a_shape[-1]
a_head, b_head, c_head = list(a_shape[:-2]), list(b_shape[:-2]), list(c_shape[:-2])
k_parts = self.attrs['parallel_k_parts']
Expand All @@ -121,7 +122,7 @@ def schedule(
warp_count_m, warp_count_n, warp_count_k = block_m // warp_m, block_n // warp_n, block_k // warp_k
mma_count_m, mma_count_n, mma_count_k = warp_m // mma_m, warp_n // mma_n, warp_k // mma_k
threads = warp_count_m * warp_count_n * warp_count_k * 32
grid_dim: Tuple[int, int, int] = cdiv(m_size, block_m), cdiv(n_size, block_n), prod(c_head)
grid_dim: Tuple[Int, Int, Int] = cdiv(m_size, block_m), cdiv(n_size, block_n), prod(c_head)
dynamic_smem_bytes = max(2 * (block_m + block_n) * block_k * 2, block_m * block_n * 2)

tune.check(block_m % warp_m == block_n % warp_n == block_k % warp_k == 0, 'warp dims divide block dims')
Expand Down Expand Up @@ -322,7 +323,10 @@ def __init__(self, a: Tensor, b: Tensor, parallel_k_parts=1):
def matmul_f16(a: Tensor, b: Tensor, parallel_k_parts=1) -> Tensor:
if len(a.shape) < 2 or len(b.shape) < 2:
raise ValueError('a and b must have at least 2 dimensions, got shape {} and {}'.format(a.shape, b.shape))
if a.shape[-1] % 8 != 0 or b.shape[-1] % 8 != 0:
# TODO: impliment dynamic run-time shape assertion
if not (isinstance(a.shape[-1], Expr) or isinstance(b.shape[-1], Expr)) and (
a.shape[-1] % 8 != 0 or b.shape[-1] % 8 != 0
):
raise ValueError('Expect the last dimension of the input tensors to be a multiple of 8')
if a.dtype != dtypes.float16 or b.dtype != dtypes.float16:
raise ValueError('BatchMatmulF16Op only support float16, got {} and {}'.format(a.dtype, b.dtype))
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def implement_cuda(self, working_dir: str) -> IRModule:

from hidet.lang.cuda import blockIdx, threadIdx

shape = self.inputs[0].const_shape
shape = self.inputs[0].shape
axis = self.axis
reduce_extent = shape[axis]
reduced_shape = shape[:axis] + shape[axis + 1 :]
Expand Down
5 changes: 3 additions & 2 deletions python/hidet/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@

from . import models
from . import utils
from .utils import check_unary, check_unary_dynamic, check_binary
from .utils import check_ternary, check_torch_unary, check_torch_binary, check_torch_ternary
from .utils import check_unary, check_unary_dynamic, check_binary, check_binary_dynamic
from .utils import check_ternary, check_torch_unary
from .utils import check_torch_binary, check_torch_binary_dynamic, check_torch_ternary
67 changes: 67 additions & 0 deletions python/hidet/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,38 @@ def check_binary(
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_binary_dynamic(
a_shape: Sequence[Union[int, Tuple[str, int]]],
b_shape: Sequence[Union[int, Tuple[str, int]]],
numpy_op,
hidet_op,
device: str = 'all',
dtype: Union[str, np.dtype] = np.float32,
atol=0.0,
rtol=0.0,
):
if device == 'all':
for dev in ['cuda', 'cpu']:
check_binary_dynamic(a_shape, b_shape, numpy_op, hidet_op, dev, dtype, atol, rtol)
return
a_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in a_shape]
a_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in a_shape]

b_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in b_shape]
b_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in b_shape]
a = np.array(np.random.randn(*a_concrete_shape)).astype(dtype)
b = np.array(np.random.randn(*b_concrete_shape)).astype(dtype)
numpy_result = numpy_op(a, b)
a_hidet = asarray(a).to(device=device)
b_hidet = asarray(b).to(device=device)
sym_a = symbol(a_symbolic_shape, dtype=a_hidet.dtype, device=a_hidet.device)
sym_b = symbol(b_symbolic_shape, dtype=b_hidet.dtype, device=b_hidet.device)
sym_result = hidet_op(sym_a, sym_b)
func = trace_from(sym_result, [sym_a, sym_b])
hidet_result = func(a_hidet, b_hidet).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol)


def check_ternary(
a_shape, b_shape, c_shape, numpy_op, hidet_op, dtype: Union[str, np.dtype] = np.float32, atol=0.0, rtol=0.0
):
Expand Down Expand Up @@ -142,6 +174,41 @@ def check_torch_binary(
)


def check_torch_binary_dynamic(
a_shape: Sequence[Union[int, Tuple[str, int]]],
b_shape: Sequence[Union[int, Tuple[str, int]]],
torch_func,
hidet_func,
device: str = 'all',
dtype: Union[str, np.dtype] = np.float32,
atol=0.0,
rtol=0.0,
):
if device == 'all':
for dev in ['cuda', 'cpu']:
check_torch_binary_dynamic(a_shape, b_shape, torch_func, hidet_func, dev, dtype, atol, rtol)
return
import torch

a_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in a_shape]
a_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in a_shape]

b_concrete_shape = [(i if isinstance(i, int) else i[1]) for i in b_shape]
b_symbolic_shape = [(i if isinstance(i, int) else i[0]) for i in b_shape]

a = torch.randn(*a_concrete_shape, dtype=getattr(torch, dtype)).to(device=device)
b = torch.randn(*b_concrete_shape, dtype=getattr(torch, dtype)).to(device=device)
torch_result = torch_func(a, b)
a_hidet = asarray(a.cpu(), dtype=dtype).to(device=device)
b_hidet = asarray(b.cpu(), dtype=dtype).to(device=device)
sym_a = symbol(a_symbolic_shape, dtype=a_hidet.dtype, device=a_hidet.device)
sym_b = symbol(b_symbolic_shape, dtype=b_hidet.dtype, device=b_hidet.device)
sym_result = hidet_func(sym_a, sym_b)
func = trace_from(sym_result, [sym_a, sym_b])
hidet_result = func(a_hidet, b_hidet).cpu().numpy()
np.testing.assert_allclose(actual=hidet_result, desired=torch_result.cpu().numpy(), atol=atol, rtol=rtol)


def check_torch_ternary(
a_shape: Sequence[int],
b_shape: Sequence[int],
Expand Down
22 changes: 21 additions & 1 deletion tests/operators/test_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from hidet import ops
from hidet.testing import check_binary
from hidet.testing import check_binary, check_binary_dynamic


def torch_conv1d(
Expand Down Expand Up @@ -49,5 +49,25 @@ def test_conv1d(hidet_op, n, c, l, oc, k, padding, stride, dilations, groups):
)


@pytest.mark.parametrize("hidet_op", [ops.conv1d])
@pytest.mark.parametrize("n, c, l, oc, k", [[1, 3, 32, 12, 3]])
@pytest.mark.parametrize("padding", [[0], [1]])
@pytest.mark.parametrize("stride", [1])
@pytest.mark.parametrize("dilations", [1])
@pytest.mark.parametrize("groups", [1])
def test_conv1d_dynamic(hidet_op, n, c, l, oc, k, padding, stride, dilations, groups):
check_binary_dynamic(
a_shape=[('n', n), ('c', c), ('c', l)],
b_shape=[oc, c, k],
numpy_op=lambda data, weight: torch_conv1d(data, weight, padding, stride, dilations, groups),
hidet_op=lambda data, weight: hidet_op(
ops.conv_pad(data, padding), weight=weight, stride=stride, dilations=dilations, groups=groups
),
dtype='float32',
atol=2e-5,
rtol=2e-5,
)


if __name__ == '__main__':
pytest.main([__file__])
Loading