Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reading doc and some discussion #220

Merged
merged 11 commits into from
Sep 6, 2024
37 changes: 16 additions & 21 deletions docs/src/algorithmic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Even if you have worked with AD before, we recommend reading in order to acclima

# Derivatives


A foundation on which all of AD is built the the derivate -- we require a fairly general definition of it, which we build up to here.

_**Scalar-to-Scalar Functions**_
Expand All @@ -16,7 +15,7 @@ Its derivative at ``x`` is usually thought of as the scalar ``\alpha \in \RR`` s
\text{d}f = \alpha \, \text{d}x .
```
Loosely speaking, by this notation we mean that for arbitrary small changes ``\text{d} x`` in the input to ``f``, the change in the output ``\text{d} f`` is ``\alpha \, \text{d}x``.
We refer readers to the first few minutes of the first lecture mentioned above for a more careful explanation.
We refer readers to the first few minutes of the [first lecture mentioned before](https://ocw.mit.edu/courses/18-s096-matrix-calculus-for-machine-learning-and-beyond-january-iap-2023/resources/ocw_18s096_lecture01-part2_2023jan18_mp4/) for a more careful explanation.

_**Vector-to-Vector Functions**_

Expand All @@ -40,7 +39,7 @@ In order to do so, we now introduce a generalised notion of the derivative.

_**Functions Between More General Spaces**_

In order to avoid the difficulties described above, we consider we consider functions ``f : \mathcal{X} \to \mathcal{Y}``, where ``\mathcal{X}`` and ``\mathcal{Y}`` are _finite_ dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars).
In order to avoid the difficulties described above, we consider functions ``f : \mathcal{X} \to \mathcal{Y}``, where ``\mathcal{X}`` and ``\mathcal{Y}`` are _finite_ dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars).
This definition includes functions to / from ``\RR``, ``\RR^D``, but also real-valued matrices, and any other "container" for collections of real numbers.
Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.

Expand All @@ -63,9 +62,7 @@ Inputs ``\dot{x}`` should be thoughts of as "directions", in the directional der

Similarly, if ``\mathcal{X} = \RR^P`` and ``\mathcal{Y} = \RR^Q`` then this operator can be specified in terms of the Jacobian matrix: ``D f [x] (\dot{x}) := J[x] \dot{x}`` -- brackets are used to emphasise that ``D f [x]`` is a function, and is being applied to ``\dot{x}``.[^note_for_geometers]

The difference from usual is a little bit subtle.
We do not define the derivative to _be_ ``\alpha`` or ``J[x]``, rather we define it to be "multiply by ``\alpha``" or "multiply by ``J[x]``".
For the rest of this document we shall use this definition of the derivative.
To reiterate, for the rest of this document, we define the derivative to be "multiply by ``\alpha``" or "multiply by ``J[x]``", rather than to _be_ ``\alpha`` or ``J[x]``.
So whenever you see the word "derivative", you should think "linear function".

_**The Chain Rule**_
Expand All @@ -75,7 +72,7 @@ Fortunately, it applies to this version of the derivative:
```math
f = g \circ h \implies D f [x] = (D g [h(x)]) \circ (D h [x])
```
By induction this extends to a collection of ``N`` functions ``f_1, \dots, f_N``:
By induction, this extends to a collection of ``N`` functions ``f_1, \dots, f_N``:
```math
f := f_N \circ \dots \circ f_1 \implies D f [x] = (D f_N [x_N]) \circ \dots \circ (D f_1 [x_1]),
```
Expand Down Expand Up @@ -106,7 +103,7 @@ For the interested reader we provide a high-level explanation of _how_ forwards-

_**Another aside: notation**_

You will have noticed that we typically denote the argument to a derivative with a "dot" over it, e.g. ``\dot{x}``.
You may have noticed that we typically denote the argument to a derivative with a "dot" over it, e.g. ``\dot{x}``.
This is something that we will do consistently, and we will use the same notation for the outputs of derivatives.
Wherever you see a symbol with a "dot" over it, expect it to be an input or output of a derivative / forwards-mode AD.

Expand All @@ -128,7 +125,8 @@ Specifically, the adjoint ``A^\ast`` of linear operator ``A`` is the linear oper
```math
\langle A^\ast \bar{y}, \dot{x} \rangle = \langle \bar{y}, A \dot{x} \rangle.
```
The relationship between the adjoint and matrix transpose is this: if ``A (x) := J x`` for some matrix ``J``, then ``A^\ast (y) := J^\top y``.
where ``\langle \cdot, \cdot \rangle`` denotes the inner-product.
The relationship between the adjoint and matrix transpose is: if ``A (x) := J x`` for some matrix ``J``, then ``A^\ast (y) := J^\top y``.

Moreover, just as ``(A B)^\top = B^\top A^\top`` when ``A`` and ``B`` are matrices, ``(A B)^\ast = B^\ast A^\ast`` when ``A`` and ``B`` are linear operators.
This result follows in short order from the definition of the adjoint operator -- (and is a good exercise!)
Expand Down Expand Up @@ -164,7 +162,7 @@ We have introduced some mathematical abstraction in order to simplify the calcul
To this end, we consider differentiating ``f(X) := X^\top X``.
Results for this and similar operations are given by [giles2008extended](@cite).
A similar operation, but which maps from matrices to ``\RR`` is discussed in Lecture 4 part 2 of the MIT course mentioned previouly.
Both [giles2008extended](@cite) and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.
Both [giles2008extended](@cite) and [Lecture 4 part 2](https://ocw.mit.edu/courses/18-s096-matrix-calculus-for-machine-learning-and-beyond-january-iap-2023/resources/ocw_18s096_lecture04-part2_2023jan26_mp4/) provide approaches to obtaining the derivative of this function.

Following either resource will yield the derivative:
```math
Expand Down Expand Up @@ -194,7 +192,7 @@ D f [X]^\ast (\bar{Y}) = \bar{Y} X^\top + X \bar{Y}.

#### AD of a Julia function: a trivial example

We now turn to differentiating Julia `function`s.
We now turn to differentiating Julia `function`s (we use `function` to refer to the programming language construct, and function to refer to a more general mathematical concept).
The way that Tapir.jl handles immutable data is very similar to how Zygote / ChainRules do.
For example, consider the Julia function
```julia
Expand Down Expand Up @@ -245,10 +243,9 @@ From here the adjoint can be read off from the first argument to the inner produ
D f [x]^\ast (\bar{f}) = \cos(x) \bar{f}.
```


#### AD of a Julia function: a slightly less trivial example

Now consider the Julia function
Now consider the Julia `function`
```julia
f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]
```
Expand All @@ -259,8 +256,6 @@ g -> (g, (y[2] * g, y[1] * g))

As before, we work through in detail.



_**Step 1: Differentiable Mathematical Model**_

There are a couple of aspects of `f` which require thought:
Expand Down Expand Up @@ -299,9 +294,9 @@ D f [x, y]^\ast (\bar{f}) = (\bar{f}, (\bar{f} y_2, \bar{f} y_1))

#### AD with mutable data

In the previous two examples there was an obvious mathematical model for the Julia function.
In the previous two examples there was an obvious mathematical model for the Julia `function`.
Indeed this model was sufficiently obvious that it required little explanation.
This is not always the case though, in particular, Julia functions which modify / mutate their inputs require a little more thought.
This is not always the case though, in particular, Julia `function`s which modify / mutate their inputs require a little more thought.

