From f08fe1e085646f53645dd6cf938b6e9b3a2214f9 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 6 Dec 2021 19:25:39 -0800 Subject: [PATCH 1/9] [t5] faster/leaner custom layer norm --- src/transformers/models/t5/modeling_t5.py | 39 ++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e5c1f340ea8417..48f85bc35eacca 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -66,6 +66,13 @@ ] +# deal with deprecated torch.norm +if hasattr(torch, "linalg") and hasattr(torch.linalg, "norm"): + torch_norm = torch.linalg.norm +else: + torch_norm = torch.norm + + #################################################### # This is a conversion method from TF 1.0 to PyTorch # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 @@ -231,22 +238,38 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Construct a layernorm module in the T5 style No bias and no subtraction of mean. + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps def forward(self, hidden_states): - # layer norm should always be calculated in float32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32, so the original code, which might be easier to + # understand was: + # + # variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # if self.weight.dtype in [torch.float16, torch.bfloat16]: + # hidden_states = hidden_states.to(self.weight.dtype) + # + # but it's more efficient to use the fused kernel norm and it doesn't require converting the + # inputs to fp32, just the intermediary results with norm which are much smaller, we just + # need to make a correction for sqrt(N) - note that this is also similar to weight + # normalization + + hidden_states = ( + hidden_states + / torch_norm(hidden_states, dim=-1, keepdim=True, dtype=torch.float32) + * math.sqrt(hidden_states.shape[-1]) + * self.weight + ) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states + return hidden_states class T5DenseReluDense(nn.Module): From 0b719ce8a029265083f62caed599e2f6773ededc Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 7 Dec 2021 08:45:20 -0800 Subject: [PATCH 2/9] wip --- src/transformers/models/t5/modeling_t5.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 48f85bc35eacca..3e37f1f236c02e 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -265,11 +265,10 @@ def forward(self, hidden_states): hidden_states / torch_norm(hidden_states, dim=-1, keepdim=True, dtype=torch.float32) * math.sqrt(hidden_states.shape[-1]) - * self.weight ) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - return hidden_states + return self.weight * hidden_states class T5DenseReluDense(nn.Module): From 0af2f9fcbdaec8420fb200e0e28e40ec5822ce1f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 3 Feb 2022 18:23:53 -0800 Subject: [PATCH 3/9] apex.normalization.FusedRMSNorm --- src/transformers/models/t5/modeling_t5.py | 52 +++++++++++++++-------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 3e37f1f236c02e..4bae32d6d4e3a0 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -66,13 +66,6 @@ ] -# deal with deprecated torch.norm -if hasattr(torch, "linalg") and hasattr(torch.linalg, "norm"): - torch_norm = torch.linalg.norm -else: - torch_norm = torch.norm - - #################################################### # This is a conversion method from TF 1.0 to PyTorch # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 @@ -242,6 +235,7 @@ def __init__(self, hidden_size, eps=1e-6): """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps def forward(self, hidden_states): @@ -250,27 +244,49 @@ def forward(self, hidden_states): # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32, so the original code, which might be easier to # understand was: - # - # variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # if self.weight.dtype in [torch.float16, torch.bfloat16]: - # hidden_states = hidden_states.to(self.weight.dtype) - # + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # but it's more efficient to use the fused kernel norm and it doesn't require converting the # inputs to fp32, just the intermediary results with norm which are much smaller, we just # need to make a correction for sqrt(N) - note that this is also similar to weight # normalization - hidden_states = ( - hidden_states - / torch_norm(hidden_states, dim=-1, keepdim=True, dtype=torch.float32) - * math.sqrt(hidden_states.shape[-1]) - ) + # attempt to fuse: + # with torch.jit.fuser("fuser2"): + # hidden_states = ( + # hidden_states + # / (torch_norm(hidden_states, dim=-1, keepdim=True, dtype=torch.float32) + self.variance_epsilon) + # * math.sqrt(hidden_states.shape[-1]) + # ) + # if self.weight.dtype in [torch.float16, torch.bfloat16]: + # hidden_states = hidden_states.to(self.weight.dtype) + + # layer norm should always be calculated in float32 + # variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + print("XXX: using FusedRMSNorm") +except ImportError: + print("XXX: using T5LayerNorm") +except Exception: + print("XXX: using T5LayerNorm: unknown exception") + pass + + class T5DenseReluDense(nn.Module): def __init__(self, config): super().__init__() From e7523dd479ee518b92f5a32d02c5e278f6f2c46f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 3 Feb 2022 18:39:47 -0800 Subject: [PATCH 4/9] cleanup --- src/transformers/models/t5/modeling_t5.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index fdc61b3929f42c..fb6b8724f51dbb 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -248,31 +248,11 @@ def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32, so the original code, which might be easier to - # understand was: + # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # but it's more efficient to use the fused kernel norm and it doesn't require converting the - # inputs to fp32, just the intermediary results with norm which are much smaller, we just - # need to make a correction for sqrt(N) - note that this is also similar to weight - # normalization - - # attempt to fuse: - # with torch.jit.fuser("fuser2"): - # hidden_states = ( - # hidden_states - # / (torch_norm(hidden_states, dim=-1, keepdim=True, dtype=torch.float32) + self.variance_epsilon) - # * math.sqrt(hidden_states.shape[-1]) - # ) - # if self.weight.dtype in [torch.float16, torch.bfloat16]: - # hidden_states = hidden_states.to(self.weight.dtype) - - # layer norm should always be calculated in float32 - # variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) From 80755c6a82ab9088a0e9ee36e2e9bdb8a1320460 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 4 Feb 2022 13:21:09 -0800 Subject: [PATCH 5/9] cleanup --- src/transformers/models/t5/modeling_t5.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index fb6b8724f51dbb..d206c58f4ef381 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -265,11 +265,9 @@ def forward(self, hidden_states): T5LayerNorm = FusedRMSNorm # noqa - print("XXX: using FusedRMSNorm") + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") except ImportError: - print("XXX: using T5LayerNorm") -except Exception: - print("XXX: using T5LayerNorm: unknown exception") + # using the normal T5LayerNorm pass From 2cb5c30bb839403ca85e62fab012829f3b45d00e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 4 Feb 2022 13:21:22 -0800 Subject: [PATCH 6/9] add doc --- docs/source/model_doc/t5.mdx | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/model_doc/t5.mdx b/docs/source/model_doc/t5.mdx index 47bcdc662f0511..66dd48bad0fa28 100644 --- a/docs/source/model_doc/t5.mdx +++ b/docs/source/model_doc/t5.mdx @@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) +## Performance + +If you'd like a faster performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. + + ## Example scripts T5 is supported by several example scripts, both for pre-training and fine-tuning. From 11c79140ef5e915b4d4ad8fa4a4291e8d27a51da Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 4 Feb 2022 13:40:22 -0800 Subject: [PATCH 7/9] add catch all --- src/transformers/models/t5/modeling_t5.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index d206c58f4ef381..b664d5fd5cd58a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -269,6 +269,9 @@ def forward(self, hidden_states): except ImportError: # using the normal T5LayerNorm pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass class T5DenseReluDense(nn.Module): From ab1b410bbdf3dc9fee50c704d31660c234ef18e1 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 4 Feb 2022 13:47:02 -0800 Subject: [PATCH 8/9] Trigger CI From fbe0eeffe32a4df26b6f2e0623c573bf6a0721b3 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 7 Feb 2022 11:26:31 -0800 Subject: [PATCH 9/9] expand --- docs/source/model_doc/t5.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model_doc/t5.mdx b/docs/source/model_doc/t5.mdx index 66dd48bad0fa28..dbcfaf1c7dc7fe 100644 --- a/docs/source/model_doc/t5.mdx +++ b/docs/source/model_doc/t5.mdx @@ -265,7 +265,7 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) ## Performance -If you'd like a faster performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. +If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter. ## Example scripts