Skip to content

Commit

Permalink
Adding
Browse files Browse the repository at this point in the history
  • Loading branch information
soma2000-lang committed Feb 7, 2023
1 parent 60c3002 commit 8d680df
Showing 1 changed file with 26 additions and 34 deletions.
60 changes: 26 additions & 34 deletions keras_nlp/models/bert/bert_backbone.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The KerasNLP Authors
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,7 +24,6 @@
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bert.bert_presets import backbone_presets
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.python_utils import format_docstring


def bert_kernel_initializer(stddev=0.02):
Expand Down Expand Up @@ -77,7 +76,7 @@ class BertBackbone(Backbone):
}
# Pretrained BERT encoder
model = keras_nlp.models.BertBackbone.from_preset("base_base_en_uncased")
model = keras_nlp.models.BertBackbone.from_preset("bert_base_en_uncased")
output = model(input_data)
# Randomly initialized BERT encoder with a custom config
Expand Down Expand Up @@ -105,7 +104,6 @@ def __init__(
num_segments=2,
**kwargs,
):

# Index of classification token in the vocabulary
cls_token_index = 0
# Inputs
Expand Down Expand Up @@ -163,7 +161,7 @@ def __init__(
x, approximate=True
),
dropout=dropout,
layer_norm_epilson=1e-05,
layer_norm_epsilon=1e-12,
kernel_initializer=bert_kernel_initializer(),
name=f"transformer_layer_{i}",
)(x, padding_mask=padding_mask)
Expand Down Expand Up @@ -191,44 +189,38 @@ def __init__(
},
**kwargs,
)

# All references to `self` below this line
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_sequence_length = max_sequence_length
self.num_segments = num_segments
self.dropout = dropout
self.token_embedding = token_embedding_layer
self.cls_token_index = cls_token_index

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
"dropout": self.dropout,
"name": self.name,
"trainable": self.trainable,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
"num_segments": self.num_segments,
}
)
return config

@property
def token_embedding(self):
return self.get_layer("token_embedding")

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(cls, preset, load_weights=True, **kwargs):
return super().from_preset(preset, load_weights, **kwargs)


BertBackbone.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__
format_docstring(
model_name=BertBackbone.__name__,
example_preset_name="bert_base_en_uncased",
preset_names='", "'.join(BertBackbone.presets),
)(BertBackbone.from_preset.__func__)
return copy.deepcopy(backbone_presets)

0 comments on commit 8d680df

Please sign in to comment.