Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReformerForQuestionAnswering : int() argument must be a string, a bytes-like object or a number, not 'NoneType' #10370

Closed
2 tasks
harikc456 opened this issue Feb 24, 2021 · 7 comments · Fixed by #11117
Assignees

Comments

@harikc456
Copy link

harikc456 commented Feb 24, 2021

Environment info

  • transformers version:
  • Platform:
  • Python version: 3.7.10
  • PyTorch version (GPU?): 1.7
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten

Information

Model I am using (Bert, XLNet ...): Reformer

The problem arises when using:

  • my own modified scripts: performing a backward() after passing the query and text to the ReformerForQuestionAnswering model.

The tasks I am working on is:

  • an official GLUE/SQUaD task: a subset of SQuAD

To reproduce

Steps to reproduce the behavior:

Performing backward on the loss throwing an error.

Minimal code to reproduce the error.

from transformers import ReformerTokenizer, ReformerForQuestionAnswering
import torch

tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerForQuestionAnswering.from_pretrained('google/reformer-crime-and-punishment')

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors='pt')
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])

outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss
loss.backward()

Error Traceback

create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 

/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py in apply(self, *args)
     87     def apply(self, *args):
     88         # _forward_cls is defined by derived class
---> 89         return self._forward_cls.backward(self, *args)  # type: ignore
     90 
     91 

/usr/local/lib/python3.7/dist-packages/transformers/models/reformer/modeling_reformer.py in backward(***failed resolving arguments***)
   1673                 head_mask=head_mask[len(layers) - idx - 1],
   1674                 attention_mask=attention_mask,
-> 1675                 buckets=buckets,
   1676             )
   1677 

/usr/local/lib/python3.7/dist-packages/transformers/models/reformer/modeling_reformer.py in backward_pass(self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, attention_mask, head_mask, buckets)
   1527 
   1528             # set seed to have correct dropout
-> 1529             torch.manual_seed(self.feed_forward_seed)
   1530             # g(Y_1)
   1531             res_hidden_states = self.feed_forward(next_attn_output)

/usr/local/lib/python3.7/dist-packages/torch/random.py in manual_seed(seed)
     30             `0xffff_ffff_ffff_ffff + seed`.
     31     """
---> 32     seed = int(seed)
     33     import torch.cuda
     34 

TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

From debugging, I believe that the error was caused because the self.feed_forward_seed in ReformerLayer class is None.

I have tried the same code with Longformer and it was working perfectly.

Expected behavior

loss.backward() running properly.

@harikc456 harikc456 changed the title ReformerForQuestionAnswering : int() argument must be a string, a bytes-like object or a number, not 'NoneType', when loss.backward() is called. ReformerForQuestionAnswering : int() argument must be a string, a bytes-like object or a number, not 'NoneType' Feb 24, 2021
@patrickvonplaten
Copy link
Contributor

Hey @harikc456,

The problem is that the model is not put into training mode. If you run the following code:

from transformers import ReformerTokenizer, ReformerForQuestionAnswering
from transformers.models.reformer.modeling_reformer import PositionEmbeddings
import torch

tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerForQuestionAnswering.from_pretrained('google/reformer-crime-and-punishment')

# change to position embeddings to prevent error
model.reformer.embeddings.position_embeddings = PositionEmbeddings(model.config)

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors='pt')
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])

outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss
loss.backward()

you can see that the code runs without error.

@forest1988
Copy link
Contributor

@patrickvonplaten

Hello, I've just come across the same issue.

I tried the code below,

from transformers import ReformerTokenizer, ReformerForQuestionAnswering
from transformers.models.reformer.modeling_reformer import PositionEmbeddings
import torch

tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerForQuestionAnswering.from_pretrained('google/reformer-crime-and-punishment')

# change to position embeddings to prevent error
model.reformer.embeddings.position_embeddings = PositionEmbeddings(model.config)

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors='pt')
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])

outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss
loss.backward()

and got the following error message.

Some weights of the model checkpoint at google/reformer-crime-and-punishment were not used when initializing ReformerForQuestionAnswering: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing ReformerForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ReformerForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ReformerForQuestionAnswering were not initialized from the model checkpoint at google/reformer-crime-and-punishment and are newly initialized: ['reformer.encoder.layers.0.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.0.attention.self_attention.mask_value_float32', 'reformer.encoder.layers.1.attention.self_attention.self_mask_value_float16', 'reformer.encoder.layers.1.attention.self_attention.self_mask_value_float32', 'reformer.encoder.layers.1.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.1.attention.self_attention.mask_value_float32', 'reformer.encoder.layers.2.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.2.attention.self_attention.mask_value_float32', 'reformer.encoder.layers.3.attention.self_attention.self_mask_value_float16', 'reformer.encoder.layers.3.attention.self_attention.self_mask_value_float32', 'reformer.encoder.layers.3.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.3.attention.self_attention.mask_value_float32', 'reformer.encoder.layers.4.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.4.attention.self_attention.mask_value_float32', 'reformer.encoder.layers.5.attention.self_attention.self_mask_value_float16', 'reformer.encoder.layers.5.attention.self_attention.self_mask_value_float32', 'reformer.encoder.layers.5.attention.self_attention.mask_value_float16', 'reformer.encoder.layers.5.attention.self_attention.mask_value_float32', 'qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/path/to/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported.
  warnings.warn("Setting attributes on ParameterList is not supported.")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-60eb084822c0> in <module>
     16 outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
     17 loss = outputs.loss
---> 18 loss.backward()

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    128         retain_graph = create_graph
    129 
--> 130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
    132         allow_unreachable=True)  # allow_unreachable flag

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/torch/autograd/function.py in apply(self, *args)
     87     def apply(self, *args):
     88         # _forward_cls is defined by derived class
---> 89         return self._forward_cls.backward(self, *args)  # type: ignore
     90 
     91 

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py in backward(***failed resolving arguments***)
   1666 
   1667             # backprop
-> 1668             output = layer.backward_pass(
   1669                 next_attn_output=output.attn_output,
   1670                 hidden_states=output.hidden_states,

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py in backward_pass(self, next_attn_output, hidden_states, grad_attn_output, grad_hidden_states, attention_mask, head_mask, buckets)
   1527 
   1528             # set seed to have correct dropout
-> 1529             torch.manual_seed(self.feed_forward_seed)
   1530             # g(Y_1)
   1531             res_hidden_states = self.feed_forward(next_attn_output)

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/torch/random.py in manual_seed(seed)
     30             `0xffff_ffff_ffff_ffff + seed`.
     31     """
