diff --git a/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py b/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py index 1e7a63ed8b..a2bc2d0d07 100644 --- a/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py +++ b/keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py @@ -67,12 +67,12 @@ def positional_embedding(self, pos_seq, inv_freq, bsz=None): def relative_positional_encoding(self, qlen, klen, bsz=None, clamp_len=-1): """create relative positional encoding.""" - freq_seq = ops.arange(0, self.hidden_dim, 2.0) + freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype=self.compute_dtype) inv_freq = 1 / (10000 ** (freq_seq / self.hidden_dim)) beg, end = klen, -qlen - fwd_pos_seq = ops.arange(beg, end, -1.0) + fwd_pos_seq = ops.arange(beg, end, -1.0, dtype=self.compute_dtype) if clamp_len > 0: fwd_pos_seq = ops.clip( fwd_pos_seq, x_min=-clamp_len, x_max=clamp_len