-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NotImplementedError] operator.lshift #370
Comments
Fixed in #371 |
Thanks for the quick fix! Does that mean I could simply build from source, and this would work via |
I am not sure about the particulars of your model, but this script worked for me (when built from source): import hidet
import torch
def test(a):
return a << 3
t = torch.compile(test, backend='hidet')
t(torch.randn(3, 5, device='cuda').to(torch.int64)) |
Hi @Aalanli, I wanted to follow up with a more detailed example. This particular issue caused several errors for me. Some of these were fixable following the patch you provided, such as However, there is one error that I’m having difficulty fixing without additional knowledge of the codebase. Here’s the error message:
I’ve attached a simplified code that shows this error, and possibly a few others. I would appreciate any help you can offer in resolving this issue. Thank you in advance for your time! Codeimport math
import torch
import jaxtyping
from typing import Tuple
DEFAULT_CONTAINER_NUM_BITS = 8
FloatTensorType = jaxtyping.Float[torch.Tensor, "..."]
UInt8TensorType = jaxtyping.UInt8[torch.Tensor, "..."]
BinaryTensorType = jaxtyping.Bool[torch.Tensor, "..."]
PackedBinaryTensorType = jaxtyping.UInt8[torch.Tensor, "..."]
def from_binary(tensor: BinaryTensorType, num_bits: int) -> UInt8TensorType:
if tensor.dtype != torch.bool:
raise TypeError
if tensor.shape[-1] != num_bits:
raise ValueError
if num_bits > 8:
raise NotImplementedError
mask = torch.tensor([2], dtype=torch.float32, device=tensor.device) ** torch.arange(
num_bits - 1, -1, -1,
dtype=torch.float32,
device=tensor.device)
mask = mask.to(dtype=torch.uint8)
tensor = tensor.to(dtype=torch.uint8)
output = torch.sum(mask * tensor, dim=-1)
output = output.to(dtype=torch.uint8)
return output
def unpack_uint8_into_bool(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
) -> BinaryTensorType:
if packed_tensor.ndim != 1:
raise ValueError
if packed_tensor.dtype != torch.uint8:
raise TypeError
# Some constants
packed_dtype = torch.uint8
packed_num_bits = torch.iinfo(packed_dtype).bits
# [1, packed_num_bits]
bits = torch.tensor(
1,
dtype=packed_dtype,
device=packed_tensor.device)
bits = bits << torch.arange(
packed_num_bits,
dtype=packed_dtype,
device=packed_tensor.device)
bits = torch.unsqueeze(
bits,
dim=0)
unpacked_tensor = torch.unsqueeze(
packed_tensor,
dim=-1)
unpacked_tensor = unpacked_tensor & bits
unpacked_tensor = unpacked_tensor > 0
unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
unpacked_tensor = unpacked_tensor.view(-1)
if padding_length > 0:
unpacked_tensor = unpacked_tensor[:-padding_length]
return unpacked_tensor
@torch.compile(fullgraph=True, backend="hidet")
def unpack_integer_tensors(
packed_tensor: PackedBinaryTensorType,
padding_length: int,
num_bits: int,
shape: Tuple[int, ...],
) -> UInt8TensorType:
packed_size = (
(math.prod(shape) * num_bits + padding_length) /
DEFAULT_CONTAINER_NUM_BITS)
if packed_tensor.shape != (packed_size,):
raise ValueError
# [tensor.numel() x num_bits / 8]
packed_tensor = packed_tensor.contiguous()
# [tensor.numel() x num_bits]
binary_tensor = unpack_uint8_into_bool(
packed_tensor=packed_tensor,
padding_length=padding_length)
# [*tensor.shape, num_bits]
binary_tensor = binary_tensor.view(
*shape, num_bits)
return from_binary(
tensor=binary_tensor,
num_bits=num_bits)
num_bits = 8
shape = torch.Size([1024, 256, 1])
unpack_integer_tensors(
torch.randint(
2 ** 8,
size=(shape.numel(),),
dtype=torch.uint8,
device="cuda"),
padding_length=0,
num_bits=num_bits,
shape=shape,
) |
No problem! |
Amazing, thanks a ton! Will give this a try soon. |
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Promote version of hidet 0.4.0.dev -> 0.5.0.dev]
Describe the Problem
BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
operator.lshift
The text was updated successfully, but these errors were encountered: