Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync CentML -> hidet-org #465

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,9 +1762,11 @@ def torch_all(input):


@register_function(torch.all)
def torch_all_v2(input, dim, keepdim=False, *, out=None):
def torch_all_v2(input, dim: Union[int, Sequence[int]], keepdim=False, *, out=None):
if out is not None:
raise NotImplementedError("hidet: does not support torch.all(..., out=...)")
if isinstance(dim, int):
dim = (dim,)
return ops.all(input, axis=dim, keepdims=keepdim)


Expand All @@ -1790,9 +1792,11 @@ def torch_argmin(x, dim: Int = None, keepdim: bool = False):


@register_function(torch.any)
def torch_any_v1(input: Tensor, dim, keepdim=False, *, out=None) -> Tensor:
def torch_any_v1(input: Tensor, dim: Union[int, Sequence[int]], keepdim=False, *, out=None) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.any(..., out=...)")
if isinstance(dim, int):
dim = (dim,)
return ops.any(input, axis=dim, keepdims=keepdim)


Expand All @@ -1801,6 +1805,15 @@ def torch_any_v2(input: Tensor) -> Tensor:
return ops.any(input)


@register_function(torch.t)
@register_method(torch.Tensor.t)
def torch_t(input: Tensor):
assert 0 <= len(input.shape) <= 2, 'torch.t expects tensors <= 2D'
if len(input.shape) == 2:
return ops.transpose(input, [1, 0])
return input


@register_function(torch.nn.functional.unfold)
def torch_unfold(input: Tensor, kernel_size, dilation=1, padding=0, stride=1) -> Tensor:
assert 3 <= len(input.shape) <= 4, "torch.nn.functional.unfold accepts 3D or 4D tensor only"
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/ir/dtypes/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .floats import float32, float16
from .integer import int8, uint8
from .integer_subbyte import int4b, uint4b
from .boolean import boolean


class VectorType(DataType):
Expand Down Expand Up @@ -116,6 +117,7 @@ def vectorize(base_dtype: DataType, num_lanes: int) -> VectorType:
(float16, 2): float16x2,
(int8, 4): int8x4,
(uint8, 4): uint8x4,
(boolean, 4): int8x4,
}
if (base_dtype, num_lanes) in table:
return table[(base_dtype, num_lanes)]
Expand Down
12 changes: 12 additions & 0 deletions tests/frontends/torch/test_torch_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ def test_mean(shape):
atol=1e-5,
rtol=1e-5,
)


@pytest.mark.parametrize(
'shape, dim', [[[2, 4], -1], [[128, 3, 4], 0], [[128, 3, 4], 2], [[72, 5, 64], -1], [[67, 128, 233], 1]]
)
def test_torch_any(shape, dim):
check_module(FunctionalModule(op=lambda x: torch.any(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0)


@pytest.mark.parametrize('shape, dim', [[[2, 3], -1]])
def test_all(shape, dim):
check_module(FunctionalModule(op=lambda x: torch.all(x, dim=dim)), args=[torch.randn(shape) > 0], atol=0, rtol=0)
Loading