Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Exception raised when saving a re-loaded model with a transformer block #878

Closed
oliverholworthy opened this issue Nov 14, 2022 · 4 comments
Assignees
Labels
bug Something isn't working P0
Milestone

Comments

@oliverholworthy
Copy link
Member

Bug description

Attempting to save a re-loaded model containing a transformer block. We get an exception about inputs to the model.

TF 2.9.2 - TypeError: Unable to serialize - **Schema**
    def test_clm_reload():
        model_dir = "/tmp/clm_model"
        reloaded_model = mm.Model.load(model_dir)
>       reloaded_model.save("/tmp/clm_model_2")

tests/unit/tf/transformers/test_block.py:320:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67: in error_handler
    raise e.with_traceback(filtered_tb) from None
/usr/lib/python3.8/json/encoder.py:199: in encode
    chunks = self.iterencode(o, _one_shot=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = <keras.saving.saved_model.json_utils.Encoder object at 0x7fdcc8ba68e0>
o = {'backend': 'tensorflow', 'batch_input_shape': None, 'class_name': 'merlin.models>Model', 'config': {'0': {'class_name...AgYBBgEGAQYBBvs=\n', {...}, None)}, 'dtype': 'float32', 'function_type': 'lambda', ...}, 'shared_object_id': 21}}, ...}, _one_shot = True

    def iterencode(self, o, _one_shot=False):
        """Encode the given object and yield each string
        representation as available.

        For example::

            for chunk in JSONEncoder().iterencode(bigobject):
                mysocket.write(chunk)

        """
        if self.check_circular:
            markers = {}
        else:
            markers = None
        if self.ensure_ascii:
            _encoder = encode_basestring_ascii
        else:
            _encoder = encode_basestring

        def floatstr(o, allow_nan=self.allow_nan,
                _repr=float.__repr__, _inf=INFINITY, _neginf=-INFINITY):
            # Check for specials.  Note that this type of test is processor
            # and/or platform-specific, so do tests which don't depend on the
            # internals.

            if o != o:
                text = 'NaN'
            elif o == _inf:
                text = 'Infinity'
            elif o == _neginf:
                text = '-Infinity'
            else:
                return _repr(o)

            if not allow_nan:
                raise ValueError(
                    "Out of range float values are not JSON compliant: " +
                    repr(o))

            return text


        if (_one_shot and c_make_encoder is not None
                and self.indent is None):
            _iterencode = c_make_encoder(
                markers, self.default, _encoder, self.indent,
                self.key_separator, self.item_separator, self.sort_keys,
                self.skipkeys, self.allow_nan)
        else:
            _iterencode = _make_iterencode(
                markers, self.default, _encoder, self.indent, floatstr,
                self.key_separator, self.item_separator, self.sort_keys,
                self.skipkeys, _one_shot)
>       return _iterencode(o, 0)
E       TypeError: Unable to serialize [{'name': 'item_id_seq', 'tags': {<Tags.ITEM_ID: 'item_id'>, <Tags.ID: 'id'>, <Tags.SEQUENCE: 'sequence'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>}, 'properties': {'domain': {'min': 1, 'max': 51996, 'name': 'item_id_seq'}, 'value_count': {'min': 1, 'max': 4}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}, {'name': 'categories', 'tags': {<Tags.ITEM: 'item'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.SEQUENCE: 'sequence'>, <Tags.LIST: 'list'>}, 'properties': {'domain': {'min': 1, 'max': 331, 'name': 'categories'}, 'value_count': {'min': 1, 'max': 4}}, 'dtype': dtype('int64'), 'is_list': True, 'is_ragged': True}] to JSON. Unrecognized type <class 'merlin.schema.schema.Schema'>.

/usr/lib/python3.8/json/encoder.py:257: TypeError
TF 2.10.0 - Could not find matching concrete function to call loaded from the SavedModel
args = ({'categories': tf.RaggedTensor(values=Tensor("args_0:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_1:0",...e=(None,), dtype=float32), row_splits=Tensor("args_0_7:0", shape=(None,), dtype=int32)), ...}, None, False, None, None), kwargs = {}
do_return = False, retval_ = <tensorflow.python.autograph.operators.variables.UndefinedReturnValue object at 0x7f37e74460a0>

    def tf___wrapped_model(*args, **kwargs):
        "A concrete tf.function that wraps the model's call function."
        with ag__.FunctionScope('_wrapped_model', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
            do_return = False
            retval_ = ag__.UndefinedReturnValue()
            (args, kwargs) = ag__.converted_call(ag__.ld(model)._call_spec.set_arg_value, ('training', False, ag__.ld(args), ag__.ld(kwargs)), dict(inputs_in_args=True), fscope)
            with ag__.ld(base_layer_utils).call_context().enter(ag__.ld(model), inputs=None, build_graph=False, training=False, saving=True):
>               outputs = ag__.converted_call(ag__.ld(model), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope)
E               ValueError: in user code:
E
E                   File "/home/oliverholworthy/anaconda3/envs/python-3.8-rapids-22.10/lib/python3.8/site-packages/keras/saving/saving_utils.py", line 147, in _wrapped_model  *
E                       outputs = model(*args, **kwargs)
E                   File "/home/oliverholworthy/anaconda3/envs/python-3.8-rapids-22.10/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
E                       raise e.with_traceback(filtered_tb) from None
E
E                   ValueError: Exception encountered when calling layer "model" "                 f"(type merlin.models>Model).
E
E                   Could not find matching concrete function to call loaded from the SavedModel. Got:
E                     Positional arguments (5 total):
E                       * {'categories': tf.RaggedTensor(values=Tensor("inputs:0", shape=(None,), dtype=int64), row_splits=Tensor("inputs_1:0", shape=(None,), dtype=int32)),
E                    'event_hour_cos': tf.RaggedTensor(values=Tensor("inputs_2:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_3:0", shape=(None,), dtype=int32)),
E                    'event_hour_sin': tf.RaggedTensor(values=Tensor("inputs_4:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_5:0", shape=(None,), dtype=int32)),
E                    'event_weekday_cos': tf.RaggedTensor(values=Tensor("inputs_6:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_7:0", shape=(None,), dtype=int32)),
E                    'event_weekday_sin': tf.RaggedTensor(values=Tensor("inputs_8:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_9:0", shape=(None,), dtype=int32)),
E                    'item_age_days_norm': tf.RaggedTensor(values=Tensor("inputs_10:0", shape=(None,), dtype=float32), row_splits=Tensor("inputs_11:0", shape=(None,), dtype=int32)),
E                    'item_id_seq': tf.RaggedTensor(values=Tensor("inputs_12:0", shape=(None,), dtype=int64), row_splits=Tensor("inputs_13:0", shape=(None,), dtype=int32)),
E                    'test_user_id': <tf.Tensor 'inputs_14:0' shape=(None, 1) dtype=int64>,
E                    'user_age': <tf.Tensor 'inputs_15:0' shape=(None, 1) dtype=float32>,
E                    'user_country': <tf.Tensor 'inputs_16:0' shape=(None, 1) dtype=int64>}
E                       * None
E                       * False
E                       * None
E                       * None
E                     Keyword arguments: {}
E
E                    Expected these arguments to match one of the following 2 option(s):
E
E                   Option 1:
E                     Positional arguments (5 total):
E                       * {'categories': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'event_hour_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_hour_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_age_days_norm': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_id_seq': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'test_user_id': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/test_user_id'),
E                    'user_age': TensorSpec(shape=(None, 1), dtype=tf.float32, name='inputs/user_age'),
E                    'user_country': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/user_country')}
E                       * None
E                       * False
E                       * False
E                       * False
E                     Keyword arguments: {}
E
E                   Option 2:
E                     Positional arguments (5 total):
E                       * {'categories': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'event_hour_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_hour_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_cos': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'event_weekday_sin': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_age_days_norm': RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int32),
E                    'item_id_seq': RaggedTensorSpec(TensorShape([None, None]), tf.int64, 1, tf.int32),
E                    'test_user_id': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/test_user_id'),
E                    'user_age': TensorSpec(shape=(None, 1), dtype=tf.float32, name='inputs/user_age'),
E                    'user_country': TensorSpec(shape=(None, 1), dtype=tf.int64, name='inputs/user_country')}
E                       * None
E                       * True
E                       * False
E                       * False
E                     Keyword arguments: {}
E
E                   Call arguments received by layer "model" "                 f"(type merlin.models>Model):
E                     • args=({'item_age_days_norm': 'tf.RaggedTensor(values=Tensor("args_0_10:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_11:0", shape=(None,), dtype=int32))', 'item_id_seq': 'tf.RaggedTensor(values=Tensor("args_0_12:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_13:0", shape=(None,), dtype=int32))', 'event_hour_sin': 'tf.RaggedTensor(values=Tensor("args_0_4:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_5:0", shape=(None,), dtype=int32))', 'test_user_id': 'tf.Tensor(shape=(None, 1), dtype=int64)', 'categories': 'tf.RaggedTensor(values=Tensor("args_0:0", shape=(None,), dtype=int64), row_splits=Tensor("args_0_1:0", shape=(None,), dtype=int32))', 'event_weekday_sin': 'tf.RaggedTensor(values=Tensor("args_0_8:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_9:0", shape=(None,), dtype=int32))', 'event_weekday_cos': 'tf.RaggedTensor(values=Tensor("args_0_6:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_7:0", shape=(None,), dtype=int32))', 'user_country': 'tf.Tensor(shape=(None, 1), dtype=int64)', 'user_age': 'tf.Tensor(shape=(None, 1), dtype=float32)', 'event_hour_cos': 'tf.RaggedTensor(values=Tensor("args_0_2:0", shape=(None,), dtype=float32), row_splits=Tensor("args_0_3:0", shape=(None,), dtype=int32))'}, 'None', 'False', 'None', 'None')
E                     • kwargs=<class 'inspect._empty'>

/tmp/__autograph_generated_file024ppwfs.py:14: ValueError

Steps/Code to reproduce bug

  1. Run example test below

Example test

def test_clm(sequence_testing_data: Dataset):

    seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag(
        Tags.CATEGORICAL
    )
    target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
    predict_next = mm.SequencePredictNext(schema=seq_schema, target=target)

    loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next)

    d_model = 48
    model = mm.Model(
        mm.InputBlockV2(
            seq_schema,
            categorical=mm.Embeddings(
                seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
            ),
        ),
        mm.MLPBlock([d_model]),
        mm.GPT2Block(d_model=d_model, n_head=8, n_layer=2),
        mm.CategoricalOutput(
            seq_schema.select_by_name(target), default_loss="categorical_crossentropy"
        ),
    )

    model.compile()
    model.fit(loader)

    model_dir = "/tmp/clm_model"
    model.save(model_dir)

def test_clm_reload():
    model_dir = "/tmp/clm_model"
    reloaded_model = mm.Model.load(model_dir)
    reloaded_model.save("/tmp/clm_model_2")

Produces error:

  • Running test_clm_reload after test_clm.

Doesn't produce error:

  • Putting the contents of test_clm_reload into test_clm (Something to do with different state)
  • Replacing GPT2Block with mm.ListToDense()

Expected behavior

Able to save a re-loaded model without any errors. And result in the same saved model artifact as the one being loaded.

Enviroment details

  • Merlin version: 22.10
  • Python version: 3.8
  • Tensorflow version (GPU): [2.9.1+nv22.8, 2.9.2, 2.10.0]

Additional context

@oliverholworthy oliverholworthy added bug Something isn't working P1 labels Nov 14, 2022
@oliverholworthy oliverholworthy added this to the Merlin 22.12 milestone Nov 14, 2022
@marcromeyn
Copy link
Contributor

Does this only happen with a TransformerBlock or more general?

@oliverholworthy
Copy link
Member Author

Replacing mm.GPT2Block(d_model=d_model, n_head=8, n_layer=2) with mm.ListToDense() in the example - It doesn't raise this error. So it appears to be some interaction between the TransformerBlock and the model. Maybe it's more general in the sense that it may be possible to create an example that doesn't use a transformer block - I haven't been able to produce an example of that though.

@oliverholworthy
Copy link
Member Author

I tried this again today, to make sure this wasn't something to do with running the transform in the loader or see if has been fixed by another dependency. Still getting the error.

@rnyak rnyak self-assigned this Jan 11, 2023
@rnyak rnyak modified the milestones: Merlin 22.12, Merlin 23.02 Jan 11, 2023
@oliverholworthy
Copy link
Member Author

This issue was resolved in the 22.12 release. I'm not sure exactly which PR fixed it though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P0
Projects
None yet
Development

No branches or pull requests

3 participants