diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 1ef247c57..7b4ad5f75 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -97,7 +97,6 @@ def build(self, decoder_sequence_shape): self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, dtype=self.dtype_policy, @@ -167,6 +166,17 @@ def call( x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output)) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 36b7f5944..3ef91d306 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -102,7 +102,6 @@ def build(self, decoder_sequence_shape): self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, dtype=self.dtype_policy, @@ -172,6 +171,17 @@ def call( x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output))