Skip to content

Commit 4eddd90

Browse files
committed
basic type docs for auto diff
1 parent 18121a9 commit 4eddd90

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

src/SUMMARY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
- [Installation](./autodiff/installation.md)
109109
- [How to debug](./autodiff/debugging.md)
110110
- [Autodiff flags](./autodiff/flags.md)
111+
- [Type Trees](./autodiff/type-trees.md)
111112

112113
# Source Code Representation
113114

src/autodiff/type-trees.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Type Trees in Enzyme
2+
3+
This document describes type trees as used by Enzyme for automatic differentiation.
4+
5+
## What are Type Trees?
6+
7+
Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process.
8+
9+
## Representing Rust Types as Type Trees
10+
11+
Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types.
12+
13+
The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system.
14+
15+
### Primitive Types
16+
17+
#### Floating-Point Types (`f32`, `f64`)
18+
19+
Consider a Rust reference to a 32-bit floating-point number, `&f32`.
20+
21+
In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function:
22+
23+
```llvm
24+
define internal void @callee(i8* %x) {
25+
start:
26+
%x.dbg.spill = bitcast i8* %x to float*
27+
; ...
28+
ret void
29+
}
30+
```
31+
32+
When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast:
33+
34+
```llvm
35+
i8* %x: {[-1]:Pointer, [-1,0]:Float@float}
36+
%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float}
37+
```
38+
39+
**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`**
40+
41+
This string is the type tree representation. Let's break it down:
42+
43+
* **`{ ... }`**: This encloses the set of type information for the variable.
44+
* **`[-1]:Pointer`**:
45+
* `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to.
46+
* `Pointer` indicates that the variable `%x` itself is treated as a pointer.
47+
* **`[-1,0]:Float@float`**:
48+
* `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to.
49+
* `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number.
50+
* `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`).
51+
52+
A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`:
53+
```llvm
54+
define internal void @callee(i8* %x) {
55+
start:
56+
%x.dbg.spill = bitcast i8* %x to double*
57+
; ...
58+
ret void
59+
}
60+
```
61+
62+
And the type tree would be:
63+
64+
```llvm
65+
i8* %x: {[-1]:Pointer, [-1,0]:Float@double}
66+
```
67+
The key difference is `@double`, indicating a double-precision float.
68+
69+
This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision.
70+
71+
### Compound Types
72+
73+
#### Structs
74+
75+
Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`):
76+
77+
```rust
78+
struct T {
79+
x: f32,
80+
y: f32,
81+
}
82+
83+
// And a function taking a reference to it:
84+
// fn callee(t: &T) { /* ... */ }
85+
```
86+
87+
In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`:
88+
89+
```llvm
90+
define internal void @callee(i8* %t) {
91+
start:
92+
%t.dbg.spill = bitcast i8* %t to { float, float }*
93+
; ...
94+
ret void
95+
}
96+
```
97+
98+
The Enzyme type analysis output for `%t` would be:
99+
100+
```llvm
101+
i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
102+
```
103+
104+
**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`**
105+
106+
* **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer.
107+
* **`[-1,0]:Float@float`**:
108+
* This describes the first field of the struct (`x`).
109+
* `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes.
110+
* `Float@float` indicates this field is an `f32`.
111+
* **`[-1,4]:Float@float`**:
112+
* This describes the second field of the struct (`y`).
113+
* `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes.
114+
* `Float@float` indicates this field is also an `f32`.
115+
116+
The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct.
117+
118+
This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation.

0 commit comments

Comments
 (0)