diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 17b61b2135..fdd3a85c09 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -271,6 +271,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("token_embedding") + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 7cdc7cd346..647da9cfd9 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -27,6 +27,11 @@ class Backbone(keras.Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @property + def token_embedding(self): + """A `keras.layers.Embedding` instance for embedding token ids.""" + raise NotImplementedError + def get_config(self): # Don't chain to super here. The default `get_config()` for functional # models is nested and cannot be passed to our Backbone constructors. diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 475f0b2e5c..7a5cc5d784 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -244,3 +244,7 @@ def get_config(self): "name": self.name, "trainable": self.trainable, } + + @property + def token_embedding(self): + return self.get_layer("token_embedding") diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index d59fc11e91..3e7d46734e 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -190,16 +190,16 @@ def __init__( }, **kwargs, ) + # All references to `self` below this line self.vocabulary_size = vocabulary_size - self.hidden_dim = hidden_dim - self.intermediate_dim = intermediate_dim self.num_layers = num_layers self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.dropout = dropout self.max_sequence_length = max_sequence_length self.num_segments = num_segments - self.dropout = dropout - self.token_embedding = token_embedding_layer self.cls_token_index = cls_token_index def get_config(self): @@ -207,17 +207,21 @@ def get_config(self): config.update( { "vocabulary_size": self.vocabulary_size, - "hidden_dim": self.hidden_dim, - "intermediate_dim": self.intermediate_dim, "num_layers": self.num_layers, "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, - "dropout": self.dropout, } ) return config + @property + def token_embedding(self): + return self.get_layer("token_embedding") + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index 4473ebcb59..e9152903a0 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -200,6 +200,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("token_embedding") + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 451b646d22..4920083cc6 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -180,6 +180,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("token_and_position_embedding").token_embedding + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 83acd9e255..2e6dc6dd18 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -216,6 +216,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("token_embedding") + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index 17aec1396c..b52778ae6f 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -187,6 +187,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("token_embedding") + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 5b06e0fc06..ebeb984226 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -174,6 +174,10 @@ def get_config(self): ) return config + @property + def token_embedding(self): + return self.get_layer("embeddings").token_embedding + @classproperty def presets(cls): return copy.deepcopy(backbone_presets)