Skip to content

Commit e3a1169

Browse files
committed
updating the docs
1 parent 4eddd90 commit e3a1169

File tree

1 file changed

+125
-80
lines changed

1 file changed

+125
-80
lines changed

src/autodiff/type-trees.md

Lines changed: 125 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,163 @@
1-
# Type Trees in Enzyme
1+
# TypeTrees for Autodiff
22

3-
This document describes type trees as used by Enzyme for automatic differentiation.
3+
## What are TypeTrees?
4+
Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.
45

5-
## What are Type Trees?
6-
7-
Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process.
6+
## Structure
7+
```rust
8+
TypeTree(Vec<Type>)
89

9-
## Representing Rust Types as Type Trees
10+
Type {
11+
offset: isize, // byte offset (-1 = everywhere)
12+
size: usize, // size in bytes
13+
kind: Kind, // Float, Integer, Pointer, etc.
14+
child: TypeTree // nested structure
15+
}
16+
```
1017

11-
Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types.
18+
## Example: `fn compute(x: &f32, data: &[f32]) -> f32`
1219

13-
The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system.
20+
**Input 0: `x: &f32`**
21+
```rust
22+
TypeTree(vec![Type {
23+
offset: -1, size: 8, kind: Pointer,
24+
child: TypeTree(vec![Type {
25+
offset: -1, size: 4, kind: Float,
26+
child: TypeTree::new()
27+
}])
28+
}])
29+
```
1430

15-
### Primitive Types
31+
**Input 1: `data: &[f32]`**
32+
```rust
33+
TypeTree(vec![Type {
34+
offset: -1, size: 8, kind: Pointer,
35+
child: TypeTree(vec![Type {
36+
offset: -1, size: 4, kind: Float, // -1 = all elements
37+
child: TypeTree::new()
38+
}])
39+
}])
40+
```
1641

17-
#### Floating-Point Types (`f32`, `f64`)
42+
**Output: `f32`**
43+
```rust
44+
TypeTree(vec![Type {
45+
offset: -1, size: 4, kind: Float,
46+
child: TypeTree::new()
47+
}])
48+
```
1849

19-
Consider a Rust reference to a 32-bit floating-point number, `&f32`.
50+
## Why Needed?
51+
- Enzyme can't deduce complex type layouts from LLVM IR
52+
- Prevents slow memory pattern analysis
53+
- Enables correct derivative computation for nested structures
54+
- Tells Enzyme which bytes are differentiable vs metadata
2055

21-
In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function:
56+
## What Enzyme Does With This Information:
2257

58+
Without TypeTrees:
2359
```llvm
24-
define internal void @callee(i8* %x) {
25-
start:
26-
%x.dbg.spill = bitcast i8* %x to float*
27-
; ...
28-
ret void
60+
; Enzyme sees generic LLVM IR:
61+
define float @distance(i8* %p1, i8* %p2) {
62+
; Has to guess what these pointers point to
63+
; Slow analysis of all memory operations
64+
; May miss optimization opportunities
2965
}
3066
```
3167

32-
When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast:
33-
68+
With TypeTrees:
3469
```llvm
35-
i8* %x: {[-1]:Pointer, [-1,0]:Float@float}
36-
%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float}
70+
define "enzyme_type"="{[]:Float@float}" float @distance(
71+
ptr "enzyme_type"="{[]:Pointer}" %p1,
72+
ptr "enzyme_type"="{[]:Pointer}" %p2
73+
) {
74+
; Enzyme knows exact type layout
75+
; Can generate efficient derivative code directly
76+
}
3777
```
3878

39-
**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`**
40-
41-
This string is the type tree representation. Let's break it down:
79+
# TypeTrees - Offset and -1 Explained
4280

43-
* **`{ ... }`**: This encloses the set of type information for the variable.
44-
* **`[-1]:Pointer`**:
45-
* `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to.
46-
* `Pointer` indicates that the variable `%x` itself is treated as a pointer.
47-
* **`[-1,0]:Float@float`**:
48-
* `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to.
49-
* `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number.
50-
* `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`).
81+
## Type Structure
5182

