From 7080cfab3689c504cf00f303bea9e800ae3c3649 Mon Sep 17 00:00:00 2001 From: johndolgov Date: Tue, 18 Aug 2020 13:58:01 +0300 Subject: [PATCH 1/3] xlnet fp16 bug fix --- src/transformers/modeling_xlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index ddb655656a8c3b..6efbc6c1c02522 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -439,7 +439,7 @@ def forward( v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v) # positional heads - k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r) + k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r) # core attention ops attn_vec = self.rel_attn_core( From 0a9991a666bfaa104b88a80e889b7d6a11b601fc Mon Sep 17 00:00:00 2001 From: johndolgov Date: Thu, 20 Aug 2020 17:45:28 +0300 Subject: [PATCH 2/3] comment cast added --- src/transformers/modeling_xlnet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 6efbc6c1c02522..8d916632661c6e 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -439,6 +439,8 @@ def forward( v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v) # positional heads + + # 16-bit precision tensors support k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r) # core attention ops From 1c18ddf892d1e5adbec12f4878f553a68d960e5f Mon Sep 17 00:00:00 2001 From: Kevin Canwen Xu Date: Fri, 21 Aug 2020 01:00:19 +0800 Subject: [PATCH 3/3] Update modeling_xlnet.py --- src/transformers/modeling_xlnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 8d916632661c6e..dd42a40c26ce12 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -439,8 +439,7 @@ def forward( v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v) # positional heads - - # 16-bit precision tensors support + # type casting for fp16 support k_head_r = torch.einsum("ibh,hnd->ibnd", r.type(self.r.dtype), self.r) # core attention ops