Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: correct TFBart embeddings weights name when load_weight_prefix is passed #18993

Merged
merged 1 commit into from
Sep 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import random
from contextlib import nullcontext
from typing import Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -743,7 +744,15 @@ def call(
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
Expand Down Expand Up @@ -931,7 +940,15 @@ def call(
positions = self.embed_positions(input_shape, position_ids=position_ids)

if inputs_embeds is None:
with tf.name_scope(self.embed_tokens.name + "/"):
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
if hasattr(self.embed_tokens, "load_weight_prefix"):
context_manager = tf.name_scope(self.embed_tokens.load_weight_prefix + "/")
else:
context_manager = nullcontext()
with context_manager:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

hidden_states = inputs_embeds
Expand Down Expand Up @@ -1027,8 +1044,9 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name="model.shared")
# Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix

self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")
Expand Down