Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flan t5 torchtext #2027

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_enc"

t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
t5_small_enc = T5EncoderModel.from_pretrained("t5-small", encoder_only=True)
t5_small_enc.save_pretrained(model_path)

our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path)
Expand All @@ -218,7 +218,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
t5_small = T5Model.from_pretrained("t5-small")
t5_small.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=False)

hf_output = t5_small(
input_ids=self.encoder_input_ids,
Expand All @@ -240,7 +240,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> No
t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_small_gen.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=False)

hf_output = t5_small_gen(
input_ids=self.encoder_input_ids,
Expand Down
21 changes: 17 additions & 4 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def build_model(
@staticmethod
def build_model_from_huggingface_ckpt(
ckpt_path: Union[str, os.PathLike],
encoder_only: bool,
*,
freeze_model: bool = False,
strict: bool = True,
Expand All @@ -163,13 +164,14 @@ def build_model_from_huggingface_ckpt(

# TODO(joecummings): find better way to determine `encoder_only` and `linear_head`
config = T5Conf(
encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(),
encoder_only=encoder_only,
linear_head="lm_head.weight" in hf_weights.keys(),
embedding_dim=config_json["d_model"],
num_attention_heads=config_json["num_heads"],
num_encoder_layers=config_json["num_layers"],
num_decoder_layers=config_json["num_decoder_layers"],
ffn_dimension=config_json["d_ff"],
feed_forward_proj=config_json.get("feed_forward_proj"),
)

t5_model = T5Model(config, freeze_model)
Expand All @@ -184,9 +186,20 @@ def build_model_from_huggingface_ckpt(
}
# Convert encoder layers
for i in range(config.num_encoder_layers):
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]
if config.is_gated_act:
t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"
]

t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"
]

else:
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]

t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
]
Expand Down
29 changes: 28 additions & 1 deletion torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .modules import T5Decoder, T5Encoder


@dataclass(frozen=True)
@dataclass
class T5Conf:
encoder_only: bool = False
linear_head: bool = False
Expand All @@ -29,6 +29,32 @@ class T5Conf:
max_seq_len: int = 512
vocab_size: int = 32128
training: bool = False
feed_forward_proj: str = None
is_gated_act: bool = False

def __post_init__(self):
"""
the following is modified from
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py

It's to support T5 1.1 and FLAN-T5.
"""

if self.feed_forward_proj:
act_info = self.feed_forward_proj.split("-")
self.activation = act_info[-1]
self.is_gated_act = (act_info[0] == "gated")

if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)

# for backwards compatibility
if self.feed_forward_proj == "gated-gelu":
self.activation = "gelu_new"


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
Expand Down Expand Up @@ -102,6 +128,7 @@ def __init__(
relative_attention_num_buckets=config.relative_attention_num_buckets,
relative_attention_max_distance=config.relative_attention_max_distance,
token_embeddings=self.token_embeddings,
is_gated_act=config.is_gated_act,
device=device,
dtype=dtype,
)
Expand Down
36 changes: 32 additions & 4 deletions torchtext/prototype/models/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ def __init__(
layer_norm_eps: float = 1e-6,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
is_gated_act: bool = False,
compute_relative_attention_bias: bool = False,

device: Optional[torch.device] = None,
dtype=None,
) -> None:
Expand All @@ -549,6 +551,7 @@ def __init__(
self.compute_relative_attention_bias = compute_relative_attention_bias
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.is_gated_act = is_gated_act

self.self_attn = T5MultiheadAttention(
d_model,
Expand All @@ -562,7 +565,15 @@ def __init__(
device=device,
dtype=dtype,
)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)

if self.is_gated_act:
self.linear1 = None
self.linear1_0 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear1_1 = nn.Linear(d_model, dim_feedforward, bias=False)
else:
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear1_0 = None
self.linear1_1 = None
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps)
Expand All @@ -574,11 +585,15 @@ def __init__(
assert activation in (
"relu",
"gelu",
), f"Do not support '{activation}' activation. Use either 'relu' or 'gelu'"
"gelu_new",
), f"Do not support '{activation}' activation. Use 'relu' or 'gelu' or 'gelu_new'"
if activation == "relu":
self.activation = F.relu
elif activation == "gelu":
self.activation = F.gelu
elif activation == "gelu_new":
# the following should match the math of https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
self.activation = nn.GELU(approximate='tanh')
else:
self.activation = activation

Expand Down Expand Up @@ -637,8 +652,19 @@ def _sa_block(

# Feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
return self.dropout3(x)
if self.is_gated_act:
wi_0 = self.activation(self.linear1_0(x))
wi_1 = self.linear1_1(x)
hidden_states = wi_0 * wi_1
hidden_states = self.dropout2(hidden_states)
hidden_states = self.linear2(hidden_states)
hidden_states = self.dropout3(hidden_states)
return hidden_states

else:
assert self.linear1 is not None
x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
return self.dropout3(x)


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622
Expand Down Expand Up @@ -810,6 +836,7 @@ def __init__(
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
token_embeddings: Optional[nn.Module] = None,
is_gated_act: bool = False,
device: Optional[torch.device] = None,
dtype=None,
) -> None:
Expand All @@ -827,6 +854,7 @@ def __init__(
layer_norm_eps,
relative_attention_num_buckets,
relative_attention_max_distance,
is_gated_act,
compute_relative_attention_bias=True if i == 0 else False,
device=device,
dtype=dtype,
Expand Down