Skip to content

Commit

Permalink
[Fix] Fixing a bug in register_methods (#331)
Browse files Browse the repository at this point in the history
Previously, an error was encountered during a model compilation attempt:

> torch._dynamo.exc.BackendCompilerFailed: backend='hidet' raised:
> RuntimeError: Can not interpreting max given arguments:
>   max(tensor(...))
> Possible candidates are:
> torch_max_v3(x: hidet.Tensor, dim: Union[int, hidet.ir.expr.Expr],
keepdim: bool = False, *, out: Union[hidet.Tensor, Tuple[hidet.Tensor,
...], List[hidet.Tensor]] = None) -> Tuple[hidet.Tensor, hidet.Tensor]
> File
"/home/bolin/Desktop/hidet/python/hidet/graph/frontend/torch/register_functions.py",
line 1067

Despite we indeed have a
[function](https://github.com/CentML/hidet/blob/13a806608d40de2de1fcc682adeea8d204189f3c/python/hidet/graph/frontend/torch/register_functions.py#L1056-L1060)
that can be used to interpret the `torch.Tensor.max` with described
arguments.
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Jul 23, 2024
1 parent ff9445e commit c87c515
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _raise_exception(exception: Exception, target, caused_callable: Any, args, k
if isinstance(caused_callable, OverloadedFunction):
dispatched = caused_callable.resolve(*args, **kwargs)
if dispatched is None:
msg = ['Can not interpreting {} given arguments: '.format(target_name)]
msg = ['Can not interpret {} given arguments: '.format(target_name)]
msg.append(' {}({})'.format(target_name, args_string))
msg.append('Possible candidates are: ')
for overload, sig in zip(caused_callable.functions, caused_callable.signatures):
Expand Down
32 changes: 22 additions & 10 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def cat(tensors: List[Tensor], dim: int = 0):


@register_function(torch.cat)
def cat(tensors: List[Tensor], axis: int): # PyTorch supports axis as well as the argument name
def cat_v2(tensors: List[Tensor], axis: int): # PyTorch supports axis as well as the argument name
dtype = functools.reduce(promote_type, [t.dtype for t in tensors])
tensors = [ops.cast(t, dtype) for t in tensors]
return ops.concat(tensors, axis)
Expand Down Expand Up @@ -1063,15 +1063,19 @@ def minimum(x: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensor


@register_function(torch.max)
@register_method(torch.Tensor.max)
def torch_max(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.max(..., out=...)")
return ops.max(x, dims=list(range(len(x.shape))), keep_dim=True)

# According to the PyTorch documentation,
# calling torch.max(tensor(...)) or some_tensor.max() results in a singleton tensor with shape torch.Size([]).
return ops.max(x, dims=list(range(len(x.shape))), keep_dim=False)


@register_function(torch.max)
@register_method(torch.Tensor.max)
def torch_max(
def torch_max_v2(
x: Tensor, other: Union[Tensor, int], *, out: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if out is not None:
Expand All @@ -1085,7 +1089,7 @@ def torch_max(
@register_function(torch.max)
@register_method(torch.Tensor.max)
def torch_max_v3(
x: Tensor, dim: Int, keepdim: bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None
x: Tensor, dim: Int, keepdim: bool = False, *, out: Optional[Union[Tensor, Tuple[Tensor, ...], List[Tensor]]] = None
) -> Tuple[Tensor, Tensor]:
if out is not None:
raise NotImplementedError("hidet: does not support torch.max(..., out=...)")
Expand All @@ -1095,14 +1099,19 @@ def torch_max_v3(


@register_function(torch.min)
@register_method(torch.Tensor.min)
def torch_min(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.min(..., out=...)")
return ops.min(x, dims=list(range(len(x.shape))), keep_dim=True)

# Same as torch.max,
# torch.min(tensor(...)) or some_tensor.min() results in a singleton tensor with shape torch.Size([]).
return ops.min(x, dims=list(range(len(x.shape))), keep_dim=False)


@register_function(torch.min)
def torch_min(
@register_method(torch.Tensor.min)
def torch_min_v2(
x: Tensor, other: Union[Tensor, int], *, out: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if out is not None:
Expand All @@ -1114,6 +1123,7 @@ def torch_min(


@register_function(torch.min)
@register_method(torch.Tensor.min)
def torch_min_v3(
x: Tensor, dim: Int, keepdim: bool = False, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -1212,13 +1222,15 @@ def tensor_pow(self: Union[Tensor, Number], exponent: Union[Tensor, Number]) ->
def torch_mean(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.mean(x, dims=list(range(len(x.shape))), keep_dim=True)

# turns out here keep_dim should be False too, similar to torch.max/min/sum
output = ops.mean(x, dims=list(range(len(x.shape))), keep_dim=False)
return output


@register_function(torch.mean)
@register_method(torch.Tensor.mean)
def torch_mean(
def torch_mean_v2(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
Expand Down Expand Up @@ -1254,13 +1266,13 @@ def torch_var(
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True)
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=False)
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(
def torch_sum_v2(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/frontend/torch/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def decorator(hidet_func):

def register_method(method: Callable):
def decorator(hidet_method):
if method not in Registry.registered_functions:
if method not in Registry.registered_methods:
Registry.registered_methods[method] = OverloadedFunction()
Registry.registered_methods[method].overload(hidet_method)
return hidet_method
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def check_module(model: torch.nn.Module, args: Sequence[torch.Tensor], atol=1e-4
raise ValueError('torch_outputs and hidet_outputs have different length')

for torch_output, hidet_output in zip(torch_outputs, hidet_outputs):
# Turns out np.testing.assert_allclose sometimes can pass even if the shapes are different
assert (
torch_output.shape == hidet_output.shape
), f"Shape mismatch --- eager: {torch_output.shape} vs hidet: {hidet_output.shape}"
torch_output = torch_output.detach().cpu().numpy()
hidet_output = hidet_output.detach().cpu().numpy()
numpy.testing.assert_allclose(torch_output, hidet_output, atol=atol, rtol=rtol)
Expand Down
81 changes: 73 additions & 8 deletions tests/frontends/torch/test_torch_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,92 @@ def test_minimum(shape):

@pytest.mark.parametrize('shape', [[2], [2, 3], [2, 3, 4]])
def test_max(shape):
check_module(FunctionalModule(op=lambda x: torch.max(x)), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)
check_module(FunctionalModule(op=lambda x: torch.max(x)), args=[torch.randn(shape)], atol=0, rtol=0)
check_module(
FunctionalModule(op=lambda x, y: torch.max(x, y)),
args=[torch.randn(shape), torch.randn(shape)],
atol=1e-5,
rtol=1e-5,
FunctionalModule(op=lambda x, y: torch.max(x, y)), args=[torch.randn(shape), torch.randn(shape)], atol=0, rtol=0
)
check_module(FunctionalModule(op=lambda x, dim: torch.max(x, dim)), args=[torch.randn(shape), 0], atol=0, rtol=0)

# Do the same checks as above all over again, this time for torch.Tensor.max methods
check_module(FunctionalModule(op=lambda x: x.max()), args=[torch.randn(shape)], atol=0, rtol=0)
check_module(FunctionalModule(op=lambda x, dim: x.max(dim)), args=[torch.randn(shape), 0], atol=0, rtol=0)
check_module(
FunctionalModule(op=lambda x, dim: torch.max(x, dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5
FunctionalModule(op=lambda x, y: x.max(y)), args=[torch.randn(shape), torch.randn(shape)], atol=0, rtol=0
)


@pytest.mark.parametrize('shape', [[2], [2, 3], [2, 3, 4]])
def test_min(shape):
check_module(FunctionalModule(op=lambda x: torch.min(x)), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)
check_module(FunctionalModule(op=lambda x: torch.min(x)), args=[torch.randn(shape)], atol=0, rtol=0)
check_module(
FunctionalModule(op=lambda x, y: torch.min(x, y)),
args=[torch.randn(shape), torch.randn(shape)],
atol=1e-5,
rtol=1e-5,
)
check_module(FunctionalModule(op=lambda x, dim: torch.min(x, dim)), args=[torch.randn(shape), 0], atol=0, rtol=0)

# Doing the same checks as above again, this time for `torch.Tensor.min` method.
check_module(FunctionalModule(op=lambda x: x.min()), args=[torch.randn(shape)], atol=0, rtol=0)

check_module(FunctionalModule(op=lambda x, dim: x.min(dim)), args=[torch.randn(shape), 0], atol=0, rtol=0)

check_module(
FunctionalModule(op=lambda x, y: x.min(y)), args=[torch.randn(shape), torch.randn(shape)], atol=0, rtol=0
)


@pytest.mark.parametrize('shape', [[2], [2, 3], [2, 3, 4]])
def test_sum(shape):
# Similar idea as test_max and test_min
check_module(FunctionalModule(op=lambda x: torch.sum(x)), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)

check_module(
FunctionalModule(op=lambda x, dim: torch.sum(x, dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5
)

check_module(FunctionalModule(op=lambda x: x.sum()), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)

check_module(FunctionalModule(op=lambda x, dim: x.sum(dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5)

check_module(
FunctionalModule(op=lambda x, dim: torch.min(x, dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5
FunctionalModule(op=lambda x, dim: x.sum(dim)),
args=[torch.randn(shape), list(range(len(shape)))],
atol=1e-5,
rtol=1e-5,
)

check_module(
FunctionalModule(op=lambda x, dim: x.sum(dim, keepdim=True)),
args=[torch.randn(shape), None],
atol=1e-5,
rtol=1e-5,
)


@pytest.mark.parametrize('shape', [[2], [2, 3], [2, 3, 4]])
def test_mean(shape):
# Similar idea as test_sum
check_module(FunctionalModule(op=lambda x: torch.mean(x)), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)

check_module(
FunctionalModule(op=lambda x, dim: torch.mean(x, dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5
)

check_module(
FunctionalModule(op=lambda x, dim: torch.mean(x, dim)),
args=[torch.randn(shape), list(range(len(shape)))],
atol=1e-5,
rtol=1e-5,
)

check_module(FunctionalModule(op=lambda x: x.mean()), args=[torch.randn(shape)], atol=1e-5, rtol=1e-5)

check_module(FunctionalModule(op=lambda x, dim: x.mean(dim)), args=[torch.randn(shape), 0], atol=1e-5, rtol=1e-5)

check_module(
FunctionalModule(op=lambda x, dim: x.mean(dim)),
args=[torch.randn(shape), list(range(len(shape)))],
atol=1e-5,
rtol=1e-5,
)

0 comments on commit c87c515

Please sign in to comment.