52-
A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`:
53-
```llvm
54-
define internal void @callee(i8* %x) {
55-
start:
56-
%x.dbg.spill = bitcast i8* %x to double*
57-
; ...
58-
ret void
83+
```rust
84+
Type {
85+
offset: isize, // WHERE this type starts
86+
size: usize, // HOW BIG this type is
87+
kind: Kind, // WHAT KIND of data (Float, Int, Pointer)
88+
child: TypeTree // WHAT'S INSIDE (for pointers/containers)
5989
}
6090
```
6191

62-
And the type tree would be:
92+
## Offset Values
6393

64-
```llvm
65-
i8* %x: {[-1]:Pointer, [-1,0]:Float@double}
66-
```
67-
The key difference is `@double`, indicating a double-precision float.
68-
69-
This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision.
70-
71-
### Compound Types
72-
73-
#### Structs
74-
75-
Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`):
94+
### Regular Offset (0, 4, 8, etc.)
95+
**Specific byte position within a structure**
7696

7797
```rust
78-
struct T {
79-
x: f32,
80-
y: f32,
98+
struct Point {
99+
x: f32, // offset 0, size 4
100+
y: f32, // offset 4, size 4
101+
id: i32, // offset 8, size 4
81102
}
82-
83-
// And a function taking a reference to it:
84-
// fn callee(t: &T) { /* ... */ }
85103
```
86104

87-
In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`:
105+
TypeTree for `&Point` (internal representation):
106+
```rust
107+
TypeTree(vec![
108+
Type { offset: 0, size: 4, kind: Float }, // x at byte 0
109+
Type { offset: 4, size: 4, kind: Float }, // y at byte 4
110+
Type { offset: 8, size: 4, kind: Integer } // id at byte 8
111+
])
112+
```
88113

114+
Generates LLVM:
89115
```llvm
90-
define internal void @callee(i8* %t) {
91-
start:
92-
%t.dbg.spill = bitcast i8* %t to { float, float }*
93-
; ...
94-
ret void
95-
}
116+
"enzyme_type"="{[]:Float@float}"
96117
```
97118

98-
The Enzyme type analysis output for `%t` would be:
119+
### Offset -1 (Special: "Everywhere")
120+
**Means "this pattern repeats for ALL elements"**
99121

100-
```llvm
101-
i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}
122+
#### Example 1: Array `[f32; 100]`
123+
```rust
124+
TypeTree(vec![Type {
125+
offset: -1, // ALL positions
126+
size: 4, // each f32 is 4 bytes
127+
kind: Float, // every element is float
128+
}])
102129
```
103130

104-
**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`**
131+
Instead of listing 100 separate Types with offsets `0,4,8,12...396`
105132

106-
* **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer.
107-
* **`[-1,0]:Float@float`**:
108-
* This describes the first field of the struct (`x`).
109-
* `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes.
110-
* `Float@float` indicates this field is an `f32`.
111-
* **`[-1,4]:Float@float`**:
112-
* This describes the second field of the struct (`y`).
113-
* `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes.
114-
* `Float@float` indicates this field is also an `f32`.
133+
#### Example 2: Slice `&[i32]`
134+
```rust
135+
// Pointer to slice data
136+
TypeTree(vec![Type {
137+
offset: -1, size: 8, kind: Pointer,
138+
child: TypeTree(vec![Type {
139+
offset: -1, // ALL slice elements
140+
size: 4, // each i32 is 4 bytes
141+
kind: Integer
142+
}])
143+
}])
144+
```
115145

116-
The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct.
146+
#### Example 3: Mixed Structure
147+
```rust
148+
struct Container {
149+
header: i64, // offset 0
150+
data: [f32; 1000], // offset 8, but elements use -1
151+
}
152+
```
117153

118-
This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation.
154+
```rust
155+
TypeTree(vec![
156+
Type { offset: 0, size: 8, kind: Integer }, // header
157+
Type { offset: 8, size: 4000, kind: Pointer,
158+
child: TypeTree(vec![Type {
159+
offset: -1, size: 4, kind: Float // ALL array elements
160+
}])
161+
}
162+
])
163+
```

0 commit comments

Comments
 (0)