Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Integration] Sync to main #353

Merged
merged 5 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ def load_arg(a, env):
hidet_kwargs = load_arg(node.kwargs, hidet_env)
try:
hidet_env[node.name] = exec_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if exec_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, exec_func, hidet_args, hidet_kwargs)
elif node.op == "call_method":
Expand Down Expand Up @@ -448,6 +452,10 @@ def load_arg(a, env):

try:
hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
from .register_functions import setitem

if hidet_func.functions[0] is setitem:
hidet_env[str(node.args[0])] = hidet_env[node.name]
except Exception as e:
self._raise_exception(e, node.target, hidet_func, hidet_args, hidet_kwargs)

Expand Down
220 changes: 210 additions & 10 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
from hidet.graph import ops
from hidet.utils import same_list
from hidet.ir.type import DataType
from hidet.ir.expr import Expr
from hidet.ir import expr
from hidet.ir.dtypes import promote_type
from hidet.ir.expr import Int
from hidet.ir.expr import Expr, Int, is_constant
from hidet.runtime.device import Device
from .interpreter import register_function, register_method
from .interpreter import warnings
from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar
from .utils import dtype_from_torch, device_from_torch, normalize_to_scalar, convert_to_scalar_if_possible

Number = Union[int, float, bool]

Expand Down Expand Up @@ -97,6 +96,11 @@ def adaptive_avg_pool2d(x: Tensor, output_size):
return ops.adaptive_avg_pool2d(x, output_size)


@register_function(torch.nn.functional.adaptive_avg_pool3d)
def adaptive_avg_pool3d(x: Tensor, output_size):
return ops.adaptive_avg_pool3d(x, output_size)


@register_function(torch.nn.functional.relu)
def relu(x: Tensor, inplace: bool):
# if inplace:
Expand All @@ -109,11 +113,9 @@ def relu(x: Tensor, inplace: bool):
def max_pool2d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=False, return_indices=False):
if dilation != 1 and not same_list(dilation, [1, 1]):
raise NotImplementedError("dilation != 1")
if ceil_mode:
raise NotImplementedError("ceil_mode=True")
if return_indices:
raise NotImplementedError("return_indices=True")
y = ops.max_pool2d(x, kernel_size, stride, padding)
y = ops.max_pool2d(x, kernel_size, stride, padding, ceil_mode=ceil_mode)
return y


Expand All @@ -130,7 +132,9 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=


@register_function(torch.nn.functional.linear)
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]):
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False):
if len(weight.shape) > 1 and not weight_is_transposed:
weight = ops.transpose(weight, [1, 0])
y = ops.matmul(x, weight)
if bias is not None:
y = y + bias
Expand Down Expand Up @@ -205,10 +209,18 @@ def batch_norm(
)
y = ops.batch_norm_infer(x, running_mean, running_var, epsilon=eps)
_ = momentum # unused
if len(x.shape) == 3:
dims = [0, 2]
if len(x.shape) == 4:
dims = [0, 2, 3]
elif len(x.shape) == 5:
dims = [0, 2, 3, 4]
else:
raise NotImplementedError("batch_norm only accepts 3D, 4D, 5D input")
if weight is not None:
y = y * weight.unsqueeze([0, 2, 3])
y = y * weight.unsqueeze(dims)
if bias is not None:
y = y + bias.unsqueeze([0, 2, 3])
y = y + bias.unsqueeze(dims)
return y


Expand All @@ -222,6 +234,70 @@ def getitem(x: Tensor, index):
return x[index]


@register_function(operator.setitem)
def setitem(x: Tensor, item, setvalue):

if isinstance(item, list):
item = tuple(item)
if not isinstance(item, tuple):
item = tuple([item])

if not isinstance(setvalue, (int, float)):
raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values')

# now, the item could have
# 1. integer index
# 2. slice
# 3. Ellipsis
# 4. None
# e.g., [1, 3:5, ..., None]

# process Ellipsis
# e.g., x[1, ..., 2] -> x[1, :, :, 2]
if Ellipsis in item:
if item.count(Ellipsis) > 1:
raise ValueError('Only one ellipsis allowed in index.')
ellipsis_index = item.index(Ellipsis)
ellipsis_ndim = len(x.shape) - sum([1 if axis not in [None, Ellipsis] else 0 for axis in item])
ellipsis_ndim = max(ellipsis_ndim, 0)
item = item[:ellipsis_index] + (slice(None),) * ellipsis_ndim + item[ellipsis_index + 1 :]

