Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] operator.gt got an object with type <class 'int'> #320

Closed
ruofan-wu opened this issue Jul 18, 2023 · 8 comments
Closed

[Bug] operator.gt got an object with type <class 'int'> #320

ruofan-wu opened this issue Jul 18, 2023 · 8 comments
Labels
bug Something isn't working

Comments

@ruofan-wu
Copy link

Hi @yaoyaoding, I encounter a bug when I run T5Model with hidet:

DEBUG:hidet.graph.frontend.torch.interpreter:interpreting node 34: %gt : [#users=1] = call_function[target=operator.gt](args = (%sub_1, 0), kwargs = {})
Traceback (most recent call last):
  File "hidet-path/python/hidet/graph/frontend/torch/interpreter.py", line 260, in forward
    hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
  File "hidet-path/python/hidet/graph/frontend/torch/register_functions.py", line 694, in gt
    return ops.greater(a, b)
  File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 85, in greater
    return GreaterOp(x, y).get_output(0)
  File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 34, in __init__
    super().__init__(x, y, lambda a, b: a > b, name='gt')
  File "hidet-path/python/hidet/graph/ops/definitions/arithmetic.py", line 125, in __init__
    task=BinaryElementwiseTask(name, input_like(x, 'x'), input_like(y, 'y'), op=op),
  File "hidet-path/python/hidet/graph/ops/definitions/utils/tensor_utils.py", line 26, in input_like
    raise TypeError('Expect a hidet.Tensor, but got an object with type {}'.format(type(tensor)))
TypeError: Expect a hidet.Tensor, but got an object with type <class 'int'>

Could you please help me fix it?

@ruofan-wu ruofan-wu added the bug Something isn't working label Jul 18, 2023
@yaoyaoding
Copy link
Member

Hi @GisellWu,

Any minimal reproducible example to reproduce the error?

@ruofan-wu
Copy link
Author

import torch
from transformers import T5Tokenizer, T5Model
import hidet

model_name = 't5-base' 
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name).to(device="cuda:0")

model = torch.compile(model, backend='hidet')

input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
                               padding='max_length', truncation=True, max_length=128).to(device="cuda:0")

outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)

@ruofan-wu
Copy link
Author

Furthermore, I added some functions to register_functions.py in order to run through T5Model:

@register_function(torch.abs)
def abs(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
    if out is not None:
        raise NotImplementedError("hidet: does not support torch.abs(..., out=...)")
    return ops.abs(x)

@register_function(torch.log)
def log(x: Tensor, *, out: Optional[Tensor] = None) -> Tensor:
    if out is not None:
        raise NotImplementedError("hidet: does not support torch.log(..., out=...)")
    return ops.log(x)

@register_function(torch.full_like)
def full_like(input, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format):
    if layout not in [None, torch.strided]:
        raise NotImplementedError("hidet: does not support torch.full(..., layout=..., ...)")
    if requires_grad and torch.is_grad_enabled():
        warnings.warn_once("hidet: requires_grad=True when torch.is_grad_enabled(), treating as requires_grad=False")
    hidet_device: Device = device_from_torch(torch_device=device)
    hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype)
    return ops.full(input.size(), fill_value, dtype=hidet_dtype, device=hidet_device)

@register_function(torch.zeros_like)
def zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format):
    import hidet

    if layout is not None:
        raise NotImplementedError("layout is not None")
    
    size = input.size()
    if len(size) == 1:
        if isinstance(size[0], (list, tuple)):
            size = size[0]
    shape = [int(v) for v in size]

    if dtype is None:
        dtype = torch.get_default_dtype()

    _ = requires_grad

    return hidet.zeros(shape, dtype=dtype_from_torch(dtype), device=device_from_torch(device))

@yaoyaoding
Copy link
Member

Hi @GisellWu,

I added the missing operators and fixed some bugs in #322 for T5 model. Could you give a try again?

@ruofan-wu
Copy link
Author

Thanks for your help! I successfully ran it. Close the issue :)

@ruofan-wu
Copy link
Author

Hi @yaoyaoding ,

Sorry to bother you again, I’m trying T5Model with float16. There are some new unsupported functions. Could you please help me fix it?

The error is:

NotImplementedError: The following modules/functions are not supported by hidet yet:
  torch.clamp
  torch.isinf

And the example code is:

import torch
from transformers import T5Tokenizer, T5Model
import hidet

model_name = 't5-base' 
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5Model.from_pretrained(model_name, torch_dtype=torch.float16).to(device="cuda:0")

model = torch.compile(model, backend='hidet')

input_text = ["translate English to French: Hello, how are you?"]
tokens = tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt',
                               padding='max_length', truncation=True, max_length=128).to(device="cuda:0")

outputs = model(input_ids=tokens.input_ids, decoder_input_ids=tokens.input_ids)
logits = outputs.last_hidden_state
print("Logits Shape:", logits.shape)

@ruofan-wu ruofan-wu reopened this Aug 4, 2023
@yaoyaoding
Copy link
Member

Hi @GisellWu,

I added the missing operators in #343, could you give it a try? Thanks!

@ruofan-wu
Copy link
Author

Hi @yaoyaoding,

That's ok. Appreciate your help again!

vadiklyutiy pushed a commit that referenced this issue Jul 22, 2024
…immutable_list` (#320)

As discussed on our Slack channel, the compilation of the model
`dalle2_pytorch` caused a Hidet exception triggered by `expand`:

> RuntimeError: 'immutable_list' object does not support mutation. If
you are attempting to modify the kwargs or args of a torch.fx.Node
object,
> instead create a new copy of it and assign the copy to the node:
>     new_args = ... # copy and mutate args
>     node.args = new_args
> , occurred when interpreting expand with
>   tensor_expand(tensor(...), [-1, 2, -1])

This PR fixes this bug.
vadiklyutiy pushed a commit that referenced this issue Jul 23, 2024
…immutable_list` (#320)

As discussed on our Slack channel, the compilation of the model
`dalle2_pytorch` caused a Hidet exception triggered by `expand`:

> RuntimeError: 'immutable_list' object does not support mutation. If
you are attempting to modify the kwargs or args of a torch.fx.Node
object,
> instead create a new copy of it and assign the copy to the node:
>     new_args = ... # copy and mutate args
>     node.args = new_args
> , occurred when interpreting expand with
>   tensor_expand(tensor(...), [-1, 2, -1])

This PR fixes this bug.
vadiklyutiy pushed a commit that referenced this issue Dec 26, 2024
…immutable_list` (#320)

As discussed on our Slack channel, the compilation of the model
`dalle2_pytorch` caused a Hidet exception triggered by `expand`:

> RuntimeError: 'immutable_list' object does not support mutation. If
you are attempting to modify the kwargs or args of a torch.fx.Node
object,
> instead create a new copy of it and assign the copy to the node:
>     new_args = ... # copy and mutate args
>     node.args = new_args
> , occurred when interpreting expand with
>   tensor_expand(tensor(...), [-1, 2, -1])

This PR fixes this bug.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants