Skip to content

BeamSearchDecoder segmentation fault (on GPU) #1109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
georgesterpu opened this issue Feb 19, 2020 · 7 comments
Closed

BeamSearchDecoder segmentation fault (on GPU) #1109

georgesterpu opened this issue Feb 19, 2020 · 7 comments
Labels
bug Something isn't working custom-ops seq2seq

Comments

@georgesterpu
Copy link
Contributor

georgesterpu commented Feb 19, 2020

System information

  • OS Platform and Distribution: Manjaro Linux testing
  • TensorFlow version: pypi tf-nightly 2.2.0.dev20200218
  • TensorFlow-Addons version: pypi tfa-nightly 0.9.0.dev20200219
  • Python version: 3.7.6
  • Is GPU used? (yes/no): yes, Nvidia Titan XP

Describe the bug

Calling a BeamSearchDecoder results in a Segmentation fault (core dumped).

Code to reproduce the issue

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.python.ops import array_ops
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['AUTOGRAPH_VERBOSITY'] = '10'

cell = tf.keras.layers.LSTMCell(3)
mechanism = tfa.seq2seq.LuongAttention(units=3)
cell = tfa.seq2seq.AttentionWrapper(
    cell=cell,
    attention_mechanism=mechanism)

embedding_layer = tf.keras.layers.Embedding(
            input_dim=3,
            output_dim=3)

decoder = tfa.seq2seq.BeamSearchDecoder(
    cell=cell,
    beam_width=10,
    embedding_fn=embedding_layer,
    maximum_iterations=8)

dataset = tf.data.Dataset.from_tensor_slices(tf.ones((100, 7, 3))).batch(2)
my_iterator = iter(dataset)

@tf.function
def decode(it):

    data = next(it)
    bs = tf.shape(data)[0]

    tiled_memory = tfa.seq2seq.tile_batch(data, multiplier=10)
    mechanism.setup_memory(tiled_memory)

    attention_state = cell.get_initial_state(batch_size=bs*10, dtype=tf.float32)

    return decoder(
        embedding=None,
        start_tokens=array_ops.fill([2], 1),
        end_token=2,
        initial_state=attention_state,
    )
print(10*'=' + ' start ' + 10*'=')
print(decode(my_iterator))
print(10*'=' + ' end ' + 10*'=')

Other info / logs

Code on colab
Not sure if this is related to #990, as the issue targets tensorflow-cpu.

In eager mode the code still crashes with the following output:

========== start ==========
(FinalBeamSearchDecoderOutput(predicted_ids=<tf.Tensor: shape=(2, 8, 10), dtype=int32, numpy=
[...................]
========== end ==========
corrupted size vs. prev_size
Aborted (core dumped)
@guillaumekln
Copy link
Contributor

Does it also segfault when using the tagged releases Addons 0.8.2 and TensorFlow 2.1.0?

@georgesterpu
Copy link
Contributor Author

georgesterpu commented Feb 19, 2020

@guillaumekln This example works (almost) with:

tensorflow                2.1.0                    pypi_0    pypi
tensorflow-addons         0.8.2                    pypi_0    pypi

However the snippet above returns a TypeError

beamsearcherror.py:39 decode  *
        return decoder(

    TypeError: __call__() missing 1 required positional argument: 'inputs'

and can be made to work by having a positional None as the first argument, instead of embedding. @qlzh727 could you please help with some clarifications on this aspect ?

My internal project (soon to be publicly released) doesn't seem to work on tf 2.1, triggering the following error in training:

ValueError: Insufficient elements in branch_graphs[0].outputs.
    Expected: 7
    Actual: 6

which appears to be fixed in the nightlies (probably tensorflow/tensorflow@7f347e6) since the segfault occurs in evaluation, so unfortunately I cannot fall back to the tagged releases anyway.

@guillaumekln
Copy link
Contributor

Let's close this issue as TensorFlow Addons is built against TensorFlow stable. In consequence, TensorFlow nightlies are not supported.

#1281 added explicit runtime checks on the required TensorFlow version.

@avnx
Copy link

avnx commented Mar 19, 2020

Let's close this issue as TensorFlow Addons is built against TensorFlow stable. In consequence, TensorFlow nightlies are not supported.

#1281 added explicit runtime checks on the required TensorFlow version.

@guillaumekln
When are you going to fix this?
Do you have understanding when we can see working version of BeamSearchDecoder?

@guillaumekln
Copy link
Contributor

The issue only affects the non stable versions of TensorFlow. Can you try with TensorFlow 2.1 instead?

@avnx
Copy link

avnx commented Mar 19, 2020

The issue only affects the non stable versions of TensorFlow. Can you try with TensorFlow 2.1 instead?

@guillaumekln

As @georgesterpu mentioned earlier BeamSearchDecoder with @tf.function doesn't work with TF2.2 nightly neither with TF 2.1.
The only difference is that in the first case error is
Aborted (core dumped)

In second case:

ValueError: Insufficient elements in branch_graphs[0].outputs.
    Expected: 7
    Actual: 6

I have encountered exactly the same problem as @georgesterpu have.
I also want to add that BasicDecoder doesn't work with TF 2.1 either. Error is the same as with BeamSearchDecoder. But BasicDecoder works with TF 2.2 nightly. This makes me happy.

Now the only way to run BeamSearchDecoder as I see is to move to the TF 1, but I don't think that it is a good solution.

@guillaumekln
Copy link
Contributor

Ok, I thought you were referring to the Segmentation fault issue. Do you mind opening a separate issue for the Insufficient elements error along with a code snippet to reproduce it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working custom-ops seq2seq
Projects
None yet
Development

No branches or pull requests

4 participants