# normalize index
normalized_item = []
for i, v in enumerate(item):
if isinstance(v, int):
if v < 0:
v = v + x.shape[i]
if is_constant(v, x.shape[i]) and (v < 0 or v >= x.shape[i]):
raise IndexError('index {} is out of bound for dimension {} with size {}'.format(v, i, x.shape[i]))
normalized_item.append(v)
elif v is not None:
# None affects getitem, but is ignored in setitem
normalized_item.append(v)
item = tuple(normalized_item)

# process slice and integer index
rank = len(x.shape)
while len(item) < rank:
item = item + (slice(None),)
starts, ends, steps = [], [], []
squeeze_dims = []
for dim, v in enumerate(item):
if isinstance(v, (int, Expr)):
squeeze_dims.append(dim)
starts.append(v)
ends.append(v + 1)
steps.append(1)
else:
assert isinstance(v, slice)
starts.append(v.start)
ends.append(v.stop)
steps.append(v.step)

out = ops.set_strided_slice(x, starts, ends, steps, setvalue)
return out


@register_function(operator.mul)
@register_function(torch.mul)
@register_function(torch.ops.aten.mul.Tensor)
Expand Down Expand Up @@ -516,6 +592,7 @@ def permute(x: Tensor, *args):
return ops.transpose(x, dims)


