Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Further performance enhancements for conv2D #290

Merged
merged 55 commits into from
Jun 28, 2023
Merged
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
803d302
lint again for some reason
Aalanli Jun 9, 2023
a59ba09
lint again for some reason
Aalanli Jun 9, 2023
0be7d38
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 9, 2023
06a9cd0
nevermind
Aalanli Jun 9, 2023
f550d2f
Merge branch 'main' of https://github.com/Aalanli/hidet into main
Aalanli Jun 12, 2023
8ae48da
minor fixes
Jun 14, 2023
e0433b8
post-conv-algorithm
Aalanli Jun 14, 2023
455f519
algorithm benchmark changes
Jun 14, 2023
629a54a
docs
Jun 14, 2023
36f7288
todo convgemmfusedfp16
Aalanli Jun 14, 2023
4cfa8e7
prototype kernel
Jun 15, 2023
952d316
further bug fix
Jun 15, 2023
77a2fa3
fused fp16 conv gemm kernel
Aalanli Jun 16, 2023
a572d72
finished kernel for group=1
Aalanli Jun 17, 2023
ac49774
minor bug fix
Aalanli Jun 17, 2023
fd19d60
Merge branch 'hidet-org:main' into main
Aalanli Jun 17, 2023
28f4d88
Merge branch 'main' into conv2d-fp16
Aalanli Jun 17, 2023
3ace805
finished conv
Aalanli Jun 18, 2023
530148f
remove dead code
Aalanli Jun 18, 2023
30886cf
add tests
Aalanli Jun 18, 2023
a565775
parallel_k test
Aalanli Jun 18, 2023
eb92746
pk part heuristic
Aalanli Jun 18, 2023
8cfe845
update heuristic
Jun 18, 2023
29340df
lint
Jun 18, 2023
5ecbfaa
finished conv2d gemm
Aalanli Jun 19, 2023
3935aa5
performance alteration
Jun 19, 2023
3ba62b5
format
Jun 19, 2023
35b081d
disable cpu tests due to numerical instability
Jun 19, 2023
cfd4211
performance enhancement
Jun 20, 2023
3caf05e
format
Jun 20, 2023
e5acafa
Merge branch 'main' into conv2d-fp16
Jun 20, 2023
b435f45
temporary commit
Jun 21, 2023
b56b2c2
make matmul_f16 work with dim a multiple of 4
Aalanli Jun 21, 2023
a5485b8
remove
Aalanli Jun 21, 2023
d99eac7
new transform op
Aalanli Jun 22, 2023
cfd72bd
tests
Jun 23, 2023
8963719
pre_transform
Jun 23, 2023
9e934bf
support sizes a multiple of 2, 4 for matmul_f16
Jun 23, 2023
8f50363
add pad to conv2d
Jun 23, 2023
4aed78e
move pad into conv2d operator
Jun 23, 2023
76e3154
format lint
Jun 23, 2023
bb1b380
fix test
Jun 23, 2023
168ea22
padding fix
Aalanli Jun 23, 2023
605b9e8
format
Aalanli Jun 23, 2023
05ced08
disable prologue and epilogue
Jun 23, 2023
8cbf333
performance improvement
Jun 23, 2023
9d98c41
Revise to default pretransform
Jun 23, 2023
15f220b
fix subgraph rewrite bug
Aalanli Jun 24, 2023
7b1cb44
format / lint
Aalanli Jun 24, 2023
fa90765
fix flaky test
Aalanli Jun 24, 2023
359de5b
.
Aalanli Jun 24, 2023
c939fe4
apply suggestions
Jun 27, 2023
45b1357
moved to conv2d
Jun 27, 2023
8706782
fix resolve, remove pad_value
Jun 27, 2023
9b88f1b
Merge branch 'main' into conv2d-fp16v2
Aalanli Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/hidet/graph/frontend/onnx/onnx.py
Original file line number Diff line number Diff line change
@@ -204,8 +204,15 @@ def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
dilations = self.attrs.get('dilations', [1, 1])
padding = self.attrs.get('pads', [0, 0, 0, 0])
strides = self.attrs.get('strides', [1, 1])
x = ops.pad(x, ops.utils.normalize_padding(padding))
output = ops.conv2d(x, w, stride=strides, dilations=dilations, groups=groups)
padding = ops.utils.normalize_padding(padding)
# currently conv2d only supports symmetric padding, like torch
if not (padding[0] == padding[2] and padding[1] == padding[3]):
x = ops.pad(x, padding)
output = ops.conv2d(x, w, stride=strides, dilations=dilations, groups=groups)
else:
output = ops.conv2d(
x, w, padding=(padding[0], padding[1]), stride=strides, dilations=dilations, groups=groups
)
if bias is not None:
bias = ops.unsqueeze(bias, [0, 2, 3])
output = output + bias
3 changes: 1 addition & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,7 @@ def conv1d_transpose(

@register_function(torch.nn.functional.conv2d)
def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups):
x = ops.conv_pad(x, padding)
y = ops.conv2d(x, weight, stride, dilation, groups)
y = ops.conv2d(x, weight, stride, dilation, groups, padding=padding)
if bias is not None:
y = y + ops.unsqueeze(bias, [0, 2, 3])
return y
55 changes: 41 additions & 14 deletions python/hidet/graph/ops/conv2d/conv2d.py
Original file line number Diff line number Diff line change
@@ -15,15 +15,26 @@
from hidet.graph.ops.utils import compute, input_like, normalize_stride, normalize_dilations, reduce


# pylint: disable=too-many-locals
class Conv2dTask(Task):
def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dilations: List[int], groups: int):
# pylint: disable=too-many-locals
def __init__(
self,
data: TensorNode,
weight: TensorNode,
padding: List[int],
stride: List[int],
dilations: List[int],
groups: int,
):
from hidet.ir.compute.cops import pad

# we assume that only data needs to have dynamic shape
n, c, h, w = data.shape
n, c, _, _ = data.shape
oc, wc, kx, ky = weight.shape
sx, sy = stride
dilx, dily = dilations
p, q = (h - dilx * (kx - 1) - 1) // sx + 1, (w - dily * (ky - 1) - 1) // sy + 1
pad_h, pad_w = padding

self._assert(
ir.logical_or(c % groups == 0, oc % groups == 0),
msg=(
@@ -39,15 +50,22 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila
),
)
out_group_size = oc // groups

pads = [0, 0, pad_h, pad_w, 0, 0, pad_h, pad_w]
padded = pad(data, pads, value=0.0) # only zero padding is needed right now

_, _, ph, pw = padded.shape
p, q = (ph - dilx * (kx - 1) - 1) // sx + 1, (pw - dily * (ky - 1) - 1) // sy + 1

output = compute(
name='out',
shape=[n, oc, p, q],
fcompute=lambda ni, oci, pi, qi: reduce(
shape=[wc, kx, ky],
fcompute=lambda wci, kxi, kyi: (
data[ni, (oci // out_group_size) * wc + wci, pi * sx + kxi * dilx, qi * sy + kyi * dily]
* weight[oci, wci, kxi, kyi]
),
fcompute=lambda wci, kxi, kyi: padded[
ni, (oci // out_group_size) * wc + wci, pi * sx + kxi * dilx, qi * sy + kyi * dily
]
* weight[oci, wci, kxi, kyi],
reduce_type='sum',
),
)
@@ -100,13 +118,21 @@ def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], dila


class Conv2dOp(Operator):
def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union[int, Sequence[int]], groups: int):
def __init__(
self,
x: Tensor,
w: Tensor,
padding: Sequence[int],
stride: Sequence[int],
dilations: Union[int, Sequence[int]],
groups: int,
):
stride = normalize_stride(stride)
dilations = normalize_dilations(dilations)
super().__init__(
inputs=[x, w],
attributes={'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv2dTask(input_like(x, 'x'), input_like(w, 'w'), stride, dilations, groups),
attributes={'padding': padding, 'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv2dTask(input_like(x, 'x'), input_like(w, 'w'), padding, stride, dilations, groups),
)


@@ -124,11 +150,12 @@ def __init__(self, x: Tensor, w: Tensor, stride: Sequence[int], dilations: Union
def conv2d(
data: Tensor,
weight: Tensor,
stride: Union[int, Sequence[int]] = (1, 1),
dilations: Union[int, Sequence[int]] = (1, 1),
stride: Sequence[int] = (1, 1),
dilations: Sequence[int] = (1, 1),
groups: int = 1,
padding: Sequence[int] = (0, 0),
) -> Tensor:
return Conv2dOp(data, weight, stride, dilations, groups).get_output(0)
return Conv2dOp(data, weight, padding, stride, dilations, groups).get_output(0)


def conv2d_channel_last(
Loading