Skip to content

Commit

Permalink
flan t5 torchtext (pytorch#2027)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2027

[d2go] flan t5 torchtext

This enables support for FLAN-T5 in Torchtext.

So far, we have only enabled FLAN-T5 encoder-only models. If we need to have an encoder-decoder model, it would be straightforward to add support for that.

Reviewed By: joecummings

Differential Revision: D42159825

fbshipit-source-id: 6c2a4430df890131e18d3ebe40bba35ecc6b25b8
  • Loading branch information
Forrest Iandola authored and Nayef211 committed Jan 20, 2023
1 parent 569d48d commit 279f242
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
6 changes: 3 additions & 3 deletions test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_enc"

t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
t5_small_enc = T5EncoderModel.from_pretrained("t5-small", encoder_only=True)
t5_small_enc.save_pretrained(model_path)

our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path)
Expand All @@ -218,7 +218,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
t5_small = T5Model.from_pretrained("t5-small")
t5_small.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=False)

hf_output = t5_small(
input_ids=self.encoder_input_ids,
Expand All @@ -240,7 +240,7 @@ def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> No
t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_small_gen.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=False)

hf_output = t5_small_gen(
input_ids=self.encoder_input_ids,
Expand Down
21 changes: 17 additions & 4 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def build_model(
@staticmethod
def build_model_from_huggingface_ckpt(
ckpt_path: Union[str, os.PathLike],
encoder_only: bool,
*,
freeze_model: bool = False,
strict: bool = True,
Expand All @@ -163,13 +164,14 @@ def build_model_from_huggingface_ckpt(

# TODO(joecummings): find better way to determine `encoder_only` and `linear_head`
config = T5Conf(
encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(),
encoder_only=encoder_only,
linear_head="lm_head.weight" in hf_weights.keys(),
embedding_dim=config_json["d_model"],
num_attention_heads=config_json["num_heads"],
num_encoder_layers=config_json["num_layers"],
num_decoder_layers=config_json["num_decoder_layers"],
ffn_dimension=config_json["d_ff"],
feed_forward_proj=config_json.get("feed_forward_proj"),
)

t5_model = T5Model(config, freeze_model)
Expand All @@ -184,9 +186,20 @@ def build_model_from_huggingface_ckpt(
}
# Convert encoder layers
for i in range(config.num_encoder_layers):
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]
if config.is_gated_act:
t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"
]

t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"
]

else:
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]

t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
]
Expand Down
29 changes: 28 additions & 1 deletion torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .modules import T5Decoder, T5Encoder


@dataclass(frozen=True)
@dataclass
class T5Conf:
encoder_only: bool = False
linear_head: bool = False
Expand All @@ -29,6 +29,32 @@ class T5Conf:
max_seq_len: int = 512
vocab_size: int = 32128
training: bool = False
feed_forward_proj: str = None
is_gated_act: bool = False

def __post_init__(self):
"""
the following is modified from
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
It's to support T5 1.1 and FLAN-T5.
"""

if self.feed_forward_proj:
act_info = self.feed_forward_proj.split("-")
self.activation = act_info[-1]
self.is_gated_act = (act_info[0] == "gated")

if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {self.feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)

# for backwards compatibility
if self.feed_forward_proj == "gated-gelu":
self.activation = "gelu_new"


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
Expand Down Expand Up @@ -102,6 +128,7 @@ def __init__(
relative_attention_num_buckets=config.relative_attention_num_buckets,
relative_attention_max_distance=config.relative_attention_max_distance,
token_embeddings=self.token_embeddings,
is_gated_act=config.is_gated_act,
device=device,
dtype=dtype,
)
Expand Down
36 changes: 32 additions & 4 deletions torchtext/prototype/models/t5/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,9 @@ def __init__(
layer_norm_eps: float = 1e-6,
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
is_gated_act: bool = False,
compute_relative_attention_bias: bool = False,

device: Optional[torch.device] = None,
dtype=None,
) -> None:
Expand All @@ -549,6 +551,7 @@ def __init__(
self.compute_relative_attention_bias = compute_relative_attention_bias
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.is_gated_act = is_gated_act

self.self_attn = T5MultiheadAttention(
d_model,
Expand All @@ -562,7 +565,15 @@ def __init__(
device=device,
dtype=dtype,
)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)

if self.is_gated_act:
self.linear1 = None
self.linear1_0 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear1_1 = nn.Linear(d_model, dim_feedforward, bias=False)
else:
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear1_0 = None
self.linear1_1 = None
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.norm1 = T5LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = T5LayerNorm(d_model, eps=layer_norm_eps)
Expand All @@ -574,11 +585,15 @@ def __init__(
assert activation in (
"relu",
"gelu",
), f"Do not support '{activation}' activation. Use either 'relu' or 'gelu'"
"gelu_new",
), f"Do not support '{activation}' activation. Use 'relu' or 'gelu' or 'gelu_new'"
if activation == "relu":
self.activation = F.relu
elif activation == "gelu":
self.activation = F.gelu
elif activation == "gelu_new":
# the following should match the math of https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
self.activation = nn.GELU(approximate='tanh')
else:
self.activation = activation

Expand Down Expand Up @@ -637,8 +652,19 @@ def _sa_block(

# Feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
return self.dropout3(x)
if self.is_gated_act:
wi_0 = self.activation(self.linear1_0(x))
wi_1 = self.linear1_1(x)
hidden_states = wi_0 * wi_1
hidden_states = self.dropout2(hidden_states)
hidden_states = self.linear2(hidden_states)
hidden_states = self.dropout3(hidden_states)
return hidden_states

else:
assert self.linear1 is not None
x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
return self.dropout3(x)


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622
Expand Down Expand Up @@ -810,6 +836,7 @@ def __init__(
relative_attention_num_buckets: int = 32,
relative_attention_max_distance: int = 128,
token_embeddings: Optional[nn.Module] = None,
is_gated_act: bool = False,
device: Optional[torch.device] = None,
dtype=None,
) -> None:
Expand All @@ -827,6 +854,7 @@ def __init__(
layer_norm_eps,
relative_attention_num_buckets,
relative_attention_max_distance,
is_gated_act,
compute_relative_attention_bias=True if i == 0 else False,
device=device,
dtype=dtype,
Expand Down

0 comments on commit 279f242

Please sign in to comment.