Consider the following Julia `function`:
```julia
Expand All @@ -315,12 +310,12 @@ So what is an appropriate mathematical model for this `function`?

_**Step 1: Differentiable Mathematical Model**_

The trick is to distingush between the state of `x` upon _entry_ to / _exit_ from `f!`.
The trick is to distinguish between the state of `x` upon _entry_ to / _exit_ from `f!`.
In particular, let ``\phi_{\text{f!}} : \RR^N \to \{ \RR^N \times \RR \}`` be given by
```math
\phi_{\text{f!}}(x) = (x \odot x, \sum_{n=1}^N x_n^2)
```
where ``\odot`` denotes the Hadamard / elementwise product.
where ``\odot`` denotes the Hadamard / element-wise product (corresponds to line `x .*= x` in the above code).
The point here is that the inputs to ``\phi_{\text{f!}}`` are the inputs to `x` upon entry to `f!`, and the value returned from ``\phi_{\text{f!}}`` is a tuple containing the both the inputs upon exit from `f!` and the value returned by `f!`.

The remaining steps are straightforward now that we have the model.
Expand Down Expand Up @@ -383,13 +378,13 @@ Forwards-Pass:
2. construct ``D f_n [x_n]^\ast``
3. let ``x_{n+1} = f_n (x_n)``
4. let ``n = n + 1``
5. if ``n < N + 1`` then go to 2
5. if ``n < N + 1`` then go to step 2.

Reverse-Pass:
1. let ``\bar{x}_{N+1} = \bar{y}``
2. let ``n = n - 1``
3. let ``\bar{x}_{n} = D f_n [x_n]^\ast (\bar{x}_{n+1})``
4. if ``n = 1`` return ``\bar{x}_1`` else go to 2.
4. if ``n = 1`` return ``\bar{x}_1`` else go to step 2.



Expand Down
8 changes: 5 additions & 3 deletions docs/src/mathematical_interpretation.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Crucially, observe that we distinguish between the state of the arguments before

For our example, the exact form of ``f`` is
```math
f((x, y, z)) = ((x, y, x \odot y), (2 x \odot y, \sum_{d=1}^D x \odot y))
f((x, y, z, s)) = ((x, y, x \odot y, \text{Ref}(2 x \odot y)), (2 x \odot y, \sum_{d=1}^D x \odot y))
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
```
Observe that ``f`` behaves a little like a transition operator, in the that the first element of the tuple returned is the updated state of the arguments.

Expand Down Expand Up @@ -173,7 +173,8 @@ Consider the usual inner product to derive the adjoint:
```math
\begin{align}
\langle \bar{y}, D f [x] (\dot{x}) \rangle &= \langle (\bar{y}_1, \bar{y}_2), (\dot{x}, D \varphi [x](\dot{x})) \rangle \nonumber \\
&= \langle \bar{y}_1, \dot{x} \rangle + \langle D \varphi [x]^\ast (\bar{y}_2), \dot{x} \rangle \nonumber \\
&= \langle \bar{y}_1, \dot{x} \rangle + \langle \bar{y}_2, D \varphi [x](\dot{x}) \rangle \nonumber \\
&= \langle \bar{y}_1, \dot{x} \rangle + \langle D \varphi [x]^\ast (\bar{y}_2), \dot{x} \rangle \nonumber \quad \text{(by definition of the adjoint)} \\
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
&= \langle \bar{y}_1 + D \varphi [x]^\ast (\bar{y}_2), \dot{x} \rangle. \nonumber
\end{align}
```
Expand Down Expand Up @@ -269,7 +270,7 @@ Consequently, we use the same type to represent both.

_**Representing Gradients**_

This package assigns to each type in Julia a unique `tangent_type`, to purpose of which is to contain the gradients computed during reverse mode AD.
This package assigns to each type in Julia a unique `tangent_type`, the purpose of which is to contain the gradients computed during reverse mode AD.
The extended docstring for [`tangent_type`](@ref) provides the best introduction to the types which are used to represent tangents / gradients.

```@docs
Expand Down Expand Up @@ -340,6 +341,7 @@ where ``\mathbf{1}`` is the vector of length ``N`` in which each element is equa
(Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).

Now that we know what the adjoint is, we'll write down the `rrule!!`, and then explain what is going on in terms of the adjoint.
Tapir.jl will generate `rrule!!`s of similar effect (mostly) automatically, but it is good to see a manual implementation for understanding.
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
```julia
function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})
dx_fdata = x.tangent[2]
Expand Down
Loading