Skip to content

Commit

Permalink
[Model Support] Add support for wav2vec (#303)
Browse files Browse the repository at this point in the history
Major:
1. Enhance the hidet dynamo backend to support the pytorch model that
used
[weight_norm](https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html).
2. Add resolve rule for conv1d.

The example code that use hidet could be found at
https://github.com/egorsmkv/wav2vec2-hidet.

Currently, hidet is still slower than pytorch (hidet 30ms vs. pytorch
25ms on RTX 4090). Will optimize these models when we have more hands.
  • Loading branch information
yaoyaoding authored Jul 5, 2023
1 parent a15f5c0 commit 6ebb6ec
Show file tree
Hide file tree
Showing 14 changed files with 213 additions and 19 deletions.
9 changes: 6 additions & 3 deletions gallery/how-to-guides/add-subgraph-rewrite-rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:

# %%
# We can check that the rewrite rule has been registered:
from hidet.graph.transforms import registered_rewrite_rules
from hidet.graph.transforms import (
registered_rewrite_rules,
clear_registered_rewrite_rules,
)

print('Registered rewrite rules:')
for rule in registered_rewrite_rules:
for rule in registered_rewrite_rules():
assert isinstance(rule, SubgraphRewriteRule)
print(rule.name)

Expand All @@ -146,7 +149,7 @@ def target(self, matched: MatchDict) -> Optional[List[Tensor]]:
# Besides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the
# last line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered
# rewrite rules and then register the rewrite rule we just defined:
registered_rewrite_rules.clear()
clear_registered_rewrite_rules()
register_rewrite_rule(
FuseTwoMatmulRewriteRule()
) # a second way to register the rewrite rule
Expand Down
22 changes: 22 additions & 0 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,33 @@ def __init__(self, torch_module: torch.nn.Module):
def __call__(self, *args, **kwargs):
raise NotImplementedError()

def _get_weight_norm_hook(self, name: str):
from torch.nn.utils.weight_norm import WeightNorm

for hook in self.mod._forward_pre_hooks.values(): # pylint: disable=protected-access
if isinstance(hook, WeightNorm) and hook.name == name:
return hook
return None

def _used_weight_norm(self, name: str) -> bool:
return self._get_weight_norm_hook(name) is not None

def _compute_weight_norm(self, name: str) -> Tensor:
hook = self._get_weight_norm_hook(name)
return hook.compute_weight(self.mod)

def param(self, name: str, optional=False) -> Optional[Tensor]:
if name not in self.torch_params:
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
# to learn more about weight norm.
if self._used_weight_norm(name):
self.torch_params[name] = self._compute_weight_norm(name)
return self.param(name, optional)

if optional:
return None
raise RuntimeError(f"hidet: {self.mod} has no parameter/buffer {name}")

if name not in self.hidet_params:
if self.torch_params[name] is None:
self.hidet_params[name] = None
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .conv1d import conv1d
from .matmul import batch_matmul, matmul, matmul_x86
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import (
conv2d,
Expand All @@ -24,7 +25,6 @@
from .conv2d_transpose import conv2d_transpose, conv2d_transpose_gemm
from .conv3d import conv3d, conv3d_gemm
from .conv3d_transpose import conv3d_transpose
from .matmul import batch_matmul, matmul, matmul_x86
from .pool import avg_pool2d, avg_pool3d, adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d
from .pool import max_pool2d, max_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d
from .activation import relu, leaky_relu, sigmoid, hardsigmoid, clip, relu6, prelu, gelu, silu, hardswish
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/conv1d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# limitations under the License.
from .conv1d import conv1d
from .conv1d import Conv1dOp
from .conv1d_gemm import conv1d_gemm

from . import resolve
93 changes: 93 additions & 0 deletions python/hidet/graph/ops/conv1d/conv1d_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from hidet.graph.ops.utils import Operator, input_like
from hidet.graph.ops.utils import normalize_kernel, normalize_stride
from hidet.graph.tensor import Tensor
from hidet.ir.compute import TensorNode
from hidet.ir.compute import compute
from hidet.ir.expr import is_constant
from hidet.ir.task import Task
from .utils import infer_conv1d_shape


class Conv1dGemmImageTransformTask(Task):
def __init__(self, x: TensorNode, kernel: int, stride: int, dilation: int, groups: int):
n, c, h = x.shape
kx = kernel
sx = stride
dilx = dilation
p = (h - dilx * (kx - 1) - 1) // sx + 1
self._assert(
c % groups == 0,
msg='Conv1d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups),
)
gc = c // groups # group channels
gemm_x = compute(
name='gemm_x',
shape=[groups, n * p, gc * kx],
fcompute=lambda g, i, k: x[i // p, g * gc + k // kx, i % p * sx + k % kx * dilx],
)
super().__init__(name='conv1d_gemm_image_transform', inputs=[x], outputs=[gemm_x])


class Conv1dGemmImageTransformOp(Operator):
def __init__(self, x: Tensor, kernel, stride, dilations, groups):
(kernel,) = normalize_kernel(kernel, dim=1)
(stride,) = normalize_stride(stride, dim=1)
super().__init__(
inputs=[x],
attributes={'kernel': kernel, 'stride': stride, 'groups': groups, 'dilations': dilations},
task=Conv1dGemmImageTransformTask(input_like(x, 'x'), kernel, stride, dilations, groups),
)


def conv1d_gemm_image_transform(x: Tensor, kernel: int, stride: int, dilation: int, groups: int = 1) -> Tensor:
return Conv1dGemmImageTransformOp(x, kernel, stride, dilation, groups).outputs[0]


def conv1d_gemm_filter_transform(w: Tensor, groups: int = 1) -> Tensor:
# weight shape: [oc, c, kx]
# output shape: [groups, c * kx, ogc] where ogc = oc // groups
oc, c, kx = w.shape
# TODO: current assertion mechanism does not cover this use case (only on the task-level)
if is_constant(oc, groups) and oc % groups != 0:
raise ValueError('invalid conv1d groups {} for out channels {}'.format(groups, oc))
ogc = oc // groups
w = w.reshape([groups, ogc, c, kx]) # [groups, ogc, c, kx]
w = w.rearrange([[0], [2, 3], [1]]) # [groups, c * kx, ogc]
return w


def conv1d_gemm_inverse_transform(gemm_y: Tensor, out_height) -> Tensor:
# gemm_y shape: [groups, n * p, ogc]
# output shape: [n, oc, p] where oc = groups * ogc
p = out_height
groups, npq, ogc = gemm_y.shape
# TODO: current assertion mechanism does not cover this use case (only on the task-level)
if is_constant(npq, p) and npq % p != 0:
raise ValueError('invalid conv1d output shape {} for dimension {}'.format(npq, p))
n = npq // p
y = gemm_y.reshape([groups, n, p, ogc])
y = y.rearrange([[1], [0, 3], [2]])
return y


def conv1d_gemm(data: Tensor, weight: Tensor, stride, dilation: int = 1, groups: int = 1) -> Tensor:
from hidet import ops

gemm_x = conv1d_gemm_image_transform(data, kernel=weight.shape[2], stride=stride, dilation=dilation, groups=groups)
gemm_w = conv1d_gemm_filter_transform(weight, groups=groups)
gemm_y = ops.matmul(gemm_x, gemm_w, require_prologue=True)

y_shape = infer_conv1d_shape(data.shape, weight.shape, stride, groups, dilation)
y = conv1d_gemm_inverse_transform(gemm_y, out_height=y_shape[2])
return y
28 changes: 28 additions & 0 deletions python/hidet/graph/ops/conv1d/resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import List, Optional
from hidet.graph.operator import Operator, Tensor
from hidet.graph.transforms import ResolveRule, register_resolve_rule
from hidet.graph import ops
from hidet.ir.expr import is_constant

from .conv1d import Conv1dOp


@register_resolve_rule(Conv1dOp)
class Conv1dResolveRule(ResolveRule):
def __init__(self, enable_winograd=False):
self.enable_winograd = enable_winograd

def resolve(self, op: Operator) -> Optional[List[Tensor]]:
assert isinstance(op, Conv1dOp)
(stride,) = ops.utils.normalize_stride(op.attrs['stride'], dim=1)
groups = op.attrs['groups']
(dilations,) = op.attrs['dilations']
channels = op.inputs[1].shape[0]

if is_constant(channels) and groups == channels:
return None # use depthwise schedule in the default Task

data, weight = op.inputs
# implicit gemm algorithm
out = ops.conv1d_gemm(data, weight, stride, dilations, groups)
return [out]
31 changes: 31 additions & 0 deletions python/hidet/graph/ops/conv1d/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Sequence
from hidet.ir.expr import is_constant
from ..utils import normalize_stride


def infer_conv1d_shape(
x_shape: Sequence[int], w_shape: Sequence[int], stride: int, groups: int, dilation: int
) -> List[int]:
n, c, d = x_shape
oc, gc, kd = w_shape
(sx,) = normalize_stride(stride, dim=1)
dilx = dilation
if is_constant(c) and gc * groups != c:
msg = 'Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups)
raise ValueError(msg)
if oc % groups != 0:
msg = 'Conv2d expects out_channels % groups == 0, got out_channels {} and groups {}'.format(oc, groups)
raise ValueError(msg)
p = (d - dilx * (kd - 1) - 1) // sx + 1
return [n, oc, p]
2 changes: 1 addition & 1 deletion python/hidet/graph/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .resolve_variant import ResolveRule, register_resolve_rule, get_resolve_chain
from .graph_patterns import TensorPattern, OperatorPattern, SubgraphRewriteRule, register_rewrite_rule, op_pattern
from .graph_patterns import registered_rewrite_rules
from .graph_patterns import registered_rewrite_rules, clear_registered_rewrite_rules


def optimize(graph: FlowGraph) -> FlowGraph:
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/graph/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __enter__(self) -> PassContext:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
from ..transforms.graph_patterns import deregister_attn_patterns
from ..transforms.graph_patterns.attn_patterns import deregister_attn_patterns

deregister_attn_patterns()
popped = self._stack.pop()
Expand Down Expand Up @@ -166,7 +166,7 @@ def set_use_attention(self, flag=False) -> PassContext:
if cc < (7, 5):
return self

from ..transforms.graph_patterns import register_attn_patterns, deregister_attn_patterns
from ..transforms.graph_patterns.attn_patterns import register_attn_patterns, deregister_attn_patterns

self.configs['use_attention'] = flag
if flag:
Expand Down
6 changes: 1 addition & 5 deletions python/hidet/graph/transforms/graph_patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@
# limitations under the License.
from .base import TensorPattern, OperatorPattern, SubgraphRewriteRule, MatchDict, Usage, graph_pattern_match
from .base import register_rewrite_rule, op_pattern, registered_rewrite_rules, deregister_rewrite_rule
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns
from .base import clear_registered_rewrite_rules
20 changes: 16 additions & 4 deletions python/hidet/graph/transforms/graph_patterns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,19 @@ def graph_pattern_match(pattern: TensorPattern, target: Tensor, usage: Usage) ->
return None


registered_rewrite_rules: List[SubgraphRewriteRule] = []
_registered_rewrite_rules: List[SubgraphRewriteRule] = []


def registered_rewrite_rules():
# pylint: disable=unused-import

from . import register_all_patterns # register on demand

return list(_registered_rewrite_rules)


def clear_registered_rewrite_rules():
_registered_rewrite_rules.clear()


def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteRule]]):
Expand All @@ -300,10 +312,10 @@ def register_rewrite_rule(rule: Union[SubgraphRewriteRule, Type[SubgraphRewriteR
should be an instance of SubgraphRewriteRule.
"""
if isinstance(rule, SubgraphRewriteRule):
registered_rewrite_rules.append(rule)
_registered_rewrite_rules.append(rule)
return None
elif issubclass(rule, SubgraphRewriteRule):
registered_rewrite_rules.append(rule())
_registered_rewrite_rules.append(rule())
return rule
else:
raise TypeError('rule should be a SubgraphRewriteRule or a subclass of SubgraphRewriteRule')
Expand All @@ -319,7 +331,7 @@ def deregister_rewrite_rule(rule: SubgraphRewriteRule):
The rule to be deregistered.
"""
if isinstance(rule, SubgraphRewriteRule):
registered_rewrite_rules.remove(rule)
_registered_rewrite_rules.remove(rule)
return None
else:
raise TypeError('rule should be a SubgraphRewriteRule')
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# pylint: disable=unused-import
from .arithmetic_patterns import arithmetic_patterns
from .transform_patterns import transform_patterns
from .attn_patterns import attn_patterns, register_attn_patterns, deregister_attn_patterns
from .conv2d_patterns import conv2d_patterns
from .matmul_patterns import matmul_patterns
2 changes: 1 addition & 1 deletion python/hidet/graph/transforms/subgraph_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SubgraphRewritePass(GraphPass):
def process_graph(self, graph: FlowGraph) -> FlowGraph:
graph = graph_utils.functors.clone(graph)
for _ in range(self.max_num_transforms):
updated, graph = self.try_transform(graph, registered_rewrite_rules)
updated, graph = self.try_transform(graph, registered_rewrite_rules())
if not updated:
graph.update_nodes()
return graph
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate(model, text, num_hidden_layers, num_heads, head_dim, device, tokens

@pytest.mark.parametrize('device,opt', [('cpu', False), ('cpu', True), ('cuda', False), ('cuda', True)])
def test_gpt2(device: str, opt: bool):
gpt2_module = hidet.testing.models.gpt2.model()
gpt2_module = hidet.testing.models.gpt2.model(disable_cache=True)

if device == 'cuda':
gpt2_module.cuda()
Expand Down

0 comments on commit 6ebb6ec

Please sign in to comment.