|
| 1 | +# Type Trees in Enzyme |
| 2 | + |
| 3 | +This document describes type trees as used by Enzyme for automatic differentiation. |
| 4 | + |
| 5 | +## What are Type Trees? |
| 6 | + |
| 7 | +Type trees in Enzyme are a way to represent the types of variables, including their activity (e.g., whether they are active, duplicated, or contain duplicated data) for automatic differentiation. They provide a structured way for Enzyme to understand how to handle different data types during the differentiation process. |
| 8 | + |
| 9 | +## Representing Rust Types as Type Trees |
| 10 | + |
| 11 | +Enzyme needs to understand the structure and properties of Rust types to perform automatic differentiation correctly. This is where type trees come in. They provide a detailed map of a type, including pointer indirections and the underlying concrete data types. |
| 12 | + |
| 13 | +The `-enzyme-rust-type` flag in Enzyme helps in interpreting types more accurately in the context of Rust's memory layout and type system. |
| 14 | + |
| 15 | +### Primitive Types |
| 16 | + |
| 17 | +#### Floating-Point Types (`f32`, `f64`) |
| 18 | + |
| 19 | +Consider a Rust reference to a 32-bit floating-point number, `&f32`. |
| 20 | + |
| 21 | +In LLVM IR, this might be represented, for instance, as an `i8*` (a generic byte pointer) that is then `bitcast` to a `float*`. Consider the following LLVM IR function: |
| 22 | + |
| 23 | +```llvm |
| 24 | +define internal void @callee(i8* %x) { |
| 25 | +start: |
| 26 | + %x.dbg.spill = bitcast i8* %x to float* |
| 27 | + ; ... |
| 28 | + ret void |
| 29 | +} |
| 30 | +``` |
| 31 | + |
| 32 | +When Enzyme analyzes this function (with appropriate flags like `-enzyme-rust-type`), it might produce the following type information for the argument `%x` and the result of the bitcast: |
| 33 | + |
| 34 | +```llvm |
| 35 | +i8* %x: {[-1]:Pointer, [-1,0]:Float@float} |
| 36 | +%x.dbg.spill = bitcast i8* %x to float*: {[-1]:Pointer, [-1,0]:Float@float} |
| 37 | +``` |
| 38 | + |
| 39 | +**Understanding the Type Tree: `{[-1]:Pointer, [-1,0]:Float@float}`** |
| 40 | + |
| 41 | +This string is the type tree representation. Let's break it down: |
| 42 | + |
| 43 | +* **`{ ... }`**: This encloses the set of type information for the variable. |
| 44 | +* **`[-1]:Pointer`**: |
| 45 | + * `[-1]` is an index or path. In this context, `-1` often refers to the base memory location or the immediate value pointed to. |
| 46 | + * `Pointer` indicates that the variable `%x` itself is treated as a pointer. |
| 47 | +* **`[-1,0]:Float@float`**: |
| 48 | + * `[-1,0]` is a path. It means: start with the base item `[-1]` (the pointer), and then look at offset `0` from the memory location it points to. |
| 49 | + * `Float` is the `CConcreteType` (from `enzyme_ffi.rs`, corresponding to `DT_Float`). It signifies that the data at this location is a floating-point number. |
| 50 | + * `@float` is a subtype or specific variant of `Float`. In this case, it specifies a single-precision float (like Rust's `f32`). |
| 51 | + |
| 52 | +A reference to an `f64` (e.g., `&f64`) is handled very similarly. The LLVM IR might cast to `double*`: |
| 53 | +```llvm |
| 54 | +define internal void @callee(i8* %x) { |
| 55 | +start: |
| 56 | + %x.dbg.spill = bitcast i8* %x to double* |
| 57 | + ; ... |
| 58 | + ret void |
| 59 | +} |
| 60 | +``` |
| 61 | + |
| 62 | +And the type tree would be: |
| 63 | + |
| 64 | +```llvm |
| 65 | +i8* %x: {[-1]:Pointer, [-1,0]:Float@double} |
| 66 | +``` |
| 67 | +The key difference is `@double`, indicating a double-precision float. |
| 68 | + |
| 69 | +This level of detail allows Enzyme to know, for example, that if `x` is an active variable in differentiation, the floating-point value it points to needs to be handled according to AD rules for its specific precision. |
| 70 | + |
| 71 | +### Compound Types |
| 72 | + |
| 73 | +#### Structs |
| 74 | + |
| 75 | +Consider a Rust struct `T` with two `f32` fields (e.g., a reference `&T`): |
| 76 | + |
| 77 | +```rust |
| 78 | +struct T { |
| 79 | + x: f32, |
| 80 | + y: f32, |
| 81 | +} |
| 82 | + |
| 83 | +// And a function taking a reference to it: |
| 84 | +// fn callee(t: &T) { /* ... */ } |
| 85 | +``` |
| 86 | + |
| 87 | +In LLVM IR, a pointer to this struct might be initially represented as `i8*` and then cast to the specific struct type, like `{ float, float }*`: |
| 88 | + |
| 89 | +```llvm |
| 90 | +define internal void @callee(i8* %t) { |
| 91 | +start: |
| 92 | + %t.dbg.spill = bitcast i8* %t to { float, float }* |
| 93 | + ; ... |
| 94 | + ret void |
| 95 | +} |
| 96 | +``` |
| 97 | + |
| 98 | +The Enzyme type analysis output for `%t` would be: |
| 99 | + |
| 100 | +```llvm |
| 101 | +i8* %t: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float} |
| 102 | +``` |
| 103 | + |
| 104 | +**Understanding the Struct Type Tree: `{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float}`** |
| 105 | + |
| 106 | +* **`[-1]:Pointer`**: As before, this indicates that `%t` is a pointer. |
| 107 | +* **`[-1,0]:Float@float`**: |
| 108 | + * This describes the first field of the struct (`x`). |
| 109 | + * `[-1,0]` means: from the memory location pointed to by `%t` (`-1`), at offset `0` bytes. |
| 110 | + * `Float@float` indicates this field is an `f32`. |
| 111 | +* **`[-1,4]:Float@float`**: |
| 112 | + * This describes the second field of the struct (`y`). |
| 113 | + * `[-1,4]` means: from the memory location pointed to by `%t` (`-1`), at offset `4` bytes. |
| 114 | + * `Float@float` indicates this field is also an `f32`. |
| 115 | + |
| 116 | +The offset `4` comes from the size of the first field (`f32` is 4 bytes). If the first field were, for example, an `f64` (8 bytes), the second field might be at offset `[-1,8]`. Enzyme uses these offsets to pinpoint the exact memory location of each field within the struct. |
| 117 | + |
| 118 | +This detailed mapping is crucial for Enzyme to correctly track the activity of individual struct fields during automatic differentiation. |
0 commit comments