---> 32     seed = int(seed)
     33     import torch.cuda
     34 

TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

I first tried to use:

    tokenizer = AutoTokenizer.from_pretrained("google/reformer-crime-and-punishment")
    model = AutoModelForSequenceClassification.from_pretrained(
        "google/reformer-crime-and-punishment", return_dict=True
    )

It failed, then I found this issue and added:

    # change to position embeddings to prevent error
    model.reformer.embeddings.position_embeddings = PositionEmbeddings(model.config)

However, the same error occurs.

  • transformers version: 4.1.1
  • Platform: Linux-4.15.0-135-generic-x86_64-with-glibc2.10
  • Python version: 3.8.3
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): 2.4.1 (False)
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Maybe the problem is that the version of Transformers I am using for this is old?

Thank you in advance.

@forest1988
Copy link
Contributor

It seems that the same issue occurs when I updated the transformers to the latest stable version via pip.

  • transformers version: 4.4.1
  • Platform: Linux-4.15.0-135-generic-x86_64-with-glibc2.10
  • Python version: 3.8.3
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): 2.4.1 (False)
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Is the problem depending on the version of some other library?

@forest1988
Copy link
Contributor

Excuse me for my frequent posting.

Instead of overwriting position_embeddings,
inserting model.train() seems to work (but with another issue).

from transformers import ReformerTokenizer, ReformerForQuestionAnswering
from transformers.models.reformer.modeling_reformer import PositionEmbeddings
import torch

tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
model = ReformerForQuestionAnswering.from_pretrained('google/reformer-crime-and-punishment')

# # change to position embeddings to prevent error
# model.reformer.embeddings.position_embeddings = PositionEmbeddings(model.config)

model.train()

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors='pt')
start_positions = torch.tensor([1])
end_positions = torch.tensor([3])

outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss

loss.backward()

The different error message is shown, but it seems can be treated by just doing padding.

~/.pyenv/versions/anaconda3-2020.07/lib/python3.8/site-packages/transformers/models/reformer/modeling_reformer.py in forward(self, position_ids)
    154 
    155         if self.training is True:
--> 156             assert (
    157                 reduce(mul, self.axial_pos_shape) == sequence_length
    158             ), "If training, make sure that config.axial_pos_shape factors: {} multiply to sequence length. Got prod({}) != sequence_length: {}. You might want to consider padding your sequence length to {} or changing config.axial_pos_shape.".format(

AssertionError: If training, make sure that config.axial_pos_shape factors: (512, 1024) multiply to sequence length. Got prod((512, 1024)) != sequence_length: 28. You might want to consider padding your sequence length to 524288 or changing config.axial_pos_shape.

I'm now trying padding the input, and it seems working.

tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=524288, return_tensors='pt')

I apologize if this is not an appropriate solution.

@patrickvonplaten
Copy link
Contributor

We could maybe add a better error message that fires when Reformer is not in training mode, but one runs .backward(). @forest1988 if you want feel free to open a PR :-)

@forest1988
Copy link
Contributor

@patrickvonplaten
Thanks, I'll open a PR!
I'm a little busy right now, but I'll make time to work on it soon.

@forest1988
Copy link
Contributor

Hi @patrickvonplaten,
Sorry to be late. I've just opened PR #11117 regarding this issue. All checks have passed.
Could you please have a look at it when you have time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants