Skip to content

Commit

Permalink
Always run SiLU activation in float32 for LLaMA and Mistral (#1540)
Browse files Browse the repository at this point in the history
* Fix discrepency between HF LLaMA and our implementation

* Fix Mistral transformer decoder
  • Loading branch information
tirthasheshpatel authored Apr 1, 2024
1 parent 1286784 commit 3b3acb5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
12 changes: 11 additions & 1 deletion keras_nlp/models/llama/llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def build(self, decoder_sequence_shape):

self._feedforward_gate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
use_bias=False,
dtype=self.dtype_policy,
Expand Down Expand Up @@ -167,6 +166,17 @@ def call(
x = self._feedforward_layernorm(x)
gate_output = self._feedforward_gate_dense(x)

# Note that we run the activation function in full 32-bit
# precision since this is what `torch.nn.functional.silu`
# does. Internally, `torch.nn.functional.silu` converts the
# inputs to float32, computes SiLU, and converts the outputs
# back to compute dtype.
# CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
# CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
gate_output = ops.cast(gate_output, "float32")
gate_output = self.activation(gate_output)
gate_output = ops.cast(gate_output, self.compute_dtype)

x = self._feedforward_intermediate_dense(x)

x = self._feedforward_output_dense(ops.multiply(x, gate_output))
Expand Down
12 changes: 11 additions & 1 deletion keras_nlp/models/mistral/mistral_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def build(self, decoder_sequence_shape):

self._feedforward_gate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
use_bias=False,
dtype=self.dtype_policy,
Expand Down Expand Up @@ -172,6 +171,17 @@ def call(
x = self._feedforward_layernorm(x)
gate_output = self._feedforward_gate_dense(x)

# Note that we run the activation function in full 32-bit
# precision since this is what `torch.nn.functional.silu`
# does. Internally, `torch.nn.functional.silu` converts the
# inputs to float32, computes SiLU, and converts the outputs
# back to compute dtype.
# CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
# CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
gate_output = ops.cast(gate_output, "float32")
gate_output = self.activation(gate_output)
gate_output = ops.cast(gate_output, self.compute_dtype)

x = self._feedforward_intermediate_dense(x)

x = self._feedforward_output_dense(ops.multiply(x, gate_output))
Expand Down

0 comments on commit 3b3acb5

Please sign in to comment.