Skip to content

Commit

Permalink
[Torch][Graph][Operator] Add and fix various items for torchvision mo…
Browse files Browse the repository at this point in the history
…del support (#347)

1. Enhance support for `__setitem__` and` __getitem__` of Tensor; Add
SetStridedSlice Op, Roll Op.
2. Add/Update torch mapping for adaptive_avg_pool3d, eq, pad, roll,
matmul, new_zeros, batch_norm, MultiHeadAttention.
3. Update torch Linear mapping to optionally accept transposed weights.
4. Fix a bug where a empty graph will output a zero tensor instead of
the input/weight.
  • Loading branch information
hjjq authored Aug 12, 2023
1 parent 8d755f7 commit 740ff3c
Show file tree
Hide file tree
Showing 10 changed files with 401 additions and 69 deletions.
8 changes: 8 additions & 0 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def load_arg(a, env):
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if exec_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
Expand Down Expand Up @@ -448,6 +452,10 @@ def load_arg(a, env):

try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if hidet_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)

Expand Down
107 changes: 102 additions & 5 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
from hidet.ir.expr import Expr
from hidet.ir import expr
from hidet.ir.dtypes import promote_type
from hidet.ir.expr import Int
from hidet.ir.expr import Expr, Int, is_constant
from hidet.runtime.device import Device
from .interpreter import register_function, register_method
from .interpreter import warnings
Expand Down Expand Up @@ -97,6 +96,11 @@ def adaptive_avg_pool2d(x: Tensor, output_size):
return ops.adaptive_avg_pool2d(x, output_size)


@register_function(torch.nn.functional.adaptive_avg_pool3d)
def adaptive_avg_pool3d(x: Tensor, output_size):
return ops.adaptive_avg_pool3d(x, output_size)


@register_function(torch.nn.functional.relu)
def relu(x: Tensor, inplace: bool):
# if inplace:
Expand Down Expand Up @@ -130,7 +134,9 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=


@register_function(torch.nn.functional.linear)
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]):
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False):
if len(weight.shape) > 1 and not weight_is_transposed:
weight = ops.transpose(weight, [1, 0])
y = ops.matmul(x, weight)
if bias is not None:
y = y + bias
Expand Down Expand Up @@ -205,10 +211,18 @@ def batch_norm(
)
y = ops.batch_norm_infer(x, running_mean, running_var, epsilon=eps)
_ = momentum # unused
if len(x.shape) == 3:
dims = [0, 2]
if len(x.shape) == 4:
dims = [0, 2, 3]
elif len(x.shape) == 5:
dims = [0, 2, 3, 4]
else:
raise NotImplementedError("batch_norm only accepts 3D, 4D, 5D input")
if weight is not None:
y = y * weight.unsqueeze([0, 2, 3])
y = y * weight.unsqueeze(dims)
if bias is not None:
y = y + bias.unsqueeze([0, 2, 3])
y = y + bias.unsqueeze(dims)
return y


Expand All @@ -222,6 +236,70 @@ def getitem(x: Tensor, index):
return x[index]


@register_function(operator.setitem)
def setitem(x: Tensor, item, setvalue):

if isinstance(item, list):
item = tuple(item)
if not isinstance(item, tuple):
item = tuple([item])

if not isinstance(setvalue, (int, float)):
raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values')

# now, the item could have
# 1. integer index
# 2. slice
# 3. Ellipsis
# 4. None
# e.g., [1, 3:5, ..., None]

# process Ellipsis
# e.g., x[1, ..., 2] -> x[1, :, :, 2]
if Ellipsis in item:
if item.count(Ellipsis) > 1:
raise ValueError('Only one ellipsis allowed in index.')
ellipsis_index = item.index(Ellipsis)
ellipsis_ndim = len(x.shape) - sum([1 if axis not in [None, Ellipsis] else 0 for axis in item])
ellipsis_ndim = max(ellipsis_ndim, 0)
item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :]

# normalize index
normalized_item = []
for i, v in enumerate(item):
if isinstance(v, int):
if v < 0:
v = v + x.shape[i]
if is_constant(v, x.shape[i]) and (v < 0 or v >= x.shape[i]):
raise IndexError('index {} is out of bound for dimension {} with size {}'.format(v, i, x.shape[i]))
normalized_item.append(v)
elif v is not None:
# None affects getitem, but is ignored in setitem
normalized_item.append(v)
item = tuple(normalized_item)

# process slice and integer index
rank = len(x.shape)
while len(item) < rank:
item = item + (slice(None),)
starts, ends, steps = [], [], []
squeeze_dims = []
for dim, v in enumerate(item):
if isinstance(v, (int, Expr)):
squeeze_dims.append(dim)
starts.append(v)
ends.append(v + 1)
steps.append(1)
else:
assert isinstance(v, slice)
starts.append(v.start)
ends.append(v.stop)
steps.append(v.step)

out = ops.set_strided_slice(x, starts, ends, steps, setvalue)
return out


@register_function(operator.mul)
@register_function(torch.mul)
@register_function(torch.ops.aten.mul.Tensor)
Expand Down Expand Up @@ -931,6 +1009,13 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor

@register_function(operator.eq)
def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
if isinstance(a, Tensor) or isinstance(b, Tensor):
from hidet.graph.ops.utils import convert_to_tensor

if isinstance(a, Tensor):
return ops.equal(a, convert_to_tensor(b, a))
else:
return ops.equal(b, convert_to_tensor(a, b))
return a == b


