Skip to content

Commit

Permalink
[Operators] Adding support for torch.nn.TransformerEncoder (#327)
Browse files Browse the repository at this point in the history
Closes #219
BolinSNLHM authored and vadiklyutiy committed Jul 23, 2024
1 parent 3638a0b commit d625146
Showing 3 changed files with 374 additions and 53 deletions.
338 changes: 286 additions & 52 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
@@ -10,19 +10,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Tuple, Optional

import torch
from hidet.graph import ops
from hidet.graph.tensor import Tensor
from .interpreter import HidetModule, register_module
from . import register_functions as regs
from .interpreter import HidetModule, register_module, warnings
from . import register_functions as reg_funcs, register_methods as reg_methods
from .dynamo_config import dynamo_config


@register_module(torch.nn.Conv1d)
class HidetConv1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Conv1d)
return regs.conv1d(
return reg_funcs.conv1d(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -37,7 +40,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetConvTranspose1d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose1d)
return regs.conv1d_transpose(
return reg_funcs.conv1d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -53,7 +56,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetConv2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Conv2d)
return regs.conv2d(
return reg_funcs.conv2d(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -68,7 +71,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetConvTranspose2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose2d)
return regs.conv2d_transpose(
return reg_funcs.conv2d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -84,7 +87,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetConv3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Conv3d)
return regs.conv3d(
return reg_funcs.conv3d(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -99,7 +102,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetConvTranspose3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ConvTranspose3d)
return regs.conv3d_transpose(
return reg_funcs.conv3d_transpose(
x=x,
weight=self.param('weight'),
bias=self.param('bias', optional=True),
@@ -115,28 +118,28 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetAdaptiveAvgPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AdaptiveAvgPool2d)
return regs.adaptive_avg_pool2d(x, self.mod.output_size)
return reg_funcs.adaptive_avg_pool2d(x, self.mod.output_size)


@register_module(torch.nn.AdaptiveAvgPool3d)
class HidetAdaptiveAvgPool3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AdaptiveAvgPool3d)
return regs.adaptive_avg_pool3d(x, self.mod.output_size)
return reg_funcs.adaptive_avg_pool3d(x, self.mod.output_size)


@register_module(torch.nn.ReLU)
class HidetReLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ReLU)
return regs.relu(x, self.mod.inplace)
return reg_funcs.relu(x, self.mod.inplace)


@register_module(torch.nn.MaxPool2d)
class HidetMaxPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.MaxPool2d)
return regs.max_pool2d(
return reg_funcs.max_pool2d(
x=x,
kernel_size=self.mod.kernel_size,
stride=self.mod.stride,
@@ -151,7 +154,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetMaxPool3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.MaxPool3d)
return regs.max_pool3d(
return reg_funcs.max_pool3d(
x=x,
kernel_size=self.mod.kernel_size,
stride=self.mod.stride,
@@ -166,7 +169,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetZeroPad2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ZeroPad2d)
return regs.torch_pad(x=x, pad=self.mod.padding, mode='constant', value=0.0)
return reg_funcs.torch_pad(x=x, pad=self.mod.padding, mode='constant', value=0.0)


@register_module(torch.nn.Linear)
@@ -181,7 +184,7 @@ def __init__(self, torch_module: torch.nn.Module):

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return regs.linear(
return reg_funcs.linear(
x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True
)

@@ -191,7 +194,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetBatchNorm2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d))
return regs.batch_norm(
return reg_funcs.batch_norm(
x=x,
running_mean=self.param('running_mean'),
running_var=self.param('running_var'),
@@ -210,14 +213,14 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetDropout2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, (torch.nn.Dropout, torch.nn.Dropout1d, torch.nn.Dropout2d, torch.nn.Dropout3d))
return regs.dropout(x, self.mod.p, self.mod.training, self.mod.inplace)
return reg_funcs.dropout(x, self.mod.p, self.mod.training, self.mod.inplace)


@register_module(torch.nn.LayerNorm)
class HidetLayerNorm(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.LayerNorm)
return regs.layer_norm(
return reg_funcs.layer_norm(
x=x,
normalized_shape=self.mod.normalized_shape,
weight=self.param('weight'),
@@ -230,7 +233,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetGroupNorm(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GroupNorm)
return regs.group_norm(
return reg_funcs.group_norm(
x=x,
num_groups=self.mod.num_groups,
num_channels=self.mod.num_channels,
@@ -244,21 +247,21 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetTanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Tanh)
return regs.tanh(x)
return reg_funcs.tanh(x)


@register_module(torch.nn.Hardtanh)
class HidetHardtanh(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardtanh)
return regs.hardtanh(x, self.mod.min_val, self.mod.max_val)
return reg_funcs.hardtanh(x, self.mod.min_val, self.mod.max_val)


@register_module(torch.nn.Embedding)
class HidetEmbedding(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Embedding)
return regs.embedding(
return reg_funcs.embedding(
x=x,
weight=self.param('weight'),
padding_idx=self.mod.padding_idx,
@@ -273,28 +276,28 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetReLU6(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.ReLU6)
return regs.relu6(x, self.mod.inplace)
return reg_funcs.relu6(x, self.mod.inplace)


@register_module(torch.nn.Sigmoid)
class HidetSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Sigmoid)
return regs.sigmoid(x)
return reg_funcs.sigmoid(x)


@register_module(torch.nn.Hardsigmoid)
class HidetHardsigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardsigmoid)
return regs.hardsigmoid(x, self.mod.inplace)
return reg_funcs.hardsigmoid(x, self.mod.inplace)


@register_module(torch.nn.AvgPool2d)
class HidetAvgPool2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AvgPool2d)
return regs.avg_pool2d(
return reg_funcs.avg_pool2d(
x=x,
kernel_size=self.mod.kernel_size,
stride=self.mod.stride,
@@ -309,98 +312,98 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetFlatten(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Flatten)
return regs.flatten(x, self.mod.start_dim, self.mod.end_dim)
return reg_funcs.flatten(x, self.mod.start_dim, self.mod.end_dim)


@register_module(torch.nn.Hardswish)
class HidetHardswish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardswish)
return regs.hardswish(x, self.mod.inplace)
return reg_funcs.hardswish(x, self.mod.inplace)


@register_module(torch.nn.GELU)
class HidetGELU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.GELU)
return regs.gelu(x, self.mod.approximate)
return reg_funcs.gelu(x, self.mod.approximate)


@register_module(torch.nn.SiLU)
class HidetSiLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.SiLU)
return regs.silu(x, self.mod.inplace)
return reg_funcs.silu(x, self.mod.inplace)


@register_module(torch.nn.Softmax)
class HidetSoftmax(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softmax)
return regs.softmax(x, self.mod.dim)
return reg_funcs.softmax(x, self.mod.dim)


@register_module(torch.nn.Softmin)
class HidetSoftmin(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softmin)
return regs.softmin(x, self.mod.dim)
return reg_funcs.softmin(x, self.mod.dim)


@register_module(torch.nn.Softplus)
class HidetSoftplus(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softplus)
return regs.softplus(x, self.mod.beta, self.mod.threshold)
return reg_funcs.softplus(x, self.mod.beta, self.mod.threshold)


@register_module(torch.nn.Softsign)
class HidetSoftsign(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softsign)
return regs.softsign(x)
return reg_funcs.softsign(x)


@register_module(torch.nn.Softshrink)
class HidetSoftshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Softshrink)
return regs.softshrink(x, self.mod.lambd)
return reg_funcs.softshrink(x, self.mod.lambd)


@register_module(torch.nn.Tanhshrink)
class HidetTanhshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Tanhshrink)
return regs.tanhshrink(x)
return reg_funcs.tanhshrink(x)


@register_module(torch.nn.Hardshrink)
class HidetHardshrink(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Hardshrink)
return regs.hardshrink(x, self.mod.lambd)
return reg_funcs.hardshrink(x, self.mod.lambd)


@register_module(torch.nn.CELU)
class HidetCELU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.CELU)
return regs.celu(x, self.mod.alpha)
return reg_funcs.celu(x, self.mod.alpha)


@register_module(torch.nn.LogSigmoid)
class HidetLogSigmoid(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.LogSigmoid)
return regs.logsigmoid(x)
return reg_funcs.logsigmoid(x)


@register_module(torch.nn.Mish)
class HidetMish(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Mish)
return regs.mish(x, self.mod.inplace)
return reg_funcs.mish(x, self.mod.inplace)


@register_module(torch.nn.Identity)
@@ -414,7 +417,7 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetUpsample(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Upsample)
return regs.interpolate(
return reg_funcs.interpolate(
x,
size=self.mod.size,
scale_factor=self.mod.scale_factor,
@@ -435,6 +438,9 @@ def __init__(self, torch_module: torch.nn.Module):
self.torch_params['out_proj.weight'] = None
self.hidet_params['in_proj_weight'] = None
self.hidet_params['out_proj.weight'] = None

self.num_heads = self.mod.num_heads
self.head_dim = self.mod.head_dim
torch.cuda.empty_cache()

def __call__(
@@ -447,25 +453,37 @@ def __call__(
attn_mask=None,
average_attn_weights=True,
is_causal=False,
) -> Tensor:
) -> Tuple[Tensor, Optional[Tensor]]:
assert isinstance(self.mod, torch.nn.MultiheadAttention)
# pylint: disable=protected-access
supported = (
self.mod._qkv_same_embed_dim
and self.mod.bias_k is None
and self.mod.bias_v is None
and not self.mod.add_zero_attn
and self.mod.batch_first
and key_padding_mask is None
and not need_weights
)
# pylint: enable=protected-access
if not supported:
raise NotImplementedError(
"Hidet Multihead Attention currently only supports "
"kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False, "
"batch_first=True, forward(key_padding_mask=None, need_weights=False)."
f"""
HidetMultiheadAttention got: kdim={self.mod.kdim}, vdim={self.mod.vdim}, embed_dim={self.mod.embed_dim},
self.mod.bias_k = {self.mod.bias_k}, self.mod.bias_v = {self.mod.bias_v},
add_zero_attn={self.mod.add_zero_attn},
batch_first={self.mod.batch_first}, key_padding_mask={key_padding_mask},
need_weights={need_weights}, average_attn_weights={average_attn_weights}, is_causal={is_causal}.
currently only supports kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False,
batch_first=True, forward(key_padding_mask=None, need_weights=False).
"""
)

if need_weights:
warnings.warn_once(
"""HidetMultiheadAttention: had need_weights=True, but
currently need_weights will be treated as False, as it forces a much slower computation of SDPA,
and can likely be turned off in most production scenarios."""
)

# Input feed forward
wq, wk, wv = ops.split(self.in_proj_weight_transposed, parts_or_sections=3, axis=1)
query = ops.matmul(query, wq)
key = ops.matmul(key, wk)
@@ -476,14 +494,74 @@ def __call__(
key = ops.add(key, bk)
value = ops.add(value, bv)

# Split heads
split_head_dims = [query.shape[0], query.shape[1], self.mod.num_heads, query.shape[2] // self.mod.num_heads]
assert (
self.mod.bias_k is None and self.mod.bias_v is None
), "HidetMultiheadAttention currently does not support bias_k and bias_v."

if not self.mod.batch_first:
return self._forward_not_batch_first(query, key, value, attn_mask, is_causal)

else:
return self._forward_batch_first(query, key, value, attn_mask, is_causal)

def _forward_not_batch_first(self, query: Tensor, key: Tensor, value: Tensor, attn_mask=True, is_causal=False):
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
# Preparing attention mask
if attn_mask is not None:
# ensure attn_mask is 3D
if len(attn_mask.shape) == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but it should be {correct_2d_size}."
)
attn_mask = attn_mask.unsqueeze(0)
elif len(attn_mask.shape) == 3:
correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.size}, but it should be {correct_3d_size}."
)
else:
raise RuntimeError(f"attn_mask's dimensionality is {len(attn_mask.shape)}, but it should be 2 or 3.")

if attn_mask.shape[0] == 1 and len(attn_mask.shape) == 3:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = reg_methods.tensor_view(
attn_mask, bsz, self.num_heads, attn_mask.size / (bsz * self.num_heads * src_len), src_len
)

q = reg_methods.tensor_view(query, tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = reg_methods.tensor_view(key, src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = reg_methods.tensor_view(value, value.shape[0], bsz * self.num_heads, self.head_dim).transpose(0, 1)

# the new source seq length
src_len = k.shape[1]

q = reg_methods.tensor_view(q, bsz, self.num_heads, tgt_len, self.head_dim)
k = reg_methods.tensor_view(k, bsz, self.num_heads, src_len, self.head_dim)
v = reg_methods.tensor_view(v, bsz, self.num_heads, src_len, self.head_dim)

attn_output = reg_funcs.scaled_dot_product_attention(q, k, v, attn_mask, is_causal)
attn_output = reg_funcs.permute(attn_output, 2, 0, 1, 3)
attn_output = reg_methods.tensor_view(attn_output, bsz * tgt_len, embed_dim)

attn_output = ops.matmul(attn_output, self.out_proj_weight_transposed)
if self.mod.out_proj.bias is not None:
attn_output = ops.add(attn_output, self.param('out_proj.bias'))
attn_output = reg_methods.tensor_view(attn_output, tgt_len, bsz, attn_output.shape[1])
return attn_output, None

def _forward_batch_first(self, query: Tensor, key: Tensor, value: Tensor, attn_mask=None, is_causal=False):
split_head_dims = [query.shape[0], query.shape[1], self.num_heads, query.shape[2] // self.num_heads]
query = ops.transpose(query.reshape(split_head_dims), [0, 2, 1, 3])
key = ops.transpose(key.reshape(split_head_dims), [0, 2, 1, 3])
value = ops.transpose(value.reshape(split_head_dims), [0, 2, 1, 3])

# fmha
out = regs.scaled_dot_product_attention(
out = reg_funcs.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=self.mod.dropout, is_causal=is_causal
)

@@ -493,4 +571,160 @@ def __call__(
out = ops.matmul(out, self.out_proj_weight_transposed)
if self.mod.out_proj.bias is not None:
out = ops.add(out, self.param('out_proj.bias'))
return out

return out, None


@register_module(torch.nn.TransformerEncoderLayer)
class HidetTransformerEncoderLayer(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)

self.self_attn = HidetMultiheadAttention(self.mod.self_attn)

self.linear1 = HidetLinear(self.mod.linear1)
self.dropout = HidetDropout2d(self.mod.dropout)

self.linear2 = HidetLinear(self.mod.linear2)

self.norm_first = self.mod.norm_first
self.norm1 = HidetLayerNorm(self.mod.norm1)
self.norm2 = HidetLayerNorm(self.mod.norm2)

self.dropout1 = HidetDropout2d(self.mod.dropout1)
self.dropout2 = HidetDropout2d(self.mod.dropout2)

from hidet.graph.frontend.torch.interpreter import Registry

mod_activation = self.mod.activation
if mod_activation.__class__ in Registry.registered_modules:
self.activation = Registry.registered_modules[mod_activation.__class__](mod_activation)
elif mod_activation in Registry.registered_functions:
self.activation = Registry.registered_functions[mod_activation]
else:
import torchmultimodal

# torchmultimodal.modules.layers.activation.SiLU is encountered
# while compiling the model torch_multimodal_clip from TorchBench
if isinstance(mod_activation, torchmultimodal.modules.layers.activation.SiLU):
self.activation = lambda x: reg_funcs.sigmoid(1.702 * x) * x
else:
raise NotImplementedError(
f"HidetTransformerEncoder: activation function {mod_activation} is not supported."
)

def supported(self):
# pylint: disable=protected-access
return (
self.mod.self_attn._qkv_same_embed_dim
and self.mod.self_attn.bias_k is None
and self.mod.self_attn.bias_v is None
and not self.mod.self_attn.add_zero_attn
)
# pylint: enable=protected-access

def print_info(self):
# pylint: disable=protected-access
info_str = f"""
self_attn._qkv_same_embed_dim = {self.mod.self_attn._qkv_same_embed_dim},\n
self_attn.bias_k = {self.mod.self_attn.bias_k},\n
self_attn.bias_v = {self.mod.self_attn.bias_v},\n
self_attn.add_zero_attn = {self.mod.self_attn.add_zero_attn},\n
self_attn.batch_first = {self.mod.self_attn.batch_first},\n
"""
# pylint: enable=protected-access
return info_str

def __call__(self, src: Tensor, src_mask=None, src_key_padding_mask=None, is_causal: bool = False) -> Tensor:
assert isinstance(self.mod, torch.nn.TransformerEncoderLayer)

if src_key_padding_mask is not None:
raise NotImplementedError(
f"""HidetTransformerEncoderLayer currently only supports src_key_padding_mask=None,
but got src_key_padding_mask={src_key_padding_mask}."""
)

x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), attn_mask=src_mask, is_causal=is_causal)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, attn_mask=src_mask, is_causal=is_causal))
x = self.norm2(x + self._ff_block(x))
return x

def _sa_block(self, x: Tensor, attn_mask: Tensor, is_causal: bool) -> Tensor:
x = self.self_attn(x, x, x, attn_mask=attn_mask, need_weights=False, is_causal=is_causal)[0]
return self.dropout1(x)

def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)


@register_module(torch.nn.TransformerEncoder)
class HidetTransformerEncoder(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
self.layers = [HidetTransformerEncoderLayer(layer) for layer in self.mod.layers]
self.num_layers = self.mod.num_layers
assert self.num_layers == len(self.layers)
self.norm = HidetLayerNorm(self.mod.norm) if self.mod.norm is not None else None
self.mask_check = self.mod.mask_check

def __call__(self, src: Tensor, mask=None, src_key_padding_mask=None, is_causal=None) -> Tensor:
self_first_layer = self.layers[0]
if not isinstance(self_first_layer, HidetTransformerEncoderLayer):
raise NotImplementedError(
f"""Hidet Transformer Encoder currently only HidetTransformerEncoderLayer,
but got {self_first_layer.__class__}."""
)

if not self_first_layer.supported():
raise NotImplementedError(
f"""Hidet Transformer Encoder currently only supports self_attn with
kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False,
batch_first=True, forward(src_key_padding_mask=None).
\n But we got:
\n {self_first_layer.print_info()}.
"""
)

if mask is not None and mask.device != src.device:
mask = ops.transfer(mask, src.device)

if not (src_key_padding_mask is None and all(layer.supported() for layer in self.layers)):
raise NotImplementedError(
f"""Hidet Transformer Encoder currently only supports self_attn with
kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False,
batch_first=True, forward(src_key_padding_mask=None),
but we got src_key_padding_mask={src_key_padding_mask} and is_causal={is_causal},
"""
)

output = src

batch_first = self.layers[0].mod.self_attn.batch_first
src_size = len(src.shape)
if src_size == 2:
seq_len = src.shape[0]
else:
seq_len = src.shape[1 if batch_first else 0]
is_causal = self._detect_is_causal_mask(mask, is_causal, seq_len)

for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, is_causal=is_causal)
if self.norm is not None:
output = self.norm(output)
return output

def _detect_is_causal_mask(self, mask: Optional[Tensor], is_causal, sz: Optional[int] = None):
make_causal = is_causal is True
if is_causal is None and mask is not None:
sz = mask.shape[-2] if sz is None else sz
causal_mask = ops.triu(ops.full((sz, sz), float('-inf'), dtype=mask.dtype, device=mask.device), diagonal=1)
if mask.shape == causal_mask.shape:
make_causal = bool(ops.all(ops.equal(mask, causal_mask)))
else:
make_causal = False
return make_causal
87 changes: 87 additions & 0 deletions tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
@@ -124,5 +124,92 @@ def test_torch_var(shape, dim):
)


