|  | 
| 1 |  | -# Type Trees in Enzyme | 
|  | 1 | +# TypeTrees for Autodiff | 
| 2 | 2 | 
 | 
| 3 |  | -This document describes type trees as used by Enzyme for automatic differentiation. | 
|  | 3 | +## What are TypeTrees? | 
|  | 4 | +Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. | 
| 4 | 5 | 
 | 
| 5 |  | -## What are Type Trees? | 
| 6 |  | - | 
| 7 |  | -Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process. | 
|  | 6 | +## Structure | 
|  | 7 | +```rust | 
|  | 8 | +TypeTree(Vec<Type>) | 
| 8 | 9 | 
 | 
| 9 |  | -## Representing Rust Types as Type Trees | 
|  | 10 | +Type { | 
|  | 11 | +    offset: isize,  // byte offset (-1 = everywhere) | 
|  | 12 | +    size: usize,    // size in bytes | 
|  | 13 | +    kind: Kind,     // Float, Integer, Pointer, etc. | 
|  | 14 | +    child: TypeTree // nested structure | 
|  | 15 | +} | 
|  | 16 | +``` | 
| 10 | 17 | 
 | 
| 11 |  | -Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types. | 
|  | 18 | +## Example: `fn compute(x: &f32, data: &[f32]) -> f32` | 
| 12 | 19 | 
 | 
| 13 |  | -The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system. | 
|  | 20 | +**Input 0: `x: &f32`** | 
|  | 21 | +```rust | 
|  | 22 | +TypeTree(vec![Type { | 
|  | 23 | +    offset: -1, size: 8, kind: Pointer, | 
|  | 24 | +    child: TypeTree(vec![Type { | 
|  | 25 | +        offset: -1, size: 4, kind: Float, | 
|  | 26 | +        child: TypeTree::new() | 
|  | 27 | +    }]) | 
|  | 28 | +}]) | 
|  | 29 | +``` | 
| 14 | 30 | 
 | 
| 15 |  | -### Primitive Types | 
|  | 31 | +**Input 1: `data: &[f32]`** | 
|  | 32 | +```rust | 
|  | 33 | +TypeTree(vec![Type { | 
|  | 34 | +    offset: -1, size: 8, kind: Pointer, | 
|  | 35 | +    child: TypeTree(vec![Type { | 
|  | 36 | +        offset: -1, size: 4, kind: Float,  // -1 = all elements | 
|  | 37 | +        child: TypeTree::new() | 
|  | 38 | +    }]) | 
|  | 39 | +}]) | 
|  | 40 | +``` | 
| 16 | 41 | 
 | 
| 17 |  | -#### Floating-Point Types (`f32`, `f64`) | 
|  | 42 | +**Output: `f32`** | 
|  | 43 | +```rust | 
|  | 44 | +TypeTree(vec![Type { | 
|  | 45 | +    offset: -1, size: 4, kind: Float, | 
|  | 46 | +    child: TypeTree::new() | 
|  | 47 | +}]) | 
|  | 48 | +``` | 
| 18 | 49 | 
 | 
| 19 |  | -Consider a Rust reference to a 32-bit floating-point number, `&f32`. | 
|  | 50 | +## Why Needed? | 
|  | 51 | +- Enzyme can't deduce complex type layouts from LLVM IR | 
|  | 52 | +- Prevents slow memory pattern analysis | 
|  | 53 | +- Enables correct derivative computation for nested structures | 
|  | 54 | +- Tells Enzyme which bytes are differentiable vs metadata | 
| 20 | 55 | 
 | 
| 21 |  | -In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function: | 
|  | 56 | +## What Enzyme Does With This Information: | 
| 22 | 57 | 
 | 
|  | 58 | +Without TypeTrees: | 
| 23 | 59 | ```llvm | 
| 24 |  | -define internal void @callee(i8* %x) { | 
| 25 |  | -start: | 
| 26 |  | -  %x.dbg.spill = bitcast i8* %x to float* | 
| 27 |  | -  ; ... | 
| 28 |  | -  ret void | 
|  | 60 | +; Enzyme sees generic LLVM IR: | 
|  | 61 | +define float @distance(i8* %p1, i8* %p2) { | 
|  | 62 | +; Has to guess what these pointers point to | 
|  | 63 | +; Slow analysis of all memory operations | 
|  | 64 | +; May miss optimization opportunities | 
| 29 | 65 | } | 
| 30 | 66 | ``` | 
| 31 | 67 | 
 | 
| 32 |  | -When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast: | 
| 33 |  | - | 
|  | 68 | +With TypeTrees: | 
| 34 | 69 | ```llvm | 
| 35 |  | -i8* %x: {[-1]:Pointer, [-1,0]:Float@float} | 
| 36 |  | -%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float} | 
|  | 70 | +define "enzyme_type"="{[]:Float@float}" float @distance( | 
|  | 71 | +    ptr "enzyme_type"="{[]:Pointer}" %p1,  | 
|  | 72 | +    ptr "enzyme_type"="{[]:Pointer}" %p2 | 
|  | 73 | +) { | 
|  | 74 | +; Enzyme knows exact type layout | 
|  | 75 | +; Can generate efficient derivative code directly | 
|  | 76 | +} | 
| 37 | 77 | ``` | 
| 38 | 78 | 
 | 
| 39 |  | -**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`** | 
| 40 |  | - | 
| 41 |  | -This string is the type tree representation. Let's break it down: | 
|  | 79 | +# TypeTrees - Offset and -1 Explained | 
| 42 | 80 | 
 | 
| 43 |  | -*   **`{ ... }`**: This encloses the set of type information for the variable. | 
| 44 |  | -*   **`[-1]:Pointer`**: | 
| 45 |  | -    *   `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to. | 
| 46 |  | -    *   `Pointer` indicates that the variable `%x` itself is treated as a pointer. | 
| 47 |  | -*   **`[-1,0]:Float@float`**: | 
| 48 |  | -    *   `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to. | 
| 49 |  | -    *   `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number. | 
| 50 |  | -    *   `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`). | 
|  | 81 | +## Type Structure | 
| 51 | 82 | 
 | 
| 52 |  | -A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`: | 
| 53 |  | -```llvm | 
| 54 |  | -define internal void @callee(i8* %x) { | 
| 55 |  | -start: | 
| 56 |  | -  %x.dbg.spill = bitcast i8* %x to double* | 
| 57 |  | -  ; ... | 
| 58 |  | -  ret void | 
|  | 83 | +```rust | 
|  | 84 | +Type { | 
|  | 85 | +    offset: isize, // WHERE this type starts | 
|  | 86 | +    size: usize,   // HOW BIG this type is | 
|  | 87 | +    kind: Kind,    // WHAT KIND of data (Float, Int, Pointer) | 
|  | 88 | +    child: TypeTree // WHAT'S INSIDE (for pointers/containers) | 
| 59 | 89 | } | 
| 60 | 90 | ``` | 
| 61 | 91 | 
 | 
| 62 |  | -And the type tree would be: | 
|  | 92 | +## Offset Values | 
| 63 | 93 | 
 | 
| 64 |  | -```llvm | 
| 65 |  | -i8* %x: {[-1]:Pointer, [-1,0]:Float@double} | 
| 66 |  | -``` | 
| 67 |  | -The key difference is `@double`, indicating a double-precision float. | 
| 68 |  | - | 
| 69 |  | -This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision. | 
| 70 |  | - | 
| 71 |  | -### Compound Types | 
| 72 |  | - | 
| 73 |  | -#### Structs | 
| 74 |  | - | 
| 75 |  | -Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`): | 
|  | 94 | +### Regular Offset (0, 4, 8, etc.) | 
|  | 95 | +**Specific byte position within a structure** | 
| 76 | 96 | 
 | 
| 77 | 97 | ```rust | 
| 78 |  | -struct T { | 
| 79 |  | -    x: f32, | 
| 80 |  | -    y: f32, | 
|  | 98 | +struct Point { | 
|  | 99 | +    x: f32, // offset 0, size 4 | 
|  | 100 | +    y: f32, // offset 4, size 4 | 
|  | 101 | +    id: i32, // offset 8, size 4 | 
| 81 | 102 | } | 
| 82 |  | - | 
| 83 |  | -// And a function taking a reference to it: | 
| 84 |  | -// fn callee(t: &T) { /* ... */ } | 
| 85 | 103 | ``` | 
| 86 | 104 | 
 | 
| 87 |  | -In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`: | 
|  | 105 | +TypeTree for `&Point` (internal representation): | 
|  | 106 | +```rust | 
|  | 107 | +TypeTree(vec![ | 
|  | 108 | +    Type { offset: 0, size: 4, kind: Float },   // x at byte 0 | 
|  | 109 | +    Type { offset: 4, size: 4, kind: Float },   // y at byte 4 | 
|  | 110 | +    Type { offset: 8, size: 4, kind: Integer }  // id at byte 8 | 
|  | 111 | +]) | 
|  | 112 | +``` | 
| 88 | 113 | 
 | 
|  | 114 | +Generates LLVM: | 
| 89 | 115 | ```llvm | 
| 90 |  | -define internal void @callee(i8* %t) { | 
| 91 |  | -start: | 
| 92 |  | -  %t.dbg.spill = bitcast i8* %t to { float, float }* | 
| 93 |  | -  ; ... | 
| 94 |  | -  ret void | 
| 95 |  | -} | 
|  | 116 | +"enzyme_type"="{[]:Float@float}" | 
| 96 | 117 | ``` | 
| 97 | 118 | 
 | 
| 98 |  | -The Enzyme type analysis output for `%t` would be: | 
|  | 119 | +### Offset -1 (Special: "Everywhere") | 
|  | 120 | +**Means "this pattern repeats for ALL elements"** | 
| 99 | 121 | 
 | 
| 100 |  | -```llvm | 
| 101 |  | -i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float} | 
|  | 122 | +#### Example 1: Array `[f32; 100]` | 
|  | 123 | +```rust | 
|  | 124 | +TypeTree(vec![Type { | 
|  | 125 | +    offset: -1, // ALL positions | 
|  | 126 | +    size: 4,    // each f32 is 4 bytes | 
|  | 127 | +    kind: Float, // every element is float | 
|  | 128 | +}]) | 
| 102 | 129 | ``` | 
| 103 | 130 | 
 | 
| 104 |  | -**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`** | 
|  | 131 | +Instead of listing 100 separate Types with offsets `0,4,8,12...396` | 
| 105 | 132 | 
 | 
| 106 |  | -*   **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer. | 
| 107 |  | -*   **`[-1,0]:Float@float`**: | 
| 108 |  | -    *   This describes the first field of the struct (`x`). | 
| 109 |  | -    *   `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes. | 
| 110 |  | -    *   `Float@float` indicates this field is an `f32`. | 
| 111 |  | -*   **`[-1,4]:Float@float`**: | 
| 112 |  | -    *   This describes the second field of the struct (`y`). | 
| 113 |  | -    *   `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes. | 
| 114 |  | -    *   `Float@float` indicates this field is also an `f32`. | 
|  | 133 | +#### Example 2: Slice `&[i32]` | 
|  | 134 | +```rust | 
|  | 135 | +// Pointer to slice data | 
|  | 136 | +TypeTree(vec![Type { | 
|  | 137 | +    offset: -1, size: 8, kind: Pointer, | 
|  | 138 | +    child: TypeTree(vec![Type { | 
|  | 139 | +        offset: -1, // ALL slice elements | 
|  | 140 | +        size: 4,    // each i32 is 4 bytes | 
|  | 141 | +        kind: Integer | 
|  | 142 | +    }]) | 
|  | 143 | +}]) | 
|  | 144 | +``` | 
| 115 | 145 | 
 | 
| 116 |  | -The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct. | 
|  | 146 | +#### Example 3: Mixed Structure | 
|  | 147 | +```rust | 
|  | 148 | +struct Container { | 
|  | 149 | +    header: i64,        // offset 0 | 
|  | 150 | +    data: [f32; 1000],  // offset 8, but elements use -1 | 
|  | 151 | +} | 
|  | 152 | +``` | 
| 117 | 153 | 
 | 
| 118 |  | -This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation. | 
|  | 154 | +```rust | 
|  | 155 | +TypeTree(vec![ | 
|  | 156 | +    Type { offset: 0, size: 8, kind: Integer }, // header | 
|  | 157 | +    Type { offset: 8, size: 4000, kind: Pointer, | 
|  | 158 | +        child: TypeTree(vec![Type { | 
|  | 159 | +            offset: -1, size: 4, kind: Float // ALL array elements | 
|  | 160 | +        }]) | 
|  | 161 | +    } | 
|  | 162 | +]) | 
|  | 163 | +``` | 
0 commit comments