Expand Down Expand Up @@ -1104,3 +1189,15 @@ def clamp(
@register_function(torch.isinf)
def isinf(x: Tensor) -> Tensor:
return ops.isinf(x)


@register_function(torch.nn.functional.pad)
def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0):
if isinstance(pad, tuple):
pad = list(pad)
return ops.pad(x, pads=pad, mode=mode, value=value)


@register_function(torch.roll)
def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None):
return ops.roll(x, shifts, dims)
24 changes: 24 additions & 0 deletions python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,27 @@ def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor:
@register_method(torch.Tensor.all)
def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.all(self, axis=dim, keepdims=keepdim)


@register_method(torch.Tensor.matmul)
def tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
return ops.matmul(self, other)


@register_method(torch.Tensor.new_zeros)
def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
if layout is not None:
raise NotImplementedError("layout is not None")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = size
if dtype is None:
dtype = self.dtype
if device is None:
device = self.device

_ = pin_memory
_ = requires_grad

return ops.full(shape, dtype=dtype, device=device, value=dtype.zero)
85 changes: 80 additions & 5 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
from __future__ import annotations
import torch
from hidet.graph import ops
from hidet.graph.tensor import Tensor
from .interpreter import HidetModule, register_module
from . import register_functions as regs
Expand Down Expand Up @@ -117,6 +118,13 @@ def __call__(self, x: Tensor) -> Tensor:
return regs.adaptive_avg_pool2d(x, self.mod.output_size)


@register_module(torch.nn.AdaptiveAvgPool3d)
class HidetAdaptiveAvgPool3d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.AdaptiveAvgPool3d)
return regs.adaptive_avg_pool3d(x, self.mod.output_size)


@register_module(torch.nn.ReLU)
class HidetReLU(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -158,21 +166,21 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetLinear(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
from hidet import ops

steal = dynamo_config['steal_weights']

self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return regs.linear(x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True))
return regs.linear(
x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True
)


@register_module(torch.nn.BatchNorm2d)
@register_module(torch.nn.BatchNorm3d)
class HidetBatchNorm2d(HidetModule):
def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.BatchNorm2d)
assert isinstance(self.mod, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d))
return regs.batch_norm(
x=x,
running_mean=self.param('running_mean'),
Expand Down Expand Up @@ -404,3 +412,70 @@ def __call__(self, x: Tensor) -> Tensor:
align_corners=self.mod.align_corners,
recompute_scale_factor=self.mod.recompute_scale_factor,
)


@register_module(torch.nn.MultiheadAttention)
class HidetMultiheadAttention(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
steal = dynamo_config['steal_weights']
self.in_proj_weight_transposed = ops.transpose(self.param('in_proj_weight', steal=steal), [1, 0])
self.out_proj_weight_transposed = ops.transpose(self.param('out_proj.weight', steal=steal), [1, 0])

def __call__(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
average_attn_weights=True,
is_causal=False,
) -> Tensor:
assert isinstance(self.mod, torch.nn.MultiheadAttention)
supported = (
self.mod._qkv_same_embed_dim
and self.mod.bias_k is None
and self.mod.bias_v is None
and not self.mod.add_zero_attn
and self.mod.batch_first
and key_padding_mask is None
and not need_weights
)
if not supported:
raise NotImplementedError(
"Hidet Multihead Attention currently only supports "
"kdim=vdim=embed_dim, add_bias_kv=False, add_zero_attn=False, "
"batch_first=True, forward(key_padding_mask=None, need_weights=False)."
)

# Input feed forward
wq, wk, wv = ops.split(self.in_proj_weight_transposed, parts_or_sections=3, axis=1)
query = ops.matmul(query, wq)
key = ops.matmul(key, wk)
value = ops.matmul(value, wv)
if self.mod.in_proj_bias is not None:
bq, bk, bv = ops.split(self.param('in_proj_bias'), parts_or_sections=3, axis=0)
query = ops.add(query, bq)
key = ops.add(key, bk)
value = ops.add(value, bv)

# Split heads
split_head_dims = [query.shape[0], query.shape[1], self.mod.num_heads, query.shape[2] // self.mod.num_heads]
query = ops.transpose(query.reshape(split_head_dims), [0, 2, 1, 3])
key = ops.transpose(key.reshape(split_head_dims), [0, 2, 1, 3])
value = ops.transpose(value.reshape(split_head_dims), [0, 2, 1, 3])

# fmha
out = regs.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=self.mod.dropout, is_causal=is_causal
)

# Output feed forward
merge_head_dims = [out.shape[0], out.shape[2], self.mod.embed_dim]
out = ops.transpose(out, [0, 2, 1, 3]).reshape(merge_head_dims)
out = ops.matmul(out, self.out_proj_weight_transposed)
if self.mod.out_proj.bias is not None:
out = ops.add(out, self.param('out_proj.bias'))
return out
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .arithmetic import reciprocal, exp, expm1, log, log2, log10, log1p, logaddexp, erf
from .arithmetic import bitwise_right_shift, bitwise_left_shift, bitwise_and, bitwise_invert, bitwise_or
from .arithmetic import bitwise_xor, maximum, minimum, clamp
from .arithmetic import isfinite, isinf, isnan, sign, where
from .arithmetic import isfinite, isinf, isnan, sign, where, set_strided_slice, roll
from .arithmetic import sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, atan2
from .complex import real, imag, conj, make_complex
from .compare import equal, not_equal, less, greater, less_equal, greater_equal
Expand Down
Loading

0 comments on commit 740ff3c

Please sign in to comment.