@register_function(torch.swapaxes)
@register_function(torch.transpose)
@register_method(torch.Tensor.transpose)
def transpose(x: Tensor, dim0: int, dim1: int):
Expand Down Expand Up @@ -590,7 +667,7 @@ def addmm(


@register_function(torch.where)
def where(condition: Tensor, x: Tensor, y: Tensor):
def where(condition: Tensor, x: Union[Tensor, Number], y: Union[Tensor, Number]):
return ops.where(cond=condition, x=x, y=y)


Expand Down Expand Up @@ -697,6 +774,7 @@ def sigmoid(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:


@register_function(torch.exp)
@register_method(torch.Tensor.exp)
def exp(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
warnings.warn_once("hidet: does not support torch.exp(..., out=...)")
Expand Down Expand Up @@ -931,6 +1009,13 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor

@register_function(operator.eq)
def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor:
if isinstance(a, Tensor) or isinstance(b, Tensor):
from hidet.graph.ops.utils import convert_to_tensor

if isinstance(a, Tensor):
return ops.equal(a, convert_to_tensor(b, a))
else:
return ops.equal(b, convert_to_tensor(a, b))
return a == b


Expand Down Expand Up @@ -981,6 +1066,28 @@ def torch_mean(
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True)
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.sum(..., out=...)")
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=dim, keep_dim=keepdim)
return output


@register_function(torch.cumsum)
def torch_cumsum(x: Tensor, dim, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
Expand Down Expand Up @@ -1069,3 +1176,96 @@ def zeros_like(
hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype) if dtype else x.dtype

return ops.full(x.shape, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero)


@register_function(torch.clamp)
def clamp(
x: Tensor,
min: Optional[Union[Tensor, Number]] = None,
max: Optional[Union[Tensor, Number]] = None,
*,
out: Optional[Tensor] = None,
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.clamp(..., out=...)")

min = convert_to_scalar_if_possible(min)
max = convert_to_scalar_if_possible(max)

if min is None and max is None:
return x
elif min is None:
if not isinstance(max, Tensor):
assert isinstance(max, (int, float, complex))
max = ops.full([], value=max, dtype=x.dtype, device=x.device)
return ops.minimum(x, max)
elif max is None:
if not isinstance(min, Tensor):
assert isinstance(min, (int, float, complex))
min = ops.full([], value=min, dtype=x.dtype, device=x.device)
return ops.maximum(x, min)
else:
return ops.clamp(x, min, max)


@register_function(torch.isinf)
def isinf(x: Tensor) -> Tensor:
return ops.isinf(x)


@register_function(torch.nn.functional.pad)
def torch_pad(x: Tensor, pad: Union[Tuple[int], List[int]], mode: str = 'constant', value=0):
if isinstance(pad, tuple):
pad = list(pad)
# Torch's pad list has form [p2left, p2right, p1left, p1right, p0left, p0right]
# Hidet's pad list has form [p0left, p1left, p2left, p0right, p1right, p2right]
left = []
right = []
for i, p in enumerate(pad):
if i % 2 == 0:
left.append(p)
else:
right.append(p)
left.reverse()
right.reverse()
pad = []
for p in left:
pad.append(p)
for p in right:
pad.append(p)
return ops.pad(x, pads=pad, mode=mode, value=value)


@register_function(torch.roll)
def torch_roll(x: Tensor, shifts: Union[int, Sequence[int]], dims: Union[int, Sequence[int]] = None):
return ops.roll(x, shifts, dims)


@register_function(torch.nn.functional.normalize)
def torch_normalize(x: Tensor, p=2.0, dim=1, eps=1e-12, out=None):
if out is not None:
raise NotImplementedError("out is not None")
return ops.lp_norm(x, p, dim, eps)


@register_function(torch.clone)
@register_method(torch.Tensor.clone)
def torch_clone(x: Tensor, *, memory_format=torch.preserve_format):
if memory_format is not torch.preserve_format:
warnings.warn_once(
"torch.clone got memory_format not torch.preserve_format, treating it as torch.preserve_format"
)
if x.is_symbolic():
return x
else:
return x.copy()


@register_function(torch.chunk)
def torch_chunk(x: Tensor, chunks: int, dim: int = 0):
return ops.split(x, parts_or_sections=chunks, axis=dim)


@register_function(torch.einsum)
def torch_einsum(equation, *operands):
return ops.einsum(equation, operands)
49 changes: 48 additions & 1 deletion python/hidet/graph/frontend/torch/register_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def tensor_to(self: Tensor, *args, **kwargs) -> Tensor:
if self.is_symbolic() and instantiate_device(device_from_torch(arg)) != self.device:
raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.')
device = arg
elif isinstance(arg, Tensor):
dtype = arg.dtype
if self.is_symbolic() and arg.device != self.device:
raise NotImplementedError('hidet: Tensor.to(..., device=...) is not supported for symbolic tensors.')
device = arg.device
else:
raise ValueError(f'Unsupported argument type: {type(arg)}')

Expand Down Expand Up @@ -222,7 +227,10 @@ def tensor_type(self: Tensor, dtype: Union[str, torch.dtype], non_blocking: bool

@register_method(torch.Tensor.expand)
def tensor_expand(self: Tensor, *sizes: int) -> Tensor:
sizes: List[int] = list(sizes)
if len(sizes) == 1 and isinstance(sizes[0], (list, tuple)):
sizes = sizes[0]
else:
sizes: List[int] = list(sizes)
assert len(sizes) >= len(self.shape)
for i in range(len(sizes)):
if sizes[i] == -1:
Expand Down Expand Up @@ -255,3 +263,42 @@ def tensor_repeat(self: Tensor, *sizes: int) -> Tensor:
@register_method(torch.Tensor.detach)
def tensor_detach(self: Tensor) -> Tensor:
return self


@register_method(torch.Tensor.any)
def tensor_any(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.any(self, axis=dim, keepdims=keepdim)


@register_method(torch.Tensor.all)
def tensor_all(self: Tensor, dim=None, keepdim=False) -> Tensor:
return ops.all(self, axis=dim, keepdims=keepdim)


@register_method(torch.Tensor.matmul)
def tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
return ops.matmul(self, other)


@register_method(torch.Tensor.new_zeros)
def tensor_new_zeros(self: Tensor, *size, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False):
if layout is not None:
raise NotImplementedError("layout is not None")
if len(size) == 1:
if isinstance(size[0], (list, tuple)):
size = size[0]
shape = size
if dtype is None:
dtype = self.dtype
if device is None:
device = self.device

_ = pin_memory
_ = requires_grad

return ops.full(shape, dtype=dtype, device=device, value=dtype.zero)


@register_method(torch.Tensor.zero_)
def tensor_zero_(self: Tensor):
return ops.full(self.shape, dtype=self.dtype, device=self.device, value=self.dtype.zero)
Loading