@pytest.mark.parametrize('embed_dim', [512])
@pytest.mark.parametrize('num_heads', [8])
@pytest.mark.parametrize('batch_first', [False, True])
@pytest.mark.parametrize('batch_size', [32])
@pytest.mark.parametrize('target_len, src_len', [[77, 77]])
@pytest.mark.parametrize('have_mask', [True])
@pytest.mark.parametrize('is_causal', [False])
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32])
def test_torch_multihead_attention(
embed_dim, num_heads, batch_first, batch_size, target_len, src_len, have_mask, is_causal, dtype
):
torch_attention = torch.nn.MultiheadAttention(
embed_dim, num_heads, batch_first=batch_first, device='cuda', dtype=dtype
)
query_shape = [target_len, batch_size, embed_dim] if not batch_first else [batch_size, target_len, embed_dim]

query = torch.randn(query_shape, dtype=dtype, device='cuda')
key = query
value = query

if have_mask:
mask = torch.full((target_len, src_len), float('-inf'), dtype=dtype, device='cuda').triu(1)
else:
mask = None
if not have_mask:
is_causal = False

# same as above, but just check the first element in the output tuple
check_module(
model=FunctionalModule(op=lambda *args: torch_attention(*args)[0]),
args=[query, key, value, None, False, mask, False, is_causal],
atol=1e-2,
rtol=1e-2,
)


