-
Notifications
You must be signed in to change notification settings - Fork 12.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
only include rustc_codegen_llvm autodiff changes
- Loading branch information
Showing
24 changed files
with
2,283 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, | ||
/// we create an `AutoDiffItem` which contains the source and target function names. The source | ||
/// is the function to which the autodiff attribute is applied, and the target is the function | ||
/// getting generated by us (with a name given by the user as the first autodiff arg). | ||
use crate::expand::typetree::TypeTree; | ||
use crate::expand::{Decodable, Encodable, HashStable_Generic}; | ||
|
||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum DiffMode { | ||
/// No autodiff is applied (usually used during error handling). | ||
Inactive, | ||
/// The primal function which we will differentiate. | ||
Source, | ||
/// The target function, to be created using forward mode AD. | ||
Forward, | ||
/// The target function, to be created using reverse mode AD. | ||
Reverse, | ||
/// The target function, to be created using forward mode AD. | ||
/// This target function will also be used as a source for higher order derivatives, | ||
/// so compute it before all Forward/Reverse targets and optimize it through llvm. | ||
ForwardFirst, | ||
/// The target function, to be created using reverse mode AD. | ||
/// This target function will also be used as a source for higher order derivatives, | ||
/// so compute it before all Forward/Reverse targets and optimize it through llvm. | ||
ReverseFirst, | ||
} | ||
|
||
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. | ||
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode | ||
/// we add to the previous shadow value. To not surprise users, we picked different names. | ||
/// Dual numbers is also a quite well known name for forward mode AD types. | ||
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum DiffActivity { | ||
/// Implicit or Explicit () return type, so a special case of Const. | ||
None, | ||
/// Don't compute derivatives with respect to this input/output. | ||
Const, | ||
/// Reverse Mode, Compute derivatives for this scalar input/output. | ||
Active, | ||
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute | ||
/// the original return value. | ||
ActiveOnly, | ||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument | ||
/// with it. | ||
Dual, | ||
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument | ||
/// with it. Drop the code which updates the original input/output for maximum performance. | ||
DualOnly, | ||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. | ||
Duplicated, | ||
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. | ||
/// Drop the code which updates the original input for maximum performance. | ||
DuplicatedOnly, | ||
/// All Integers must be Const, but these are used to mark the integer which represents the | ||
/// length of a slice/vec. This is used for safety checks on slices. | ||
FakeActivitySize, | ||
} | ||
/// We generate one of these structs for each `#[autodiff(...)]` attribute. | ||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct AutoDiffItem { | ||
/// The name of the function getting differentiated | ||
pub source: String, | ||
/// The name of the function being generated | ||
pub target: String, | ||
pub attrs: AutoDiffAttrs, | ||
/// Despribe the memory layout of input types | ||
pub inputs: Vec<TypeTree>, | ||
/// Despribe the memory layout of the output type | ||
pub output: TypeTree, | ||
} | ||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct AutoDiffAttrs { | ||
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and | ||
/// e.g. in the [JAX | ||
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). | ||
pub mode: DiffMode, | ||
pub ret_activity: DiffActivity, | ||
pub input_activity: Vec<DiffActivity>, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use std::fmt; | ||
|
||
use crate::expand::{Decodable, Encodable, HashStable_Generic}; | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub enum Kind { | ||
Anything, | ||
Integer, | ||
Pointer, | ||
Half, | ||
Float, | ||
Double, | ||
Unknown, | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct TypeTree(pub Vec<Type>); | ||
|
||
impl TypeTree { | ||
pub fn new() -> Self { | ||
Self(Vec::new()) | ||
} | ||
pub fn all_ints() -> Self { | ||
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }]) | ||
} | ||
pub fn int(size: usize) -> Self { | ||
let mut ints = Vec::with_capacity(size); | ||
for i in 0..size { | ||
ints.push(Type { | ||
offset: i as isize, | ||
size: 1, | ||
kind: Kind::Integer, | ||
child: TypeTree::new(), | ||
}); | ||
} | ||
Self(ints) | ||
} | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct FncTree { | ||
pub args: Vec<TypeTree>, | ||
pub ret: TypeTree, | ||
} | ||
|
||
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] | ||
pub struct Type { | ||
pub offset: isize, | ||
pub size: usize, | ||
pub kind: Kind, | ||
pub child: TypeTree, | ||
} | ||
|
||
impl Type { | ||
pub fn add_offset(self, add: isize) -> Self { | ||
let offset = match self.offset { | ||
-1 => add, | ||
x => add + x, | ||
}; | ||
|
||
Self { size: self.size, kind: self.kind, child: self.child, offset } | ||
} | ||
} | ||
|
||
impl fmt::Display for Type { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
<Self as fmt::Debug>::fmt(self, f) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.