Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correcting Mistakes In Flip Docs #2140

Merged
merged 1 commit into from
May 25, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions docs/flip/1777-default-dtype.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This FLIP proposes to replace the default dtype which is currently fixed to floa

Currently, Linen Modules always produce `module.dtype` (defaults to float32) outputs regardless of input and parameter dtypes. Half-precision types like float16 and bfloat16 are supported by explicitly passing the half-precision type to each Module. The way this is currently implemented is that each Module has a dtype argument with float32 as the default value. The layer guarantees that this dtype will be the return type of the result returned by `__call__`.

The current behavior is problematic and results in silent bugs especially for dtypes that do not fit inside float32 (complex, float64). Also the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes.
The current behavior is problematic and results in silent bugs, especially for dtypes that do not fit inside float32 (complex, float64). Also, the Linen dtype behavior is significantly different from how NumPy and by extension JAX handle dtypes.


### Dtypes in JAX
Expand All @@ -28,15 +28,15 @@ JAX uses a NumPy-inspired [dtype promotion](https://github.com/google/jax/blob/m

## Dtypes in Linen

Beside input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32.
Besides input arguments, state and in particular parameters could affect dtype promotion. For example: we might feed a float64 input to a Dense layer with float32 parameters. Currently, the result would be truncated to float32. If the input is a complex number the result is even worse because the imaginary part will be silently dropped when casting to float32.

By using the dtype promotion rules already available in JAX we can avoid this issue. A public API is available called `jax.numpy.result_dtype(*args)`, which returns the dtype that JAX would promote the given arguments to, in accordance with the type promotion lattice. For Linen layers the arguments would be the layer inputs together with the parameters. For example, for a linear layer this would be inputs, kernel, and bias.

Note that there is also a `param_dtype` attribute in standard Linen Modules that also defaults to flaot32. This behavior is left untouched and encodes the common case of having float32 parameters.
There are a few reasons why float32 is almost always the correct dtype for parameters:
1. Storing weights in half-precision often leads to underflow during optimization.
2. Double precision is rarely used because it severly slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for.
3. Complex Modules are relatively uncommon. Even within complex networks the complex inputs can be projected with a real matrix.
2. Double precision is rarely used because it severely slows down modern accelerators (GPU, TPU). Therefore, such a cost should be explicitly opted-in for.
3. Complex Modules are relatively uncommon. Even within complex networks, the complex inputs can be projected with a real matrix.


# Implementation
Expand Down Expand Up @@ -99,7 +99,7 @@ BatchNorm layers are often used with a half precision output dtype. However, cal

**Complex number support**

Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex number correctly:
Currently, our complex number support is brittle because the default behavior is to truncate the output to the real part. This issue will be fixed by the automatic type promotion proposed in this FLIP. However, some layers require some additional thought to extend to complex numbers correctly:

1. Normalization layers use the complex conjugate to calculate norms instead of normal squaring.
2. Attention: It’s not exactly clear how the dot product and softmax are defined in this case. Raise an error on complex inputs.
Expand All @@ -114,7 +114,7 @@ Summarizing the main points from the discussion:
## Consider implicit complex truncation an error

Q:
I I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do.
I'm wondering if we should always raise an error if one of the xs tree leaves is complex but dtype is not. Users should maybe remove imaginary part by themselves if that's really what they want to do.
(Maybe it's a contrived example, but I can imagine cases where layers have their dtype set by parent modules based on assumptions without complex numbers in mind)

A:
Expand All @@ -124,10 +124,10 @@ This is worth considering in a follow-up CL but this might as well be solved in
## Dtype attribute names

Q:
Are the dtype and param_dtype arguments confusion? In particular should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit?
Are the dtype and param_dtype arguments confusing? In particular, should dtype perhaps be called output_dtype to make the difference between the two dtypes more explicit?

A:
This would be a large and orthogonal change wrt to this proposal so leaving it out for now.
Also this breaks with the standard dtype argument in NumPY/JAX.
This would be a large and orthogonal change wrt to this proposal so leaving it out for now.
Also, this breaks with the standard dtype argument in NumPY/JAX.
Although dtype indeed constrains the output dtype it is also a hint for the dtype we would like the computation to happen in.