@pytest.mark.parametrize('d_model', [512])
@pytest.mark.parametrize('nhead', [8])
@pytest.mark.parametrize('dim_feedforward', [2048])
@pytest.mark.parametrize('dropout', [0.0])
@pytest.mark.parametrize('activation', [torch.nn.functional.relu])
@pytest.mark.parametrize('batch_first', [False])
@pytest.mark.parametrize('norm_first', [True])
@pytest.mark.parametrize('src_shape', [[77, 32, 512]])
@pytest.mark.parametrize('need_mask', [True])
@pytest.mark.parametrize('mask_shape', [[77, 77]])
@pytest.mark.parametrize('is_causal', [True])
@pytest.mark.parametrize('dtype', [torch.float16, torch.float32])
def test_torch_transformer_encoder(
d_model,
nhead,
dim_feedforward,
dropout,
activation,
batch_first,
norm_first,
src_shape,
need_mask,
mask_shape,
is_causal,
dtype,
):
torch_layer = torch.nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first,
norm_first=norm_first,
device='cuda',
dtype=dtype,
)

src = torch.randn(src_shape, dtype=dtype, device='cuda')
mask = torch.full(mask_shape, float('-inf'), dtype=dtype, device='cuda').triu(1) if need_mask else None

if not need_mask:
is_causal = False

torch_encoder = torch.nn.TransformerEncoder(torch_layer, num_layers=12)

# Change the atol to 5e-2 since the test is quite flaky here...
# for atol=1e-2 sometimes the test fails with way less than 1% of mismatch
check_module(model=torch_encoder, args=[src, mask, None, is_causal], atol=5e-2, rtol=1e-2)


if __name__ == '__main__':
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tests/operators/test_image.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from hidet.graph.tensor import asarray
from hidet.utils.ort_utils import create_ort_session, ort_inference
from hidet.testing import check_torch_unary
from hidet.graph.frontend.torch import register_functions as regs
from hidet.graph.frontend.torch import register_functions as reg_funcs


class TorchResizeModel(torch.nn.Module):

0 comments on commit d625146

Please sign in to comment.