diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 84fc6ebbc3172..029c43e0ba82e 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -15,6 +15,7 @@ use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; +use rustc_session::config::CrateType; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::callconv::PassMode; @@ -1136,8 +1137,17 @@ fn codegen_autodiff<'ll, 'tcx>( if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable); } - if tcx.sess.lto() != rustc_session::config::Lto::Fat { - let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutLto); + + let ct = tcx.crate_types(); + let lto = tcx.sess.lto(); + if ct.len() == 1 && ct.contains(&CrateType::Executable) { + if lto != rustc_session::config::Lto::Fat { + let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutLto); + } + } else { + if lto != rustc_session::config::Lto::Fat && !tcx.sess.opts.cg.linker_plugin_lto.enabled() { + let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutLto); + } } let fn_args = instance.args; diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 7fc9fb9cca2d7..69248cf91f241 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -34,6 +34,14 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { return true; } + // FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 + if tcx.has_attr(def_id, sym::autodiff_forward) + || tcx.has_attr(def_id, sym::autodiff_reverse) + || tcx.has_attr(def_id, sym::rustc_autodiff) + { + return true; + } + if tcx.has_attr(def_id, sym::rustc_intrinsic) { // Intrinsic fallback bodies are always cross-crate inlineable. // To ensure that the MIR inliner doesn't cluelessly try to inline fallback diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs index 13868cca944a2..e3646596e75e6 100644 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -7,6 +7,8 @@ use crate::collector::{MonoItems, create_fn_mono_item}; // mono so this does not interfere in `autodiff` intrinsics // codegen process. If they are unused, LLVM will remove them when // compiling with O3. +// FIXME(autodiff): Remove this whole file, as per discussion in +// https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 pub(crate) fn collect_autodiff_fn<'tcx>( tcx: TyCtxt<'tcx>, instance: ty::Instance<'tcx>, diff --git a/tests/run-make/autodiff/rlib/dep.rs b/tests/run-make/autodiff/rlib/dep.rs new file mode 100644 index 0000000000000..6fc9b2c2473f4 --- /dev/null +++ b/tests/run-make/autodiff/rlib/dep.rs @@ -0,0 +1,7 @@ +pub fn f(x: f64, y: f64) -> f64 { + 2.0 * x + y +} + +pub fn g(x: f64) -> f64 { + 2.0 * x +} diff --git a/tests/run-make/autodiff/rlib/lib.rs b/tests/run-make/autodiff/rlib/lib.rs new file mode 100644 index 0000000000000..b459fed643a1c --- /dev/null +++ b/tests/run-make/autodiff/rlib/lib.rs @@ -0,0 +1,13 @@ +#![feature(autodiff)] +extern crate simple_dep; +use std::autodiff::*; + +#[inline(never)] +pub fn f2(x: f64) -> f64 { + x.sin() +} + +#[autodiff_forward(df1_lib, Dual, Dual)] +pub fn _f1(x: f64) -> f64 { + simple_dep::f(x, x) * f2(x) +} diff --git a/tests/run-make/autodiff/rlib/main.rs b/tests/run-make/autodiff/rlib/main.rs new file mode 100644 index 0000000000000..a3e5fcde0381b --- /dev/null +++ b/tests/run-make/autodiff/rlib/main.rs @@ -0,0 +1,8 @@ +extern crate foo; + +fn main() { + //dbg!("Running main.rs"); + let enzyme_y1_lib = foo::df1_lib(1.5, 1.0); + println!("output1: {:?}", enzyme_y1_lib.0); + println!("output2: {:?}", enzyme_y1_lib.1); +} diff --git a/tests/run-make/autodiff/rlib/rmake.rs b/tests/run-make/autodiff/rlib/rmake.rs new file mode 100644 index 0000000000000..59eaa836864c7 --- /dev/null +++ b/tests/run-make/autodiff/rlib/rmake.rs @@ -0,0 +1,66 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use run_make_support::{cwd, run, rustc}; + +fn main() { + // Build the dependency crate. + rustc() + .input("dep.rs") + .arg("-Zautodiff=Enable") + .arg("--edition=2024") + .arg("-Copt-level=3") + .arg("--crate-name=simple_dep") + .arg("-Clinker-plugin-lto") + .arg("--crate-type=lib") + .emit("dep-info,metadata,link") + .run(); + + let cwd = cwd(); + let cwd_str = cwd.to_string_lossy(); + + let mydep = format!("-Ldependency={cwd_str}"); + + let simple_dep_rlib = + format!("--extern=simple_dep={}", cwd.join("libsimple_dep.rlib").to_string_lossy()); + + // Build the main library that depends on `simple_dep`. + rustc() + .input("lib.rs") + .arg("-Zautodiff=Enable") + .arg("--edition=2024") + .arg("-Copt-level=3") + .arg("--crate-name=foo") + .arg("-Clinker-plugin-lto") + .arg("--crate-type=lib") + .emit("dep-info,metadata,link") + .arg(&mydep) + .arg(&simple_dep_rlib) + .run(); + + let foo_rlib = format!("--extern=foo={}", cwd.join("libfoo.rlib").to_string_lossy()); + + // Build the final binary linking both rlibs. + rustc() + .input("main.rs") + .arg("-Zautodiff=Enable") + .arg("--edition=2024") + .arg("-Copt-level=3") + .arg("--crate-name=foo") + .arg("-Clto=fat") + .arg("--crate-type=bin") + .emit("dep-info,link") + .arg(&mydep) + .arg(&foo_rlib) + .arg(&simple_dep_rlib) + .run(); + + // Run the binary and check its output. + let binary = run("foo"); + assert!(binary.status().success(), "binary failed to run"); + + let binary_out = binary.stdout(); + let output = String::from_utf8_lossy(&binary_out); + assert!(output.contains("output1: 4.488727439718245")); + assert!(output.contains("output2: 3.3108023673168265")); +}