Skip to content

Commit

Permalink
Abstract typetrees in separate structure to make indepdent of LLVM
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenz Schmidt committed Feb 22, 2023
1 parent d6cc01b commit f2372c0
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 45 deletions.
34 changes: 19 additions & 15 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::llvm_util;
use crate::type_::Type;
use crate::LlvmCodegenBackend;
use crate::ModuleLlvm;
use crate::typetree::to_enzyme_typetree;
use llvm::{EnzymeLogicRef, EnzymeTypeAnalysisRef, CreateTypeAnalysis, CreateEnzymeLogic, LLVMSetValueName2, LLVMGetModuleContext, LLVMAddFunction, BasicBlock, LLVMGetElementType, LLVMAppendBasicBlockInContext, LLVMCountParams, LLVMTypeOf, LLVMCreateBuilderInContext, LLVMPositionBuilderAtEnd, LLVMBuildExtractValue, LLVMBuildRet, LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMBuildCall, LLVMGetParams, LLVMDeleteFunction, LLVMCountStructElementTypes, LLVMGetReturnType, enzyme_rust_forward_diff, enzyme_rust_reverse_diff, LLVMVoidTypeInContext};
//use llvm::LLVMRustGetNamedValue;
use rustc_codegen_ssa::back::link::ensure_removed;
Expand All @@ -37,7 +38,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo};
use tracing::debug;

use libc::{c_char, c_int, c_uint, c_void, size_t};
use std::ffi::CString;
use std::ffi::{CString, CStr};
use std::fs;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -579,7 +580,7 @@ unsafe fn create_wrapper<'a>(
outer_params,
inner_params,
c_inner_fnc_name,
)
)
}


Expand Down Expand Up @@ -641,31 +642,36 @@ pub(crate) unsafe fn extract_return_type<'a>(
// As unsafe as it can be.
#[allow(unused_variables)]
#[allow(unused)]
pub(crate) unsafe fn enzyme_ad(llmod: &llvm::Module, llcx: &llvm::Context, item: AutoDiffItem, typetree: DiffTypeTree) -> Result<(), FatalError> {
pub(crate) unsafe fn enzyme_ad(llmod: &llvm::Module, llcx: &llvm::Context, item: AutoDiffItem) -> Result<(), FatalError> {
let autodiff_mode = item.attrs.mode;
let rust_name = item.source;
let rust_name2 = &item.target;

let args_activity = item.attrs.input_activity.clone();
let ret_activity: DiffActivity = item.attrs.ret_activity;

// get target and source function
let name = CString::new(rust_name.to_owned()).unwrap();
let name2 = CString::new(rust_name2.clone()).unwrap();
let src_fnc_tmp = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr());
let target_fnc_tmp = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr());
assert!(src_fnc_tmp.is_some());
assert!(target_fnc_tmp.is_some());
let fnc_todiff = src_fnc_tmp.unwrap();
let target_fnc = target_fnc_tmp.unwrap();
let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap();
let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap();

// create enzyme typetrees
let llvm_data_layout = unsafe{ llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout = std::str::from_utf8(unsafe {CStr::from_ptr(llvm_data_layout)}.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");

let input_tts = item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect();
let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx);

let opt = 1;
let ret_primary_ret = false;
let diff_primary_ret = false;
let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8);
let type_analysis: EnzymeTypeAnalysisRef = CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0);
let mut res: &Value = match item.attrs.mode {
DiffMode::Forward => enzyme_rust_forward_diff(logic_ref, type_analysis, fnc_todiff, args_activity, ret_activity, ret_primary_ret, typetree),
DiffMode::Reverse => enzyme_rust_reverse_diff(logic_ref, type_analysis, fnc_todiff, args_activity, ret_activity, ret_primary_ret, diff_primary_ret, typetree),
DiffMode::Forward => enzyme_rust_forward_diff(logic_ref, type_analysis, src_fnc, args_activity, ret_activity, ret_primary_ret, input_tts, output_tt),
DiffMode::Reverse => enzyme_rust_reverse_diff(logic_ref, type_analysis, src_fnc, args_activity, ret_activity, ret_primary_ret, diff_primary_ret, input_tts, output_tt),
_ => unreachable!(),
};
let f_type = LLVMTypeOf(res);
Expand Down Expand Up @@ -698,7 +704,7 @@ pub(crate) unsafe fn differentiate(
module: &ModuleCodegen<ModuleLlvm>,
_cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
typetrees: FxHashMap<String, DiffTypeTree>,
_typetrees: FxHashMap<String, DiffTypeTree>,
config: &ModuleConfig,
) -> Result<(), FatalError> {

Expand All @@ -712,9 +718,7 @@ pub(crate) unsafe fn differentiate(
}

for item in diff_items {
let tt = typetrees.get(&item.source).unwrap().clone();

let res = enzyme_ad(llmod, llcx, item, tt);
let res = enzyme_ad(llmod, llcx, item);
assert!(res.is_ok());
}

Expand Down
11 changes: 7 additions & 4 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mod coverageinfo;
mod debuginfo;
mod declare;
mod intrinsic;
mod typetree;

// The following is a work around that replaces `pub mod llvm;` and that fixes issue 53912.
#[path = "llvm/mod.rs"]
Expand Down Expand Up @@ -448,7 +449,7 @@ pub fn get_enzyme_typetree<'tcx>(id: Ty<'tcx>, llvm_data_layout: &str, tcx: TyCt
let inner_id = id.builtin_deref(true).unwrap().ty;
let inner_tt = get_enzyme_typetree(inner_id, llvm_data_layout, tcx, llcx, depth+1);

tt.merge(inner_tt).only(-1)
let tt = tt.merge(inner_tt).only(-1);
println!("{:depth$} add indirection {}", "", tt);

return tt;
Expand Down Expand Up @@ -497,14 +498,14 @@ pub fn get_enzyme_typetree<'tcx>(id: Ty<'tcx>, llvm_data_layout: &str, tcx: TyCt
println!("{:depth$} -> {}", "", inner_tt);
field_tt.push(inner_tt);

if field_ty.is_adt() {
if field_ty.is_integral() {
field_sizes.push(1);
} else {
let param_env_and = ParamEnvAnd {
param_env: ParamEnv::empty(),
value: field_ty,
};
field_sizes.push(tcx.layout_of(param_env_and).unwrap().size.bytes());
} else {
field_sizes.push(1);
}
}
//dbg!(offsets);
Expand All @@ -520,6 +521,8 @@ pub fn get_enzyme_typetree<'tcx>(id: Ty<'tcx>, llvm_data_layout: &str, tcx: TyCt
let tt = tt.clone();
//let tt = tt.only(offset.bytes_usize() as isize);
let tt = tt.shift(llvm_data_layout, 0, size, offset.bytes_usize() as usize);
dbg!(&offset, &size);
dbg!(&tt);

ret_tt = ret_tt.merge(tt);
}
Expand Down
17 changes: 9 additions & 8 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use std::ffi::{CStr, CString};
use std::marker::PhantomData;

use super::RustString;
use crate::DiffTypeTree;

pub type Bool = c_uint;

Expand Down Expand Up @@ -1010,7 +1009,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
input_diffactivity: Vec<DiffActivity>,
ret_diffactivity: DiffActivity,
mut ret_primary_ret: bool,
typetree: DiffTypeTree,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
) -> &Value{

let ret_activity = cdiffe_from(ret_diffactivity);
Expand All @@ -1034,14 +1034,14 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
ret_primary_ret = false;
}

let mut args_tree = typetree.input_tt.into_iter()
let mut args_tree = input_tts.iter()
.map(|x| x.inner).collect::<Vec<_>>();
//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_activity.len()];
let ret = typetree.ret_tt;
let ret = output_tt.clone();

let kv_tmp = IntList {
data: std::ptr::null_mut(),
Expand Down Expand Up @@ -1083,10 +1083,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
ret_activity: DiffActivity,
mut ret_primary_ret: bool,
diff_primary_ret: bool,
typetree: DiffTypeTree,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
) -> &Value{

dbg!(&typetree);
dbg!(&input_tts);
let ret_activity = cdiffe_from(ret_activity);
assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF);
let input_activity: Vec<CDIFFE_TYPE> = input_activity.iter().map(|&x| cdiffe_from(x)).collect();
Expand All @@ -1103,15 +1104,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
ret_primary_ret = false;
}

let mut args_tree = typetree.input_tt.iter()
let mut args_tree = input_tts.iter()
.map(|x| x.inner).collect::<Vec<_>>();

//let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()];

// We don't support volatile / extern / (global?) values.
// Just because I didn't had time to test them, and it seems less urgent.
let args_uncacheable = vec![0; input_activity.len()];
let ret = typetree.ret_tt;
let ret = output_tt.clone();
let kv_tmp = IntList {
data: std::ptr::null_mut(),
size: 0,
Expand Down
22 changes: 22 additions & 0 deletions compiler/rustc_codegen_llvm/src/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use rustc_middle::middle::typetree::{TypeTree, Kind};
use crate::llvm;

pub fn to_enzyme_typetree(tree: TypeTree, llvm_data_layout: &str, llcx: &llvm::Context) -> llvm::TypeTree {
tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| {
let inner_tt = if x.child.0.is_empty() {
let scalar = match x.kind {
Kind::Integer => llvm::CConcreteType::DT_Integer,
Kind::Float => llvm::CConcreteType::DT_Float,
Kind::Double => llvm::CConcreteType::DT_Double,
Kind::Pointer => llvm::CConcreteType::DT_Pointer,
_ => unreachable!(),
};

llvm::TypeTree::from_type(scalar, llcx)
} else {
to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx)
};

obj.merge(inner_tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize))
})
}
7 changes: 6 additions & 1 deletion compiler/rustc_middle/src/middle/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::str::FromStr;
use crate::middle::typetree::TypeTree;

#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)]
Expand Down Expand Up @@ -73,10 +74,12 @@ impl AutoDiffAttrs {
}
}

pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
pub fn into_item(self, source: String, target: String, inputs: Vec<TypeTree>, output: TypeTree) -> AutoDiffItem {
AutoDiffItem {
source,
target,
inputs,
output,
attrs: self,
}
}
Expand All @@ -87,4 +90,6 @@ pub struct AutoDiffItem {
pub source: String,
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/middle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod privacy;
pub mod region;
pub mod resolve_lifetime;
pub mod stability;
pub mod typetree;

pub fn provide(providers: &mut crate::ty::query::Providers) {
limits::provide(providers);
Expand Down
Loading

0 comments on commit f2372c0

Please sign in to comment.