Skip to content

Commit

Permalink
[Fix] Added missing torch.multiply and torch.nn.functional.unfold ops…
Browse files Browse the repository at this point in the history
… for conv-bert-base model (#351)

Added support for `torch.multiply` and `torch.nn.functional.unfold`
These ops are needed in `conv-bert-base` models

---------

Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
  • Loading branch information
2 people authored and vadiklyutiy committed Jul 23, 2024
1 parent c87c515 commit 18842ee
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 2 deletions.
7 changes: 7 additions & 0 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def setitem(x: Tensor, item, setvalue):

@register_function(operator.mul)
@register_function(torch.mul)
@register_function(torch.multiply)
@register_function(torch.ops.aten.mul.Tensor)
def mul(x: Tensor, y: Tensor):
return x * y
Expand Down Expand Up @@ -1559,6 +1560,12 @@ def torch_any_v2(input: Tensor) -> Tensor:
return ops.any(input)


@register_function(torch.nn.functional.unfold)
def torch_unfold(input: Tensor, kernel_size, dilation=1, padding=0, stride=1) -> Tensor:
assert 3 <= len(input.shape) <= 4, "torch.nn.functional.unfold accepts 3D or 4D tensor only"
return ops.im2col(input, kernel_size, dilation, padding, stride)


# Below torch function might appear in fxgraph on dynamo level. But dynamo resolved it by itself.
# Hidet never should see them.
@register_function(torch._C._has_torch_function)
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .quant import symmetric_quantize, symmetric_dequantize
from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any
from .cumulative import cumsum
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape
from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape, im2col
from .transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather, index_select, triu, tril
from .transform import permute_dims, meshgrid, repeat_interleave
from .fusion import fused_operator
Expand Down
80 changes: 79 additions & 1 deletion python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
from typing import List, Optional, Union, Sequence
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant, logical_or
from hidet.ir.expr import Int
from hidet.ir.layout import RowMajorLayout
from hidet.ir.utils import index_deserialize, index_serialize
Expand Down Expand Up @@ -405,6 +405,50 @@ def fmap(*indices):
super().__init__(name='tril', inputs=[x], attributes={'diagonal': diagonal}, outputs=[out])


class Im2ColTask(Task):
def __init__(
self, x: TensorNode, kernel_size: List[Int], dilation: List[Int], padding: List[Int], stride: List[Int]
):
dtype = x.type.dtype
batch_size = x.shape[0]
n_input_plane = x.shape[1]
input_height = x.shape[2]
input_width = x.shape[3]

output_height = (input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) // stride[0] + 1
output_width = (input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) // stride[1] + 1
n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
output_length = output_height * output_width

output_shape = [batch_size, n_output_plane, output_length]

def fmap(*indices):
b_i, ck_i, o_i = indices
c_i = ck_i // prod(kernel_size)
ck_i %= prod(kernel_size)
k_hi = ck_i // kernel_size[1]
k_wi = ck_i % kernel_size[1]
o_hi = o_i // output_width
o_wi = o_i % output_width

i_hi = o_hi * stride[0] - padding[0] + k_hi * dilation[0]
i_wi = o_wi * stride[1] - padding[1] + k_wi * dilation[1]
return if_then_else(
logical_or(i_hi < 0, i_wi < 0, i_hi >= input_height, i_wi >= input_width),
dtype.zero,
x[b_i, c_i, i_hi, i_wi],
)

out = compute(name='out', shape=output_shape, fcompute=fmap)

super().__init__(
name='im2col',
inputs=[x],
attributes={'kernel_size': kernel_size, 'dilation': dilation, 'padding': padding, 'stride': stride},
outputs=[out],
)


class ReshapeOp(Operator):
def __init__(self, x: Tensor, shape):
task = ReshapeTask(input_like(x, 'x'), shape)
Expand Down Expand Up @@ -590,6 +634,15 @@ def __init__(self, x: Tensor, diagonal: Int = 0):
super().__init__(inputs=[x], attributes={'diagonal': diagonal}, task=TrilTask(input_like(x, 'x'), diagonal))


class Im2ColOp(Operator):
def __init__(self, x: Tensor, kernel_size: List[Int], dilation: List[Int], padding: List[Int], stride: List[Int]):
super().__init__(
inputs=[x],
attributes={'kernel_size': kernel_size, 'dilation': dilation, 'padding': padding, 'stride': stride},
task=Im2ColTask(input_like(x, 'x'), kernel_size, dilation, padding, stride),
)


def reshape(x: Tensor, shape) -> Tensor:
if same_shape(x.shape, shape):
return x
Expand Down Expand Up @@ -838,3 +891,28 @@ def meshgrid(*tensors: Tensor, indexing: str = "ij") -> List[Tensor]:
grid = transpose(grid, (1, 0))
outputs.append(grid)
return outputs


def im2col(
x: Tensor,
kernel_size: Union[int, List[int]],
dilation: Union[int, List[int]] = 1,
padding: Union[int, List[int]] = 0,
stride: Union[int, List[Int]] = 1,
):
nd = len(x.shape)
if nd == 3:
x = x.unsqueeze(0)
if isinstance(kernel_size, int):
kernel_size = [kernel_size] * 2
if isinstance(dilation, int):
dilation = [dilation] * 2
if isinstance(padding, int):
padding = [padding] * 2
if isinstance(stride, int):
stride = [stride] * 2
x = Im2ColOp(x, kernel_size, dilation, padding, stride).outputs[0]

if nd == 3:
return x.squeeze(0)
return x
21 changes: 21 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,5 +412,26 @@ def test_repeat_interleave(input_shape, repeats, dim):
np.testing.assert_allclose(output_tensor.numpy(), output_tensor_hidet.numpy(), atol=0, rtol=0)


@pytest.mark.parametrize(
"shape, kernel_size, dilation, padding, stride",
[
([1, 3, 10, 10], 2, 1, 0, 1),
([2, 3, 99, 99], 3, 2, 1, 3),
([3, 1, 10, 9], 3, 4, 5, 6),
([3, 4, 5, 7], 3, 1, 4, 2),
([3, 44, 41], 2, 3, 0, 5),
([4, 24, 122], 4, 2, 1, 2),
([3, 1, 45, 33], 3, 4, 5, 6),
([4, 10, 13], 2, 1, 0, 1),
],
)
def test_unfold(shape, kernel_size, dilation, padding, stride):
check_transform_torch(
shape,
lambda x: torch.nn.functional.unfold(x, kernel_size, dilation, padding, stride),
lambda x: ops.im2col(x, kernel_size, dilation, padding, stride),
)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 18842ee

Please sign in to comment.