Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NotImplementedError] operator.lshift #370

Closed
HanGuo97 opened this issue Oct 28, 2023 · 7 comments
Closed

[NotImplementedError] operator.lshift #370

HanGuo97 opened this issue Oct 28, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@HanGuo97
Copy link

Describe the Problem
BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
operator.lshift

@HanGuo97 HanGuo97 added the bug Something isn't working label Oct 28, 2023
@yaoyaoding
Copy link
Member

Thanks @HanGuo97 for reporting this.

Hi @Aalanli, could you help add this operator when you have a chance?

@Aalanli
Copy link
Collaborator

Aalanli commented Oct 29, 2023

Fixed in #371

@HanGuo97
Copy link
Author

Thanks for the quick fix!

Does that mean I could simply build from source, and this would work via torch.compile interface?

@Aalanli
Copy link
Collaborator

Aalanli commented Oct 29, 2023

I am not sure about the particulars of your model, but this script worked for me (when built from source):

import hidet
import torch

def test(a):
    return a << 3
t = torch.compile(test, backend='hidet')
t(torch.randn(3, 5, device='cuda').to(torch.int64))

@HanGuo97
Copy link
Author

HanGuo97 commented Oct 30, 2023

Hi @Aalanli,
Thank you so much for your help! The patch you provided worked perfectly for me.

I wanted to follow up with a more detailed example. This particular issue caused several errors for me. Some of these were fixable following the patch you provided, such as pow, bitwise_and, and torch.Tensor.max.

However, there is one error that I’m having difficulty fixing without additional knowledge of the codebase. Here’s the error message:

ValueError: Unknown data type: uint8x4, candidates...

I’ve attached a simplified code that shows this error, and possibly a few others. I would appreciate any help you can offer in resolving this issue. Thank you in advance for your time!

Code
import math
import torch
import jaxtyping
from typing import Tuple

DEFAULT_CONTAINER_NUM_BITS = 8
FloatTensorType = jaxtyping.Float[torch.Tensor, "..."]
UInt8TensorType = jaxtyping.UInt8[torch.Tensor, "..."]
BinaryTensorType = jaxtyping.Bool[torch.Tensor, "..."]
PackedBinaryTensorType = jaxtyping.UInt8[torch.Tensor, "..."]


def from_binary(tensor: BinaryTensorType, num_bits: int) -> UInt8TensorType:
    if tensor.dtype != torch.bool:
        raise TypeError
    if tensor.shape[-1] != num_bits:
        raise ValueError
    if num_bits > 8:
        raise NotImplementedError
    mask = torch.tensor([2], dtype=torch.float32, device=tensor.device) ** torch.arange(
        num_bits - 1, -1, -1,
        dtype=torch.float32,
        device=tensor.device)
    mask = mask.to(dtype=torch.uint8)
    tensor = tensor.to(dtype=torch.uint8)
    output = torch.sum(mask * tensor, dim=-1)
    output = output.to(dtype=torch.uint8)
    return output


def unpack_uint8_into_bool(
    packed_tensor: PackedBinaryTensorType,
    padding_length: int,
) -> BinaryTensorType:
    if packed_tensor.ndim != 1:
        raise ValueError
    if packed_tensor.dtype != torch.uint8:
        raise TypeError
    # Some constants
    packed_dtype = torch.uint8
    packed_num_bits = torch.iinfo(packed_dtype).bits

    # [1, packed_num_bits]
    bits = torch.tensor(
        1,
        dtype=packed_dtype,
        device=packed_tensor.device)
    bits = bits << torch.arange(
        packed_num_bits,
        dtype=packed_dtype,
        device=packed_tensor.device)
    bits = torch.unsqueeze(
        bits,
        dim=0)
    unpacked_tensor = torch.unsqueeze(
        packed_tensor,
        dim=-1)
    unpacked_tensor = unpacked_tensor & bits
    unpacked_tensor = unpacked_tensor > 0
    unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
    unpacked_tensor = unpacked_tensor.view(-1)
    if padding_length > 0:
        unpacked_tensor = unpacked_tensor[:-padding_length]
    return unpacked_tensor


@torch.compile(fullgraph=True, backend="hidet")
def unpack_integer_tensors(
    packed_tensor: PackedBinaryTensorType,
    padding_length: int,
    num_bits: int,
    shape: Tuple[int, ...],
) -> UInt8TensorType:
    packed_size = (
        (math.prod(shape) * num_bits + padding_length) /
        DEFAULT_CONTAINER_NUM_BITS)
    if packed_tensor.shape != (packed_size,):
        raise ValueError

    # [tensor.numel() x num_bits / 8]
    packed_tensor = packed_tensor.contiguous()
    # [tensor.numel() x num_bits]
    binary_tensor = unpack_uint8_into_bool(
        packed_tensor=packed_tensor,
        padding_length=padding_length)
    # [*tensor.shape, num_bits]
    binary_tensor = binary_tensor.view(
        *shape, num_bits)
    return from_binary(
        tensor=binary_tensor,
        num_bits=num_bits)


num_bits = 8
shape = torch.Size([1024, 256, 1])
unpack_integer_tensors(
    torch.randint(
        2 ** 8,
        size=(shape.numel(),),
        dtype=torch.uint8,
        device="cuda"),
    padding_length=0,
    num_bits=num_bits,
    shape=shape,
)

@Aalanli
Copy link
Collaborator

Aalanli commented Oct 31, 2023

No problem!
For me, pr #372 works on the code provided.

@HanGuo97
Copy link
Author

Amazing, thanks a ton! Will give this a try soon.

vadiklyutiy added a commit that referenced this issue Dec 24, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
vadiklyutiy added a commit that referenced this issue Dec 26, 2024
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants