Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue when saving TF model with a tokenizer as a custom layer #422

Open
Shiro-LK opened this issue Oct 22, 2020 · 6 comments
Open

Issue when saving TF model with a tokenizer as a custom layer #422

Shiro-LK opened this issue Oct 22, 2020 · 6 comments

Comments

@Shiro-LK
Copy link

Shiro-LK commented Oct 22, 2020

Hi,

I am trying to create a tensorflow model with keras api, when I include the tokenizing process inside the model. It seems to work for the inference locally, but when I am saving the model with tf.saved_model.save , I got an error. I am wondering if there is something wrong in my current code, or if it is currently not possible ?

AssertionError: Tried to export a function which references untracked object Tensor("139395:0", shape=(), dtype=resource).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

My tokenizer which use the BertTokenizer from tensorflow_text (I take the code from some discussion in this forum and modify it) :

class TokenizerTF(tf.Module):
    def __init__(self, vocab_file_path, sequence_length=512, lower_case=True, pad_id=1, cls_id=2, sep_id=3):
        self.cls_token_id = tf.constant(cls_id, dtype=tf.int32)
        self.sep_token_id = tf.constant(sep_id, dtype=tf.int32)
        self.pad_token_id = tf.constant(pad_id, dtype=tf.int32)

        self.sequence_length = tf.constant(sequence_length)



        # These two lines are basically what makes it work
        # assigning the vocab to a tf.Module and then later assigning the
        # intantiated Module to e.g. a Keras Model
        self.bert_tokenizer = tf_text.BertTokenizer(
            vocab_file_path,
            lower_case=lower_case,
        )

    @tf.function
    def __call__(self, text: tf.Tensor) -> tf.Tensor:
        """
        Perform the BERT preprocessing from text -> input token ids
        """
        # Convert text into token ids
        tokens = self.bert_tokenizer.tokenize(text)

        # Flatten the ragged tensors
        tokens =  tf.cast(tokens.merge_dims(1, 2), tf.int32)

        # Add start and end token ids to the id sequence
        start_tokens = tf.fill([tf.shape(text)[0], 1], self.cls_token_id)
        end_tokens = tf.fill([tf.shape(text)[0], 1], self.sep_token_id)
        tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)

        # Truncate to sequence length
        tokens = tokens[:, : self.sequence_length]

        # Convert ragged tensor to tensor and pad with PAD_ID
        tokens = tokens.to_tensor(default_value=self.pad_token_id)

        # Pad to sequence length
        pad = self.sequence_length - tf.shape(tokens)[1]
        tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=self.pad_token_id)

        return tf.reshape(tokens, [-1, self.sequence_length])  

My current model :

def get_model(backbone, max_len, tokenizer):
    """
        backbone = transformer model
    """
    padding_idx = tokenizer.pad_token_id
    input_str = tf.keras.layers.Input(shape=(), dtype=tf.string, name = "input_str")
    input_ids = tf.keras.layers.Lambda(lambda x: tokenizer(x))(input_str)
    
    
    #attention_mask = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name = "attention_mask")
    attention_mask = tf.math.not_equal(input_ids, padding_idx)
    predictions = backbone(input_ids, attention_mask=attention_mask)
    outputs = tf.keras.layers.Activation("sigmoid", name="outputs_proba")(predictions)

    model = tf.keras.Model(inputs=input_str, outputs=outputs)
    model.compile(tf.keras.optimizers.Adam(1e-5), loss="binary_crossentropy")
    return model

PS : I am using TF 2.3.1

@Shiro-LK
Copy link
Author

Am I the only one to get this error ?

@broken
Copy link
Member

broken commented Dec 7, 2020

Is it just the BertTokenizer? I'll pass this on to somebody more familiar with Keras.

@fsx950223
Copy link
Member

fsx950223 commented Dec 24, 2020

Is it just the BertTokenizer? I'll pass this on to somebody more familiar with Keras.

@broken. I have fixed the bug in #460. Could you review it?

@broken
Copy link
Member

broken commented Jan 5, 2021

Thanks! I missed this over the holidays. We'll take a took.

@jeisinge
Copy link

jeisinge commented Jul 8, 2021

I am also running into this issue and a similar work-around. In particular, I found that the BertTokenizer needs to be wrapped in a Lambda layer:

class TspBertTokenizer(keras.layers.Layer):
  def __init__(self, vocab_file, cls_token_id=None, sep_token_id=None, **kwargs):
    import tensorflow as tf
    import tensorflow.keras as keras
    import tensorflow.keras.backend as K
    import tensorflow_text as text
    
    super(TspBertTokenizer, self).__init__(**kwargs)
    
    self.vocab_file = vocab_file
    bert_tokenizer = text.BertTokenizer(self.vocab_file, token_out_type=tf.int32, lower_case=True)
    self.tokenize = keras.layers.Lambda(lambda text_input: bert_tokenizer.tokenize(text_input), name="bert_tokenizer")
    
    basic_tokenizer, wordpiece_tokenizer = bert_tokenizer.submodules
    self.cls_token_id = cls_token_id if cls_token_id is not None else K.get_value(wordpiece_tokenizer.tokenize("[CLS]")[0]).item()
    self.sep_token_id = sep_token_id if sep_token_id is not None else K.get_value(wordpiece_tokenizer.tokenize("[SEP]")[0]).item()
    
  def call(self, nlp_input):
    word_tokens = self.tokenize(nlp_input)
    flattened_tokens = word_tokens.merge_dims(1, -1)
    return flattened_tokens
    
  def get_config(self):
    return {
      "vocab_file": self.vocab_file,
      "cls_token_id": self.cls_token_id,
      "sep_token_id": self.sep_token_id,
      **super(TspBertTokenizer, self).get_config()
    }

Then, it can be added to a Keras Layer. I think this functionally works and an export can be done. However, it is not clear if performance is ideal. I get the following:

[1,0]<stderr>:WARNING:tensorflow:AutoGraph could not transform <bound method TspBertTokenizer.call of <__main__.TspBertTokenizer object at 0x7fc57fa9d3a0>> and will run it as-is.
[1,0]<stderr>:Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
[1,0]<stderr>:Cause: Unable to locate the source code of <bound method TspBertTokenizer.call of <__main__.TspBertTokenizer object at 0x7fc57fa9d3a0>>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
[1,0]<stderr>:To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
[1,0]<stderr>:AutoGraph could not transform <bound method TspBertTokenizer.call of <__main__.TspBertTokenizer object at 0x7fc57fa9d3a0>> and will run it as-is.
[1,0]<stderr>:Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
[1,0]<stderr>:Cause: Unable to locate the source code of <bound method TspBertTokenizer.call of <__main__.TspBertTokenizer object at 0x7fc57fa9d3a0>>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code
[1,0]<stderr>:To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

[1,0]<stderr>:2021-07-08 16:12:23.837909: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:906] Skipping loop optimization for Merge node with control input: pericles/nlp_input/cross_nlp/tsp_bert_tokenizer/bert_tokenizer/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/assert_greater_equal/Assert/AssertGuard/branch_executed/_107

I do not know if any of these warnings degrade performance or hurt model accuracy. Any feedback on if these warnings are an issue or better work-arounds are much appreciated!

This is on TF 2.5. Also, I filed tensorflow/models#10115 as a downstream issue as well. See the gist there for the export issue without the Lambda.

@markomernick
Copy link
Member

Thanks for the report! I'll take a look at this and see if we can get a fix pushed soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants