diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs new file mode 100644 index 0000000000000..05714731b9d4d --- /dev/null +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -0,0 +1,283 @@ +//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, +//! we create an [`AutoDiffItem`] which contains the source and target function names. The source +//! is the function to which the autodiff attribute is applied, and the target is the function +//! getting generated by us (with a name given by the user as the first autodiff arg). + +use std::fmt::{self, Display, Formatter}; +use std::str::FromStr; + +use crate::expand::typetree::TypeTree; +use crate::expand::{Decodable, Encodable, HashStable_Generic}; +use crate::ptr::P; +use crate::{Ty, TyKind}; + +/// Forward and Reverse Mode are well known names for automatic differentiation implementations. +/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants +/// are a hack to support higher order derivatives. We need to compute first order derivatives +/// before we compute second order derivatives, otherwise we would differentiate our placeholder +/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations, +/// as it's already done in the C++ and Julia frontend of Enzyme. +/// +/// (FIXME) remove *First variants. +/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and +/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online. +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffMode { + /// No autodiff is applied (used during error handling). + Error, + /// The primal function which we will differentiate. + Source, + /// The target function, to be created using forward mode AD. + Forward, + /// The target function, to be created using reverse mode AD. + Reverse, + /// The target function, to be created using forward mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ForwardFirst, + /// The target function, to be created using reverse mode AD. + /// This target function will also be used as a source for higher order derivatives, + /// so compute it before all Forward/Reverse targets and optimize it through llvm. + ReverseFirst, +} + +/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. +/// However, under forward mode we overwrite the previous shadow value, while for reverse mode +/// we add to the previous shadow value. To not surprise users, we picked different names. +/// Dual numbers is also a quite well known name for forward mode AD types. +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum DiffActivity { + /// Implicit or Explicit () return type, so a special case of Const. + None, + /// Don't compute derivatives with respect to this input/output. + Const, + /// Reverse Mode, Compute derivatives for this scalar input/output. + Active, + /// Reverse Mode, Compute derivatives for this scalar output, but don't compute + /// the original return value. + ActiveOnly, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. + Dual, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. Drop the code which updates the original input/output for maximum performance. + DualOnly, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + Duplicated, + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. + /// Drop the code which updates the original input for maximum performance. + DuplicatedOnly, + /// All Integers must be Const, but these are used to mark the integer which represents the + /// length of a slice/vec. This is used for safety checks on slices. + FakeActivitySize, +} +/// We generate one of these structs for each `#[autodiff(...)]` attribute. +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + /// The name of the function getting differentiated + pub source: String, + /// The name of the function being generated + pub target: String, + pub attrs: AutoDiffAttrs, + /// Describe the memory layout of input types + pub inputs: Vec, + /// Describe the memory layout of the output type + pub output: TypeTree, +} +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffAttrs { + /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and + /// e.g. in the [JAX + /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl DiffMode { + pub fn is_rev(&self) -> bool { + matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst) + } + pub fn is_fwd(&self) -> bool { + matches!(self, DiffMode::Forward | DiffMode::ForwardFirst) + } +} + +impl Display for DiffMode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + DiffMode::Error => write!(f, "Error"), + DiffMode::Source => write!(f, "Source"), + DiffMode::Forward => write!(f, "Forward"), + DiffMode::Reverse => write!(f, "Reverse"), + DiffMode::ForwardFirst => write!(f, "ForwardFirst"), + DiffMode::ReverseFirst => write!(f, "ReverseFirst"), + } + } +} + +/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...). +/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...). +/// Const is valid for all cases and means that we don't compute derivatives wrt. this output. +/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg, +/// but this is too complex to verify here. Also it's just a logic error if users get this wrong. +pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { + if activity == DiffActivity::None { + // Only valid if primal returns (), but we can't check that here. + return true; + } + match mode { + DiffMode::Error => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + activity == DiffActivity::Dual + || activity == DiffActivity::DualOnly + || activity == DiffActivity::Const + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + activity == DiffActivity::Const + || activity == DiffActivity::Active + || activity == DiffActivity::ActiveOnly + } + } +} + +/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value +/// for the given argument, but we generally can't know the size of such a type. +/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated, +/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value +/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent +/// users here from marking scalars as Duplicated, due to type aliases. +pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { + use DiffActivity::*; + // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it. + if matches!(activity, Const) { + return true; + } + if matches!(activity, Dual | DualOnly) { + return true; + } + // FIXME(ZuseZ4) We should make this more robust to also + // handle type aliases. Once that is done, we can be more restrictive here. + if matches!(activity, Active | ActiveOnly) { + return true; + } + matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..)) + && matches!(activity, Duplicated | DuplicatedOnly) +} +pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { + use DiffActivity::*; + return match mode { + DiffMode::Error => false, + DiffMode::Source => false, + DiffMode::Forward | DiffMode::ForwardFirst => { + matches!(activity, Dual | DualOnly | Const) + } + DiffMode::Reverse | DiffMode::ReverseFirst => { + matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) + } + }; +} + +impl Display for DiffActivity { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DiffActivity::None => write!(f, "None"), + DiffActivity::Const => write!(f, "Const"), + DiffActivity::Active => write!(f, "Active"), + DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), + DiffActivity::Dual => write!(f, "Dual"), + DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::Duplicated => write!(f, "Duplicated"), + DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), + DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), + } + } +} + +impl FromStr for DiffMode { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "Error" => Ok(DiffMode::Error), + "Source" => Ok(DiffMode::Source), + "Forward" => Ok(DiffMode::Forward), + "Reverse" => Ok(DiffMode::Reverse), + "ForwardFirst" => Ok(DiffMode::ForwardFirst), + "ReverseFirst" => Ok(DiffMode::ReverseFirst), + _ => Err(()), + } + } +} +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "ActiveOnly" => Ok(DiffActivity::ActiveOnly), + "Const" => Ok(DiffActivity::Const), + "Dual" => Ok(DiffActivity::Dual), + "DualOnly" => Ok(DiffActivity::DualOnly), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), + _ => Err(()), + } + } +} + +impl AutoDiffAttrs { + pub fn has_ret_activity(&self) -> bool { + self.ret_activity != DiffActivity::None + } + pub fn has_active_only_ret(&self) -> bool { + self.ret_activity == DiffActivity::ActiveOnly + } + + pub fn error() -> Self { + AutoDiffAttrs { + mode: DiffMode::Error, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + pub fn source() -> Self { + AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + self.mode != DiffMode::Error + } + + pub fn is_source(&self) -> bool { + self.mode == DiffMode::Source + } + pub fn apply_autodiff(&self) -> bool { + !matches!(self.mode, DiffMode::Error | DiffMode::Source) + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +impl fmt::Display for AutoDiffItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Differentiating {} -> {}", self.source, self.target)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) + } +} diff --git a/compiler/rustc_ast/src/expand/mod.rs b/compiler/rustc_ast/src/expand/mod.rs index 13413281bc7cc..d259677e98e3d 100644 --- a/compiler/rustc_ast/src/expand/mod.rs +++ b/compiler/rustc_ast/src/expand/mod.rs @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident; use crate::MetaItem; pub mod allocator; +pub mod autodiff_attrs; +pub mod typetree; #[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)] pub struct StrippedCfgItem { diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs new file mode 100644 index 0000000000000..9a2dd2e85e0d6 --- /dev/null +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -0,0 +1,90 @@ +//! This module contains the definition of the `TypeTree` and `Type` structs. +//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff +//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the +//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to +//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in +//! the user code, but this is extremely slow and not even always sufficient. As such we lower some +//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to +//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means +//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it +//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere, +//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there +//! will be only pointers, if you dereference these new pointers they will point to array of floats. +//! Generally, it allows byte-specific descriptions. +//! FIXME: This description might be partly inaccurate and should be extended, along with +//! adding documentation to the corresponding Enzyme core code. +//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to +//! provide typetree information. +//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR +//! representations of some types might not be accurate. For example a vector of floats might be +//! represented as a vector of u8s in MIR in some cases. + +use std::fmt; + +use crate::expand::{Decodable, Encodable, HashStable_Generic}; + +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct TypeTree(pub Vec); + +impl TypeTree { + pub fn new() -> Self { + Self(Vec::new()) + } + pub fn all_ints() -> Self { + Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }]) + } + pub fn int(size: usize) -> Self { + let mut ints = Vec::with_capacity(size); + for i in 0..size { + ints.push(Type { + offset: i as isize, + size: 1, + kind: Kind::Integer, + child: TypeTree::new(), + }); + } + Self(ints) + } +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct FncTree { + pub args: Vec, + pub ret: TypeTree, +} + +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_builtin_macros/Cargo.toml b/compiler/rustc_builtin_macros/Cargo.toml index 21b87be4b81d2..ef48486f6f150 100644 --- a/compiler/rustc_builtin_macros/Cargo.toml +++ b/compiler/rustc_builtin_macros/Cargo.toml @@ -3,6 +3,10 @@ name = "rustc_builtin_macros" version = "0.0.0" edition = "2021" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] } + [lib] doctest = false diff --git a/compiler/rustc_builtin_macros/messages.ftl b/compiler/rustc_builtin_macros/messages.ftl index 77cb8dc63c484..6ebc2fd870cca 100644 --- a/compiler/rustc_builtin_macros/messages.ftl +++ b/compiler/rustc_builtin_macros/messages.ftl @@ -69,6 +69,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as builtin_macros_assert_requires_expression = macro requires an expression as an argument .suggestion = try removing semicolon +builtin_macros_autodiff = autodiff must be applied to function +builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode +builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse` +builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode +builtin_macros_autodiff_not_build = this rustc version does not support autodiff +builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found} +builtin_macros_autodiff_ty_activity = {$act} can not be used for this type + +builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}` builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s .label = not applicable here .label2 = not a `struct`, `enum` or `union` diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs new file mode 100644 index 0000000000000..66bb11ca52281 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -0,0 +1,820 @@ +//! This module contains the implementation of the `#[autodiff]` attribute. +//! Currently our linter isn't smart enough to see that each import is used in one of the two +//! configs (autodiff enabled or disabled), so we have to add cfg's to each import. +//! FIXME(ZuseZ4): Remove this once we have a smarter linter. + +#[cfg(llvm_enzyme)] +mod llvm_enzyme { + use std::str::FromStr; + use std::string::String; + + use rustc_ast::expand::autodiff_attrs::{ + AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity, + }; + use rustc_ast::ptr::P; + use rustc_ast::token::{Token, TokenKind}; + use rustc_ast::tokenstream::*; + use rustc_ast::visit::AssocCtxt::*; + use rustc_ast::{ + self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner, + PatKind, TyKind, + }; + use rustc_expand::base::{Annotatable, ExtCtxt}; + use rustc_span::symbol::{Ident, kw, sym}; + use rustc_span::{Span, Symbol}; + use thin_vec::{ThinVec, thin_vec}; + use tracing::{debug, trace}; + + use crate::errors; + + // If we have a default `()` return type or explicitley `()` return type, + // then we often can skip doing some work. + fn has_ret(ty: &FnRetTy) -> bool { + match ty { + FnRetTy::Ty(ty) => !ty.kind.is_unit(), + FnRetTy::Default(_) => false, + } + } + fn first_ident(x: &MetaItemInner) -> rustc_span::symbol::Ident { + let segments = &x.meta_item().unwrap().path.segments; + assert!(segments.len() == 1); + segments[0].ident + } + + fn name(x: &MetaItemInner) -> String { + first_ident(x).name.to_string() + } + + pub(crate) fn from_ast( + ecx: &mut ExtCtxt<'_>, + meta_item: &ThinVec, + has_ret: bool, + ) -> AutoDiffAttrs { + let dcx = ecx.sess.dcx(); + let mode = name(&meta_item[1]); + let Ok(mode) = DiffMode::from_str(&mode) else { + dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode }); + return AutoDiffAttrs::error(); + }; + let mut activities: Vec = vec![]; + let mut errors = false; + for x in &meta_item[2..] { + let activity_str = name(&x); + let res = DiffActivity::from_str(&activity_str); + match res { + Ok(x) => activities.push(x), + Err(_) => { + dcx.emit_err(errors::AutoDiffUnknownActivity { + span: x.span(), + act: activity_str, + }); + errors = true; + } + }; + } + if errors { + return AutoDiffAttrs::error(); + } + + // If a return type exist, we need to split the last activity, + // otherwise we return None as placeholder. + let (ret_activity, input_activity) = if has_ret { + let Some((last, rest)) = activities.split_last() else { + unreachable!( + "should not be reachable because we counted the number of activities previously" + ); + }; + (last, rest) + } else { + (&DiffActivity::None, activities.as_slice()) + }; + + AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() } + } + + /// We expand the autodiff macro to generate a new placeholder function which passes + /// type-checking and can be called by users. The function body of the placeholder function will + /// later be replaced on LLVM-IR level, so the design of the body is less important and for now + /// should just prevent early inlining and optimizations which alter the function signature. + /// The exact signature of the generated function depends on the configuration provided by the + /// user, but here is an example: + /// + /// ``` + /// #[autodiff(cos_box, Reverse, Duplicated, Active)] + /// fn sin(x: &Box) -> f32 { + /// f32::sin(**x) + /// } + /// ``` + /// which becomes expanded to: + /// ``` + /// #[rustc_autodiff] + /// #[inline(never)] + /// fn sin(x: &Box) -> f32 { + /// f32::sin(**x) + /// } + /// #[rustc_autodiff(Reverse, Duplicated, Active)] + /// #[inline(never)] + /// fn cos_box(x: &Box, dx: &mut Box, dret: f32) -> f32 { + /// unsafe { + /// asm!("NOP"); + /// }; + /// ::core::hint::black_box(sin(x)); + /// ::core::hint::black_box((dx, dret)); + /// ::core::hint::black_box(sin(x)) + /// } + /// ``` + /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked + /// in CI. + pub(crate) fn expand( + ecx: &mut ExtCtxt<'_>, + expand_span: Span, + meta_item: &ast::MetaItem, + mut item: Annotatable, + ) -> Vec { + let dcx = ecx.sess.dcx(); + // first get the annotable item: + let (sig, is_impl): (FnSig, bool) = match &item { + Annotatable::Item(ref iitem) => { + let sig = match &iitem.kind { + ItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), false) + } + Annotatable::AssocItem(ref assoc_item, _) => { + let sig = match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig, + _ => { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + (sig.clone(), true) + } + _ => { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + let meta_item_vec: ThinVec = match meta_item.kind { + ast::MetaItemKind::List(ref vec) => vec.clone(), + _ => { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + let has_ret = has_ret(&sig.decl.output); + let sig_span = ecx.with_call_site_ctxt(sig.span); + + let (vis, primal) = match &item { + Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()), + Annotatable::AssocItem(ref assoc_item, _) => { + (assoc_item.vis.clone(), assoc_item.ident.clone()) + } + _ => { + dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() }); + return vec![item]; + } + }; + + // create TokenStream from vec elemtents: + // meta_item doesn't have a .tokens field + let comma: Token = Token::new(TokenKind::Comma, Span::default()); + let mut ts: Vec = vec![]; + if meta_item_vec.len() < 2 { + // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no + // input and output args. + dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() }); + return vec![item]; + } else { + for t in meta_item_vec.clone()[1..].iter() { + let val = first_ident(t); + let t = Token::from_ast_ident(val); + ts.push(TokenTree::Token(t, Spacing::Joint)); + ts.push(TokenTree::Token(comma.clone(), Spacing::Alone)); + } + } + if !has_ret { + // We don't want users to provide a return activity if the function doesn't return anything. + // For simplicity, we just add a dummy token to the end of the list. + let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default()); + ts.push(TokenTree::Token(t, Spacing::Joint)); + } + let ts: TokenStream = TokenStream::from_iter(ts); + + let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret); + if !x.is_active() { + // We encountered an error, so we return the original item. + // This allows us to potentially parse other attributes. + return vec![item]; + } + let span = ecx.with_def_site_ctxt(expand_span); + + let n_active: u32 = x + .input_activity + .iter() + .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) + .count() as u32; + let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let new_decl_span = d_sig.span; + let d_body = gen_enzyme_body( + ecx, + &x, + n_active, + &sig, + &d_sig, + primal, + &new_args, + span, + sig_span, + new_decl_span, + idents, + errored, + ); + let d_ident = first_ident(&meta_item_vec[0]); + + // The first element of it is the name of the function to be generated + let asdf = Box::new(ast::Fn { + defaultness: ast::Defaultness::Final, + sig: d_sig, + generics: Generics::default(), + body: Some(d_body), + }); + let mut rustc_ad_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); + + let ts2: Vec = vec![TokenTree::Token( + Token::new(TokenKind::Ident(sym::never, false.into()), span), + Spacing::Joint, + )]; + let never_arg = ast::DelimArgs { + dspan: ast::tokenstream::DelimSpan::from_single(span), + delim: ast::token::Delimiter::Parenthesis, + tokens: ast::tokenstream::TokenStream::from_iter(ts2), + }; + let inline_item = ast::AttrItem { + unsafety: ast::Safety::Default, + path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)), + args: ast::AttrArgs::Delimited(never_arg), + tokens: None, + }; + let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None }); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let attr: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), + id: new_id, + style: ast::AttrStyle::Outer, + span, + }; + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let inline_never: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(inline_never_attr), + id: new_id, + style: ast::AttrStyle::Outer, + span, + }; + + // Don't add it multiple times: + let orig_annotatable: Annotatable = match item { + Annotatable::Item(ref mut iitem) => { + if !iitem.attrs.iter().any(|a| a.id == attr.id) { + iitem.attrs.push(attr.clone()); + } + if !iitem.attrs.iter().any(|a| a.id == inline_never.id) { + iitem.attrs.push(inline_never.clone()); + } + Annotatable::Item(iitem.clone()) + } + Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => { + if !assoc_item.attrs.iter().any(|a| a.id == attr.id) { + assoc_item.attrs.push(attr.clone()); + } + if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) { + assoc_item.attrs.push(inline_never.clone()); + } + Annotatable::AssocItem(assoc_item.clone(), i) + } + _ => { + unreachable!("annotatable kind checked previously") + } + }; + // Now update for d_fn + rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs { + dspan: DelimSpan::dummy(), + delim: rustc_ast::token::Delimiter::Parenthesis, + tokens: ts, + }); + let d_attr: ast::Attribute = ast::Attribute { + kind: ast::AttrKind::Normal(rustc_ad_attr.clone()), + id: new_id, + style: ast::AttrStyle::Outer, + span, + }; + + let d_annotatable = if is_impl { + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let d_fn = P(ast::AssocItem { + attrs: thin_vec![d_attr.clone(), inline_never], + id: ast::DUMMY_NODE_ID, + span, + vis, + ident: d_ident, + kind: assoc_item, + tokens: None, + }); + Annotatable::AssocItem(d_fn, Impl) + } else { + let mut d_fn = ecx.item( + span, + d_ident, + thin_vec![d_attr.clone(), inline_never], + ItemKind::Fn(asdf), + ); + d_fn.vis = vis; + Annotatable::Item(d_fn) + }; + + return vec![orig_annotatable, d_annotatable]; + } + + // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be + // mutable references or ptrs, because Enzyme will write into them. + fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty { + let mut ty = ty.clone(); + match ty.kind { + TyKind::Ptr(ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + TyKind::Ref(_, ref mut mut_ty) => { + mut_ty.mutbl = ast::Mutability::Mut; + } + _ => { + panic!("unsupported type: {:?}", ty); + } + } + ty + } + + /// We only want this function to type-check, since we will replace the body + /// later on llvm level. Using `loop {}` does not cover all return types anymore, + /// so instead we build something that should pass. We also add a inline_asm + /// line, as one more barrier for rustc to prevent inlining of this function. + /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see + /// , so this isn't sufficient. + /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate + /// this function (which should never happen, since it is only a placeholder). + /// Finally, we also add back_box usages of all input arguments, to prevent rustc + /// from optimizing any arguments away. + fn gen_enzyme_body( + ecx: &ExtCtxt<'_>, + x: &AutoDiffAttrs, + n_active: u32, + sig: &ast::FnSig, + d_sig: &ast::FnSig, + primal: Ident, + new_names: &[String], + span: Span, + sig_span: Span, + new_decl_span: Span, + idents: Vec, + errored: bool, + ) -> P { + let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); + let noop = ast::InlineAsm { + asm_macro: ast::AsmMacro::Asm, + template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], + template_strs: Box::new([]), + operands: vec![], + clobber_abis: vec![], + options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, + line_spans: vec![], + }; + let noop_expr = ecx.expr_asm(span, P(noop)); + let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); + let unsf_block = ast::Block { + stmts: thin_vec![ecx.stmt_semi(noop_expr)], + id: ast::DUMMY_NODE_ID, + tokens: None, + rules: unsf, + span, + could_be_bare_literal: false, + }; + let unsf_expr = ecx.expr_block(P(unsf_block)); + let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + let primal_call = gen_primal_call(ecx, span, primal, idents); + let black_box_primal_call = + ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ + primal_call.clone() + ]); + let tup_args = new_names + .iter() + .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) + .collect(); + + let black_box_remaining_args = + ecx.expr_call(sig_span, blackbox_call_expr.clone(), thin_vec![ + ecx.expr_tuple(sig_span, tup_args) + ]); + + let mut body = ecx.block(span, ThinVec::new()); + body.stmts.push(ecx.stmt_semi(unsf_expr)); + + // This uses primal args which won't be available if we errored before + if !errored { + body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); + } + body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); + + if !has_ret(&d_sig.decl.output) { + // there is no return type that we have to match, () works fine. + return body; + } + + // having an active-only return means we'll drop the original return type. + // So that can be treated identical to not having one in the first place. + let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); + + if primal_ret && n_active == 0 && x.mode.is_rev() { + // We only have the primal ret. + body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone())); + return body; + } + + if !primal_ret && n_active == 1 { + // Again no tuple return, so return default float val. + let ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let arg = ty.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, kw::Default]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + body.stmts.push(ecx.stmt_expr(default_call_expr)); + return body; + } + + let mut exprs = ThinVec::>::new(); + if primal_ret { + // We have both primal ret and active floats. + // primal ret is first, by construction. + exprs.push(primal_call.clone()); + } + + // Now construct default placeholder for each active float. + // Is there something nicer than f32::default() and f64::default()? + let d_ret_ty = match d_sig.decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let mut d_ret_ty = match d_ret_ty.kind.clone() { + TyKind::Tup(ref tys) => tys.clone(), + TyKind::Path(_, rustc_ast::Path { segments, .. }) => { + if let [segment] = &segments[..] + && segment.args.is_none() + { + let id = vec![segments[0].ident]; + let kind = TyKind::Path(None, ecx.path(span, id)); + let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }); + thin_vec![ty] + } else { + panic!("Expected tuple or simple path return type"); + } + } + _ => { + // We messed up construction of d_sig + panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty); + } + }; + + if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual { + assert!(d_ret_ty.len() == 2); + // both should be identical, by construction + let arg = d_ret_ty[0].kind.is_simple_path().unwrap(); + let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap(); + assert!(arg == arg2); + let sl: Vec = vec![arg, kw::Default]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + } else if x.mode.is_rev() { + if primal_ret { + // We have extra handling above for the primal ret + d_ret_ty = d_ret_ty[1..].to_vec().into(); + } + + for arg in d_ret_ty.iter() { + let arg = arg.kind.is_simple_path().unwrap(); + let sl: Vec = vec![arg, kw::Default]; + let tmp = ecx.def_site_path(&sl); + let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); + let default_call_expr = + ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); + exprs.push(default_call_expr); + } + } + + let ret: P; + match &exprs[..] { + [] => { + assert!(!has_ret(&d_sig.decl.output)); + // We don't have to match the return type. + return body; + } + [arg] => { + ret = ecx + .expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![arg.clone()]); + } + args => { + let ret_tuple: P = ecx.expr_tuple(span, args.into()); + ret = + ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]); + } + } + assert!(has_ret(&d_sig.decl.output)); + body.stmts.push(ecx.stmt_expr(ret)); + + body + } + + fn gen_primal_call( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + idents: Vec, + ) -> P { + let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; + if has_self { + let args: ThinVec<_> = + idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let self_expr = ecx.expr_self(span); + ecx.expr_method_call(span, self_expr, primal, args.clone()) + } else { + let args: ThinVec<_> = + idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); + let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal)); + ecx.expr_call(span, primal_call_expr, args) + } + } + + // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must + // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. + // Active arguments must be scalars. Their shadow argument is added to the return type (and will be + // zero-initialized by Enzyme). + // Each argument of the primal function (and the return type if existing) must be annotated with an + // activity. + // + // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or + // both), we emit an error and return the original signature. This allows us to continue parsing. + fn gen_enzyme_decl( + ecx: &ExtCtxt<'_>, + sig: &ast::FnSig, + x: &AutoDiffAttrs, + span: Span, + ) -> (ast::FnSig, Vec, Vec, bool) { + let dcx = ecx.sess.dcx(); + let has_ret = has_ret(&sig.decl.output); + let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; + let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 }; + if sig_args != num_activities { + dcx.emit_err(errors::AutoDiffInvalidNumberActivities { + span, + expected: sig_args, + found: num_activities, + }); + // This is not the right signature, but we can continue parsing. + return (sig.clone(), vec![], vec![], true); + } + assert!(sig.decl.inputs.len() == x.input_activity.len()); + assert!(has_ret == x.has_ret_activity()); + let mut d_decl = sig.decl.clone(); + let mut d_inputs = Vec::new(); + let mut new_inputs = Vec::new(); + let mut idents = Vec::new(); + let mut act_ret = ThinVec::new(); + + // We have two loops, a first one just to check the activities and types and possibly report + // multiple errors in one compilation session. + let mut errors = false; + for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { + if !valid_input_activity(x.mode, *activity) { + dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct { + span, + mode: x.mode.to_string(), + act: activity.to_string(), + }); + errors = true; + } + if !valid_ty_for_activity(&arg.ty, *activity) { + dcx.emit_err(errors::AutoDiffInvalidTypeForActivity { + span: arg.ty.span, + act: activity.to_string(), + }); + errors = true; + } + } + if errors { + // This is not the right signature, but we can continue parsing. + return (sig.clone(), new_inputs, idents, true); + } + let unsafe_activities = x + .input_activity + .iter() + .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly)); + for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) { + d_inputs.push(arg.clone()); + match activity { + DiffActivity::Active => { + act_ret.push(arg.ty.clone()); + } + DiffActivity::ActiveOnly => { + // We will add the active scalar to the return type. + // This is handled later. + } + DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => { + let mut shadow_arg = arg.clone(); + // We += into the shadow in reverse mode. + shadow_arg.ty = P(assure_mut_ref(&arg.ty)); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + debug!("{:#?}", &shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("d{}", old_name); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingMode::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg); + } + DiffActivity::Dual | DiffActivity::DualOnly => { + let mut shadow_arg = arg.clone(); + let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { + ident.name + } else { + debug!("{:#?}", &shadow_arg.pat); + panic!("not an ident?"); + }; + let name: String = format!("b{}", old_name); + new_inputs.push(name.clone()); + let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span); + shadow_arg.pat = P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingMode::NONE, ident, None), + span: shadow_arg.pat.span, + tokens: shadow_arg.pat.tokens.clone(), + }); + d_inputs.push(shadow_arg); + } + DiffActivity::Const => { + // Nothing to do here. + } + DiffActivity::None | DiffActivity::FakeActivitySize => { + panic!("Should not happen"); + } + } + if let PatKind::Ident(_, ident, _) = arg.pat.kind { + idents.push(ident.clone()); + } else { + panic!("not an ident?"); + } + } + + let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly; + if active_only_ret { + assert!(x.mode.is_rev()); + } + + // If we return a scalar in the primal and the scalar is active, + // then add it as last arg to the inputs. + if x.mode.is_rev() { + match x.ret_activity { + DiffActivity::Active | DiffActivity::ActiveOnly => { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + let name = "dret".to_string(); + let ident = Ident::from_str_and_span(&name, ty.span); + let shadow_arg = ast::Param { + attrs: ThinVec::new(), + ty: ty.clone(), + pat: P(ast::Pat { + id: ast::DUMMY_NODE_ID, + kind: PatKind::Ident(BindingMode::NONE, ident, None), + span: ty.span, + tokens: None, + }), + id: ast::DUMMY_NODE_ID, + span: ty.span, + is_placeholder: false, + }; + d_inputs.push(shadow_arg); + new_inputs.push(name); + } + _ => {} + } + } + d_decl.inputs = d_inputs.into(); + + if x.mode.is_fwd() { + if let DiffActivity::Dual = x.ret_activity { + let ty = match d_decl.output { + FnRetTy::Ty(ref ty) => ty.clone(), + FnRetTy::Default(span) => { + panic!("Did not expect Default ret ty: {:?}", span); + } + }; + // Dual can only be used for f32/f64 ret. + // In that case we return now a tuple with two floats. + let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]); + let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); + d_decl.output = FnRetTy::Ty(ty); + } + if let DiffActivity::DualOnly = x.ret_activity { + // No need to change the return type, + // we will just return the shadow in place + // of the primal return. + } + } + + // If we use ActiveOnly, drop the original return value. + d_decl.output = + if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() }; + + trace!("act_ret: {:?}", act_ret); + + // If we have an active input scalar, add it's gradient to the + // return type. This might require changing the return type to a + // tuple. + if act_ret.len() > 0 { + let ret_ty = match d_decl.output { + FnRetTy::Ty(ref ty) => { + if !active_only_ret { + act_ret.insert(0, ty.clone()); + } + let kind = TyKind::Tup(act_ret); + P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }) + } + FnRetTy::Default(span) => { + if act_ret.len() == 1 { + act_ret[0].clone() + } else { + let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect()); + P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None }) + } + } + }; + d_decl.output = FnRetTy::Ty(ret_ty); + } + + let mut d_header = sig.header.clone(); + if unsafe_activities { + d_header.safety = rustc_ast::Safety::Unsafe(span); + } + let d_sig = FnSig { header: d_header, decl: d_decl, span }; + trace!("Generated signature: {:?}", d_sig); + (d_sig, new_inputs, idents, false) + } +} + +#[cfg(not(llvm_enzyme))] +mod ad_fallback { + use rustc_ast::ast; + use rustc_expand::base::{Annotatable, ExtCtxt}; + use rustc_span::Span; + + use crate::errors; + pub(crate) fn expand( + ecx: &mut ExtCtxt<'_>, + _expand_span: Span, + meta_item: &ast::MetaItem, + item: Annotatable, + ) -> Vec { + ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span }); + return vec![item]; + } +} + +#[cfg(not(llvm_enzyme))] +pub(crate) use ad_fallback::expand; +#[cfg(llvm_enzyme)] +pub(crate) use llvm_enzyme::expand; diff --git a/compiler/rustc_builtin_macros/src/errors.rs b/compiler/rustc_builtin_macros/src/errors.rs index 639c2aa231cb1..f8e65661e52e2 100644 --- a/compiler/rustc_builtin_macros/src/errors.rs +++ b/compiler/rustc_builtin_macros/src/errors.rs @@ -145,6 +145,78 @@ pub(crate) struct AllocMustStatics { pub(crate) span: Span, } +#[cfg(llvm_enzyme)] +pub(crate) use autodiff::*; + +#[cfg(llvm_enzyme)] +mod autodiff { + use super::*; + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_missing_config)] + pub(crate) struct AutoDiffMissingConfig { + #[primary_span] + pub(crate) span: Span, + } + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_unknown_activity)] + pub(crate) struct AutoDiffUnknownActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, + } + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_ty_activity)] + pub(crate) struct AutoDiffInvalidTypeForActivity { + #[primary_span] + pub(crate) span: Span, + pub(crate) act: String, + } + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_number_activities)] + pub(crate) struct AutoDiffInvalidNumberActivities { + #[primary_span] + pub(crate) span: Span, + pub(crate) expected: usize, + pub(crate) found: usize, + } + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_mode_activity)] + pub(crate) struct AutoDiffInvalidApplicationModeAct { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, + pub(crate) act: String, + } + + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_mode)] + pub(crate) struct AutoDiffInvalidMode { + #[primary_span] + pub(crate) span: Span, + pub(crate) mode: String, + } + + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff)] + pub(crate) struct AutoDiffInvalidApplication { + #[primary_span] + pub(crate) span: Span, + } +} + +#[cfg(not(llvm_enzyme))] +pub(crate) use ad_fallback::*; +#[cfg(not(llvm_enzyme))] +mod ad_fallback { + use super::*; + #[derive(Diagnostic)] + #[diag(builtin_macros_autodiff_not_build)] + pub(crate) struct AutoDiffSupportNotBuild { + #[primary_span] + pub(crate) span: Span, + } +} + #[derive(Diagnostic)] #[diag(builtin_macros_concat_bytes_invalid)] pub(crate) struct ConcatBytesInvalid { diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index ebe5e2b544292..377d7f542cf46 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -5,6 +5,7 @@ #![allow(internal_features)] #![allow(rustc::diagnostic_outside_of_impl)] #![allow(rustc::untranslatable_diagnostic)] +#![cfg_attr(not(bootstrap), feature(autodiff))] #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")] #![doc(rust_logo)] #![feature(assert_matches)] @@ -29,6 +30,7 @@ use crate::deriving::*; mod alloc_error_handler; mod assert; +mod autodiff; mod cfg; mod cfg_accessible; mod cfg_eval; @@ -106,6 +108,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { register_attr! { alloc_error_handler: alloc_error_handler::expand, + autodiff: autodiff::expand, bench: test::expand_bench, cfg_accessible: cfg_accessible::Expander, cfg_eval: cfg_eval::expand, diff --git a/compiler/rustc_expand/src/build.rs b/compiler/rustc_expand/src/build.rs index b5945759d43a0..743a9854f7964 100644 --- a/compiler/rustc_expand/src/build.rs +++ b/compiler/rustc_expand/src/build.rs @@ -220,6 +220,10 @@ impl<'a> ExtCtxt<'a> { self.stmt_local(local, span) } + pub fn stmt_semi(&self, expr: P) -> ast::Stmt { + ast::Stmt { id: ast::DUMMY_NODE_ID, span: expr.span, kind: ast::StmtKind::Semi(expr) } + } + pub fn stmt_local(&self, local: P, span: Span) -> ast::Stmt { ast::Stmt { id: ast::DUMMY_NODE_ID, kind: ast::StmtKind::Let(local), span } } @@ -287,6 +291,25 @@ impl<'a> ExtCtxt<'a> { self.expr(sp, ast::ExprKind::Paren(e)) } + pub fn expr_method_call( + &self, + span: Span, + expr: P, + ident: Ident, + args: ThinVec>, + ) -> P { + let seg = ast::PathSegment::from_ident(ident); + self.expr( + span, + ast::ExprKind::MethodCall(Box::new(ast::MethodCall { + seg, + receiver: expr, + args, + span, + })), + ) + } + pub fn expr_call( &self, span: Span, @@ -295,6 +318,12 @@ impl<'a> ExtCtxt<'a> { ) -> P { self.expr(span, ast::ExprKind::Call(expr, args)) } + pub fn expr_loop(&self, sp: Span, block: P) -> P { + self.expr(sp, ast::ExprKind::Loop(block, None, sp)) + } + pub fn expr_asm(&self, sp: Span, expr: P) -> P { + self.expr(sp, ast::ExprKind::InlineAsm(expr)) + } pub fn expr_call_ident( &self, span: Span, diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 17827b4e43b37..477760a459757 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -752,6 +752,11 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ template!(NameValueStr: "transparent|semitransparent|opaque"), ErrorFollowing, EncodeCrossCrate::Yes, "used internally for testing macro hygiene", ), + rustc_attr!( + rustc_autodiff, Normal, + template!(Word, List: r#""...""#), DuplicatesOk, + EncodeCrossCrate::No, INTERNAL_UNSTABLE + ), // ========================================================================== // Internal attributes, Diagnostics related: diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl index c11c38500345a..e5a14f6a15658 100644 --- a/compiler/rustc_passes/messages.ftl +++ b/compiler/rustc_passes/messages.ftl @@ -49,6 +49,10 @@ passes_attr_crate_level = passes_attr_only_in_functions = `{$attr}` attribute can only be used on functions +passes_autodiff_attr = + `#[autodiff]` should be applied to a function + .label = not a function + passes_both_ffi_const_and_pure = `#[ffi_const]` function cannot be `#[ffi_pure]` diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 44a62383e6eed..7ce29260e3676 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -243,6 +243,9 @@ impl<'tcx> CheckAttrVisitor<'tcx> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + [sym::autodiff, ..] => { + self.check_autodiff(hir_id, attr, span, target) + } [sym::coroutine, ..] => { self.check_coroutine(attr, target); } @@ -2345,6 +2348,18 @@ impl<'tcx> CheckAttrVisitor<'tcx> { self.dcx().emit_err(errors::RustcPubTransparent { span, attr_span }); } } + + /// Checks if `#[autodiff]` is applied to an item other than a function item. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { + debug!("check_autodiff"); + match target { + Target::Fn => {} + _ => { + self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span }); + self.abort.set(true); + } + } + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index f9186d3089ab3..6dc3dfba58f06 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -20,6 +20,14 @@ use crate::lang_items::Duplicate; #[diag(passes_incorrect_do_not_recommend_location)] pub(crate) struct IncorrectDoNotRecommendLocation; +#[derive(Diagnostic)] +#[diag(passes_autodiff_attr)] +pub(crate) struct AutoDiffAttr { + #[primary_span] + #[label] + pub attr_span: Span, +} + #[derive(LintDiagnostic)] #[diag(passes_outer_crate_level_attr)] pub(crate) struct OuterCrateLevelAttr; diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index cc3bda99a117b..6f62b4f82d76e 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -481,6 +481,8 @@ symbols! { audit_that, augmented_assignments, auto_traits, + autodiff, + autodiff_fallback, automatically_derived, avx, avx512_target_feature, @@ -544,6 +546,7 @@ symbols! { cfg_accessible, cfg_attr, cfg_attr_multi, + cfg_autodiff_fallback, cfg_boolean_literals, cfg_doctest, cfg_eval, @@ -1002,6 +1005,7 @@ symbols! { hashset_iter_ty, hexagon_target_feature, hidden, + hint, homogeneous_aggregate, host, html_favicon_url, @@ -1654,6 +1658,7 @@ symbols! { rustc_allow_incoherent_impl, rustc_allowed_through_unstable_modules, rustc_attrs, + rustc_autodiff, rustc_box, rustc_builtin_macro, rustc_capture_analysis, diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs index b21618e28a48e..6c2859051ff97 100644 --- a/library/core/src/lib.rs +++ b/library/core/src/lib.rs @@ -273,6 +273,15 @@ pub mod assert_matches { pub use crate::macros::{assert_matches, debug_assert_matches}; } +// We don't export this through #[macro_export] for now, to avoid breakage. +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "124509")] +/// Unstable module containing the unstable `autodiff` macro. +pub mod autodiff { + #[unstable(feature = "autodiff", issue = "124509")] + pub use crate::macros::builtin::autodiff; +} + #[unstable(feature = "cfg_match", issue = "115585")] pub use crate::macros::cfg_match; diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index aa0646846e43e..b5e5b58f7051f 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1539,6 +1539,24 @@ pub(crate) mod builtin { ($file:expr $(,)?) => {{ /* compiler built-in */ }}; } + /// Automatic Differentiation macro which allows generating a new function to compute + /// the derivative of a given function. It may only be applied to a function. + /// The expected usage syntax is + /// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]` + /// where: + /// NAME is a string that represents a valid function name. + /// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst. + /// INPUT_ACTIVITIES consists of one valid activity for each input parameter. + /// OUTPUT_ACTIVITY must not be set if we implicitely return nothing (or explicitely return + /// `-> ()`. Otherwise it must be set to one of the allowed activities. + #[unstable(feature = "autodiff", issue = "124509")] + #[allow_internal_unstable(rustc_attrs)] + #[rustc_builtin_macro] + #[cfg(not(bootstrap))] + pub macro autodiff($item:item) { + /* compiler built-in */ + } + /// Asserts that a boolean expression is `true` at runtime. /// /// This will invoke the [`panic!`] macro if the provided expression cannot be diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 65a9aa66c7cc6..35ed761759bd7 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -267,6 +267,7 @@ #![allow(unused_features)] // // Features: +#![cfg_attr(not(bootstrap), feature(autodiff))] #![cfg_attr(test, feature(internal_output_capture, print_internals, update_panic_count, rt))] #![cfg_attr( all(target_vendor = "fortanix", target_env = "sgx"), @@ -627,7 +628,13 @@ pub mod simd { #[doc(inline)] pub use crate::std_float::StdFloat; } - +#[cfg(not(bootstrap))] +#[unstable(feature = "autodiff", issue = "124509")] +/// This module provides support for automatic differentiation. +pub mod autodiff { + /// This macro handles automatic differentiation. + pub use core::autodiff::autodiff; +} #[stable(feature = "futures_api", since = "1.36.0")] pub mod task { //! Types and Traits for working with asynchronous tasks. diff --git a/tests/pretty/autodiff_forward.pp b/tests/pretty/autodiff_forward.pp new file mode 100644 index 0000000000000..23c3b5b34a82a --- /dev/null +++ b/tests/pretty/autodiff_forward.pp @@ -0,0 +1,107 @@ +#![feature(prelude_import)] +#![no_std] +//@ needs-enzyme + +#![feature(autodiff)] +#[prelude_import] +use ::std::prelude::rust_2015::*; +#[macro_use] +extern crate std; +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:autodiff_forward.pp + +// Test that forward mode ad macros are expanded correctly. + +use std::autodiff::autodiff; + +#[rustc_autodiff] +#[inline(never)] +pub fn f1(x: &[f64], y: f64) -> f64 { + + + + // Not the most interesting derivative, but who are we to judge + + // We want to be sure that the same function can be differentiated in different ways + + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(Forward, Dual, Const, Dual,)] +#[inline(never)] +pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f1(x, y)); + ::core::hint::black_box((bx,)); + ::core::hint::black_box((f1(x, y), f64::default())) +} +#[rustc_autodiff] +#[inline(never)] +pub fn f2(x: &[f64], y: f64) -> f64 { + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(Forward, Dual, Const, Const,)] +#[inline(never)] +pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f2(x, y)); + ::core::hint::black_box((bx,)); + ::core::hint::black_box(f2(x, y)) +} +#[rustc_autodiff] +#[inline(never)] +pub fn f3(x: &[f64], y: f64) -> f64 { + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(ForwardFirst, Dual, Const, Const,)] +#[inline(never)] +pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f3(x, y)); + ::core::hint::black_box((bx,)); + ::core::hint::black_box(f3(x, y)) +} +#[rustc_autodiff] +#[inline(never)] +pub fn f4() {} +#[rustc_autodiff(Forward, None)] +#[inline(never)] +pub fn df4() { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f4()); + ::core::hint::black_box(()); +} +#[rustc_autodiff] +#[inline(never)] +#[rustc_autodiff] +#[inline(never)] +#[rustc_autodiff] +#[inline(never)] +pub fn f5(x: &[f64], y: f64) -> f64 { + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(Forward, Const, Dual, Const,)] +#[inline(never)] +pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f5(x, y)); + ::core::hint::black_box((by,)); + ::core::hint::black_box(f5(x, y)) +} +#[rustc_autodiff(Forward, Dual, Const, Const,)] +#[inline(never)] +pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f5(x, y)); + ::core::hint::black_box((bx,)); + ::core::hint::black_box(f5(x, y)) +} +#[rustc_autodiff(Reverse, Duplicated, Const, Active,)] +#[inline(never)] +pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f5(x, y)); + ::core::hint::black_box((dx, dret)); + ::core::hint::black_box(f5(x, y)) +} +fn main() {} diff --git a/tests/pretty/autodiff_forward.rs b/tests/pretty/autodiff_forward.rs new file mode 100644 index 0000000000000..35108d0d6f111 --- /dev/null +++ b/tests/pretty/autodiff_forward.rs @@ -0,0 +1,39 @@ +//@ needs-enzyme + +#![feature(autodiff)] +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:autodiff_forward.pp + +// Test that forward mode ad macros are expanded correctly. + +use std::autodiff::autodiff; + +#[autodiff(df1, Forward, Dual, Const, Dual)] +pub fn f1(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +#[autodiff(df2, Forward, Dual, Const, Const)] +pub fn f2(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +#[autodiff(df3, ForwardFirst, Dual, Const, Const)] +pub fn f3(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +// Not the most interesting derivative, but who are we to judge +#[autodiff(df4, Forward)] +pub fn f4() {} + +// We want to be sure that the same function can be differentiated in different ways +#[autodiff(df5_rev, Reverse, Duplicated, Const, Active)] +#[autodiff(df5_x, Forward, Dual, Const, Const)] +#[autodiff(df5_y, Forward, Const, Dual, Const)] +pub fn f5(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +fn main() {} diff --git a/tests/pretty/autodiff_reverse.pp b/tests/pretty/autodiff_reverse.pp new file mode 100644 index 0000000000000..a98d3782c7034 --- /dev/null +++ b/tests/pretty/autodiff_reverse.pp @@ -0,0 +1,86 @@ +#![feature(prelude_import)] +#![no_std] +//@ needs-enzyme + +#![feature(autodiff)] +#[prelude_import] +use ::std::prelude::rust_2015::*; +#[macro_use] +extern crate std; +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:autodiff_reverse.pp + +// Test that reverse mode ad macros are expanded correctly. + +use std::autodiff::autodiff; + +#[rustc_autodiff] +#[inline(never)] +pub fn f1(x: &[f64], y: f64) -> f64 { + + // Not the most interesting derivative, but who are we to judge + + + // What happens if we already have Reverse in type (enum variant decl) and value (enum variant + // constructor) namespace? > It's expected to work normally. + + + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(Reverse, Duplicated, Const, Active,)] +#[inline(never)] +pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f1(x, y)); + ::core::hint::black_box((dx, dret)); + ::core::hint::black_box(f1(x, y)) +} +#[rustc_autodiff] +#[inline(never)] +pub fn f2() {} +#[rustc_autodiff(Reverse, None)] +#[inline(never)] +pub fn df2() { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f2()); + ::core::hint::black_box(()); +} +#[rustc_autodiff] +#[inline(never)] +pub fn f3(x: &[f64], y: f64) -> f64 { + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(ReverseFirst, Duplicated, Const, Active,)] +#[inline(never)] +pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f3(x, y)); + ::core::hint::black_box((dx, dret)); + ::core::hint::black_box(f3(x, y)) +} +enum Foo { Reverse, } +use Foo::Reverse; +#[rustc_autodiff] +#[inline(never)] +pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } +#[rustc_autodiff(Reverse, Const, None)] +#[inline(never)] +pub fn df4(x: f32) { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f4(x)); + ::core::hint::black_box(()); +} +#[rustc_autodiff] +#[inline(never)] +pub fn f5(x: *const f32, y: &f32) { + ::core::panicking::panic("not implemented") +} +#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)] +#[inline(never)] +pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) { + unsafe { asm!("NOP", options(pure, nomem)); }; + ::core::hint::black_box(f5(x, y)); + ::core::hint::black_box((dx, dy)); +} +fn main() {} diff --git a/tests/pretty/autodiff_reverse.rs b/tests/pretty/autodiff_reverse.rs new file mode 100644 index 0000000000000..657201caa9401 --- /dev/null +++ b/tests/pretty/autodiff_reverse.rs @@ -0,0 +1,40 @@ +//@ needs-enzyme + +#![feature(autodiff)] +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:autodiff_reverse.pp + +// Test that reverse mode ad macros are expanded correctly. + +use std::autodiff::autodiff; + +#[autodiff(df1, Reverse, Duplicated, Const, Active)] +pub fn f1(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +// Not the most interesting derivative, but who are we to judge +#[autodiff(df2, Reverse)] +pub fn f2() {} + +#[autodiff(df3, ReverseFirst, Duplicated, Const, Active)] +pub fn f3(x: &[f64], y: f64) -> f64 { + unimplemented!() +} + +enum Foo { Reverse } +use Foo::Reverse; +// What happens if we already have Reverse in type (enum variant decl) and value (enum variant +// constructor) namespace? > It's expected to work normally. +#[autodiff(df4, Reverse, Const)] +pub fn f4(x: f32) { + unimplemented!() +} + +#[autodiff(df5, Reverse, DuplicatedOnly, Duplicated)] +pub fn f5(x: *const f32, y: &f32) { + unimplemented!() +} + +fn main() {} diff --git a/tests/ui/autodiff/autodiff_illegal.rs b/tests/ui/autodiff/autodiff_illegal.rs new file mode 100644 index 0000000000000..c0548d2bbb8ff --- /dev/null +++ b/tests/ui/autodiff/autodiff_illegal.rs @@ -0,0 +1,160 @@ +//@ needs-enzyme + +#![feature(autodiff)] +//@ pretty-mode:expanded +//@ pretty-compare-only +//@ pp-exact:autodiff_illegal.pp + +// Test that invalid ad macros give nice errors and don't ICE. + +use std::autodiff::autodiff; + +// We can't use Duplicated on scalars +#[autodiff(df1, Reverse, Duplicated)] +pub fn f1(x: f64) { +//~^ ERROR Duplicated can not be used for this type + unimplemented!() +} + +// Too many activities +#[autodiff(df3, Reverse, Duplicated, Const)] +pub fn f3(x: f64) { +//~^^ ERROR expected 1 activities, but found 2 + unimplemented!() +} + +// To few activities +#[autodiff(df4, Reverse)] +pub fn f4(x: f64) { +//~^^ ERROR expected 1 activities, but found 0 + unimplemented!() +} + +// We can't use Dual in Reverse mode +#[autodiff(df5, Reverse, Dual)] +pub fn f5(x: f64) { +//~^^ ERROR Dual can not be used in Reverse Mode + unimplemented!() +} + +// We can't use Duplicated in Forward mode +#[autodiff(df6, Forward, Duplicated)] +pub fn f6(x: f64) { +//~^^ ERROR Duplicated can not be used in Forward Mode +//~^^ ERROR Duplicated can not be used for this type + unimplemented!() +} + +fn dummy() { + + #[autodiff(df7, Forward, Dual)] + let mut x = 5; + //~^ ERROR autodiff must be applied to function + + #[autodiff(df7, Forward, Dual)] + x = x + 3; + //~^^ ERROR attributes on expressions are experimental [E0658] + //~^^ ERROR autodiff must be applied to function + + #[autodiff(df7, Forward, Dual)] + let add_one_v2 = |x: u32| -> u32 { x + 1 }; + //~^ ERROR autodiff must be applied to function +} + +// Malformed, where args? +#[autodiff] +pub fn f7(x: f64) { +//~^ ERROR autodiff must be applied to function + unimplemented!() +} + +// Malformed, where args? +#[autodiff()] +pub fn f8(x: f64) { +//~^ ERROR autodiff requires at least a name and mode + unimplemented!() +} + +// Invalid attribute syntax +#[autodiff = ""] +pub fn f9(x: f64) { +//~^ ERROR autodiff must be applied to function + unimplemented!() +} + +fn fn_exists() {} + +// We colide with an already existing function +#[autodiff(fn_exists, Reverse, Active)] +pub fn f10(x: f64) { +//~^^ ERROR the name `fn_exists` is defined multiple times [E0428] + unimplemented!() +} + +// Malformed, missing a mode +#[autodiff(df11)] +pub fn f11() { +//~^ ERROR autodiff requires at least a name and mode + unimplemented!() +} + +// Invalid Mode +#[autodiff(df12, Debug)] +pub fn f12() { +//~^^ ERROR unknown Mode: `Debug`. Use `Forward` or `Reverse` + unimplemented!() +} + +// Invalid, please pick one Mode +// or use two autodiff macros. +#[autodiff(df13, Forward, Reverse)] +pub fn f13() { +//~^^ ERROR did not recognize Activity: `Reverse` + unimplemented!() +} + +struct Foo {} + +// We can't handle Active structs, because that would mean (in the general case), that we would +// need to allocate and initialize arbitrary user types. We have Duplicated/Dual input args for +// that. FIXME: Give a nicer error and suggest to the user to have a `&mut Foo` input instead. +#[autodiff(df14, Reverse, Active, Active)] +fn f14(x: f32) -> Foo { + unimplemented!() +} + +type MyFloat = f32; + +// We would like to support type alias to f32/f64 in argument type in the future, +// but that requires us to implement our checks at a later stage +// like THIR which has type information available. +#[autodiff(df15, Reverse, Active, Active)] +fn f15(x: MyFloat) -> f32 { +//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433] + unimplemented!() +} + +// We would like to support type alias to f32/f64 in return type in the future +#[autodiff(df16, Reverse, Active, Active)] +fn f16(x: f32) -> MyFloat { + unimplemented!() +} + +#[repr(transparent)] +struct F64Trans { inner: f64 } + +// We would like to support `#[repr(transparent)]` f32/f64 wrapper in return type in the future +#[autodiff(df17, Reverse, Active, Active)] +fn f17(x: f64) -> F64Trans { + unimplemented!() +} + +// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future +#[autodiff(df18, Reverse, Active, Active)] +fn f18(x: F64Trans) -> f64 { + //~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433] + unimplemented!() +} + + +fn main() {} diff --git a/tests/ui/autodiff/autodiff_illegal.stderr b/tests/ui/autodiff/autodiff_illegal.stderr new file mode 100644 index 0000000000000..3a7242b2f5d95 --- /dev/null +++ b/tests/ui/autodiff/autodiff_illegal.stderr @@ -0,0 +1,152 @@ +error[E0658]: attributes on expressions are experimental + --> $DIR/autodiff_illegal.rs:54:5 + | +LL | #[autodiff(df7, Forward, Dual)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: see issue #15701 for more information + = help: add `#![feature(stmt_expr_attributes)]` to the crate attributes to enable + = note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date + +error: Duplicated can not be used for this type + --> $DIR/autodiff_illegal.rs:14:14 + | +LL | pub fn f1(x: f64) { + | ^^^ + +error: expected 1 activities, but found 2 + --> $DIR/autodiff_illegal.rs:20:1 + | +LL | #[autodiff(df3, Reverse, Duplicated, Const)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: expected 1 activities, but found 0 + --> $DIR/autodiff_illegal.rs:27:1 + | +LL | #[autodiff(df4, Reverse)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: Dual can not be used in Reverse Mode + --> $DIR/autodiff_illegal.rs:34:1 + | +LL | #[autodiff(df5, Reverse, Dual)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: Duplicated can not be used in Forward Mode + --> $DIR/autodiff_illegal.rs:41:1 + | +LL | #[autodiff(df6, Forward, Duplicated)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: Duplicated can not be used for this type + --> $DIR/autodiff_illegal.rs:42:14 + | +LL | pub fn f6(x: f64) { + | ^^^ + +error: autodiff must be applied to function + --> $DIR/autodiff_illegal.rs:51:5 + | +LL | let mut x = 5; + | ^^^^^^^^^^^^^^ + +error: autodiff must be applied to function + --> $DIR/autodiff_illegal.rs:55:5 + | +LL | x = x + 3; + | ^ + +error: autodiff must be applied to function + --> $DIR/autodiff_illegal.rs:60:5 + | +LL | let add_one_v2 = |x: u32| -> u32 { x + 1 }; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: autodiff must be applied to function + --> $DIR/autodiff_illegal.rs:66:1 + | +LL | / pub fn f7(x: f64) { +LL | | +LL | | unimplemented!() +LL | | } + | |_^ + +error: autodiff requires at least a name and mode + --> $DIR/autodiff_illegal.rs:73:1 + | +LL | / pub fn f8(x: f64) { +LL | | +LL | | unimplemented!() +LL | | } + | |_^ + +error: autodiff must be applied to function + --> $DIR/autodiff_illegal.rs:80:1 + | +LL | / pub fn f9(x: f64) { +LL | | +LL | | unimplemented!() +LL | | } + | |_^ + +error[E0428]: the name `fn_exists` is defined multiple times + --> $DIR/autodiff_illegal.rs:88:1 + | +LL | fn fn_exists() {} + | -------------- previous definition of the value `fn_exists` here +... +LL | #[autodiff(fn_exists, Reverse, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `fn_exists` redefined here + | + = note: `fn_exists` must be defined only once in the value namespace of this module + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: autodiff requires at least a name and mode + --> $DIR/autodiff_illegal.rs:96:1 + | +LL | / pub fn f11() { +LL | | +LL | | unimplemented!() +LL | | } + | |_^ + +error: unknown Mode: `Debug`. Use `Forward` or `Reverse` + --> $DIR/autodiff_illegal.rs:102:18 + | +LL | #[autodiff(df12, Debug)] + | ^^^^^ + +error: did not recognize Activity: `Reverse` + --> $DIR/autodiff_illegal.rs:110:27 + | +LL | #[autodiff(df13, Forward, Reverse)] + | ^^^^^^^ + +error[E0433]: failed to resolve: use of undeclared type `MyFloat` + --> $DIR/autodiff_illegal.rs:131:1 + | +LL | #[autodiff(df15, Reverse, Active, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat` + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0433]: failed to resolve: use of undeclared type `F64Trans` + --> $DIR/autodiff_illegal.rs:153:1 + | +LL | #[autodiff(df18, Reverse, Active, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans` + | + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) + +error: aborting due to 19 previous errors + +Some errors have detailed explanations: E0428, E0433, E0658. +For more information about an error, try `rustc --explain E0428`. diff --git a/tests/ui/autodiff/auxiliary/my_macro.rs b/tests/ui/autodiff/auxiliary/my_macro.rs new file mode 100644 index 0000000000000..417199611cca6 --- /dev/null +++ b/tests/ui/autodiff/auxiliary/my_macro.rs @@ -0,0 +1,12 @@ +//@ force-host +//@ no-prefer-dynamic +#![crate_type = "proc-macro"] + +extern crate proc_macro; +use proc_macro::TokenStream; + +#[proc_macro_attribute] +#[macro_use] +pub fn autodiff(_attr: TokenStream, item: TokenStream) -> TokenStream { + item // identity proc-macro +} diff --git a/tests/ui/autodiff/visibility.rs b/tests/ui/autodiff/visibility.rs new file mode 100644 index 0000000000000..6a4851de2dc02 --- /dev/null +++ b/tests/ui/autodiff/visibility.rs @@ -0,0 +1,17 @@ +//@ ignore-enzyme +//@ revisions: std_autodiff no_std_autodiff +//@[no_std_autodiff] check-pass +//@ aux-build: my_macro.rs +#![crate_type = "lib"] +#![feature(autodiff)] + +#[cfg(std_autodiff)] +use std::autodiff::autodiff; + +extern crate my_macro; +use my_macro::autodiff; // bring `autodiff` in scope + +#[autodiff] +//[std_autodiff]~^^^ ERROR the name `autodiff` is defined multiple times +//[std_autodiff]~^^ ERROR this rustc version does not support autodiff +fn foo() {} diff --git a/tests/ui/autodiff/visibility.std_autodiff.stderr b/tests/ui/autodiff/visibility.std_autodiff.stderr new file mode 100644 index 0000000000000..720c9a00170e9 --- /dev/null +++ b/tests/ui/autodiff/visibility.std_autodiff.stderr @@ -0,0 +1,24 @@ +error[E0252]: the name `autodiff` is defined multiple times + --> $DIR/visibility.rs:12:5 + | +LL | use std::autodiff::autodiff; + | ----------------------- previous import of the macro `autodiff` here +... +LL | use my_macro::autodiff; // bring `autodiff` in scope + | ^^^^^^^^^^^^^^^^^^ `autodiff` reimported here + | + = note: `autodiff` must be defined only once in the macro namespace of this module +help: you can use `as` to change the binding name of the import + | +LL | use my_macro::autodiff as other_autodiff; // bring `autodiff` in scope + | +++++++++++++++++ + +error: this rustc version does not support autodiff + --> $DIR/visibility.rs:14:1 + | +LL | #[autodiff] + | ^^^^^^^^^^^ + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0252`. diff --git a/tests/ui/feature-gates/feature-gate-autodiff-use.has_support.stderr b/tests/ui/feature-gates/feature-gate-autodiff-use.has_support.stderr new file mode 100644 index 0000000000000..36a017dd53c22 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff-use.has_support.stderr @@ -0,0 +1,23 @@ +error[E0658]: use of unstable library feature 'autodiff' + --> $DIR/feature-gate-autodiff-use.rs:13:3 + | +LL | #[autodiff(dfoo, Reverse)] + | ^^^^^^^^ + | + = note: see issue #124509 for more information + = help: add `#![feature(autodiff)]` to the crate attributes to enable + = note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date + +error[E0658]: use of unstable library feature 'autodiff' + --> $DIR/feature-gate-autodiff-use.rs:9:5 + | +LL | use std::autodiff::autodiff; + | ^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: see issue #124509 for more information + = help: add `#![feature(autodiff)]` to the crate attributes to enable + = note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date + +error: aborting due to 2 previous errors + +For more information about this error, try `rustc --explain E0658`. diff --git a/tests/ui/feature-gates/feature-gate-autodiff-use.no_support.stderr b/tests/ui/feature-gates/feature-gate-autodiff-use.no_support.stderr new file mode 100644 index 0000000000000..4b767f824c809 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff-use.no_support.stderr @@ -0,0 +1,29 @@ +error[E0658]: use of unstable library feature 'autodiff' + --> $DIR/feature-gate-autodiff-use.rs:13:3 + | +LL | #[autodiff(dfoo, Reverse)] + | ^^^^^^^^ + | + = note: see issue #124509 for more information + = help: add `#![feature(autodiff)]` to the crate attributes to enable + = note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date + +error: this rustc version does not support autodiff + --> $DIR/feature-gate-autodiff-use.rs:13:1 + | +LL | #[autodiff(dfoo, Reverse)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0658]: use of unstable library feature 'autodiff' + --> $DIR/feature-gate-autodiff-use.rs:9:5 + | +LL | use std::autodiff::autodiff; + | ^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: see issue #124509 for more information + = help: add `#![feature(autodiff)]` to the crate attributes to enable + = note: this compiler was built on YYYY-MM-DD; consider upgrading it if it is out of date + +error: aborting due to 3 previous errors + +For more information about this error, try `rustc --explain E0658`. diff --git a/tests/ui/feature-gates/feature-gate-autodiff-use.rs b/tests/ui/feature-gates/feature-gate-autodiff-use.rs new file mode 100644 index 0000000000000..2276a79d6e2d8 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff-use.rs @@ -0,0 +1,17 @@ +//@ revisions: has_support no_support +//@[no_support] ignore-enzyme +//@[has_support] needs-enzyme + +// This checks that without enabling the autodiff feature, we can't import std::autodiff::autodiff; + +#![crate_type = "lib"] + +use std::autodiff::autodiff; +//[has_support]~^ ERROR use of unstable library feature 'autodiff' +//[no_support]~^^ ERROR use of unstable library feature 'autodiff' + +#[autodiff(dfoo, Reverse)] +//[has_support]~^ ERROR use of unstable library feature 'autodiff' [E0658] +//[no_support]~^^ ERROR use of unstable library feature 'autodiff' [E0658] +//[no_support]~| ERROR this rustc version does not support autodiff +fn foo() {} diff --git a/tests/ui/feature-gates/feature-gate-autodiff.has_support.stderr b/tests/ui/feature-gates/feature-gate-autodiff.has_support.stderr new file mode 100644 index 0000000000000..c25cf7d337371 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff.has_support.stderr @@ -0,0 +1,13 @@ +error: cannot find attribute `autodiff` in this scope + --> $DIR/feature-gate-autodiff.rs:9:3 + | +LL | #[autodiff(dfoo, Reverse)] + | ^^^^^^^^ + | +help: consider importing this attribute macro + | +LL + use std::autodiff::autodiff; + | + +error: aborting due to 1 previous error + diff --git a/tests/ui/feature-gates/feature-gate-autodiff.no_support.stderr b/tests/ui/feature-gates/feature-gate-autodiff.no_support.stderr new file mode 100644 index 0000000000000..c25cf7d337371 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff.no_support.stderr @@ -0,0 +1,13 @@ +error: cannot find attribute `autodiff` in this scope + --> $DIR/feature-gate-autodiff.rs:9:3 + | +LL | #[autodiff(dfoo, Reverse)] + | ^^^^^^^^ + | +help: consider importing this attribute macro + | +LL + use std::autodiff::autodiff; + | + +error: aborting due to 1 previous error + diff --git a/tests/ui/feature-gates/feature-gate-autodiff.rs b/tests/ui/feature-gates/feature-gate-autodiff.rs new file mode 100644 index 0000000000000..4249b229a6985 --- /dev/null +++ b/tests/ui/feature-gates/feature-gate-autodiff.rs @@ -0,0 +1,12 @@ +//@ revisions: has_support no_support +//@[no_support] ignore-enzyme +//@[has_support] needs-enzyme + +#![crate_type = "lib"] + +// This checks that without the autodiff feature enabled, we can't use it. + +#[autodiff(dfoo, Reverse)] +//[has_support]~^ ERROR cannot find attribute `autodiff` in this scope +//[no_support]~^^ ERROR cannot find attribute `autodiff` in this scope +fn foo() {}