Skip to content

Commit 0f05703

Browse files
committed
model offload C++ structs through Rust structs
1 parent f70c6f4 commit 0f05703

File tree

1 file changed

+96
-75
lines changed

1 file changed

+96
-75
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 96 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1818
// The offload memory transfer type for each kernel
1919
let mut memtransfer_types = vec![];
2020
let mut region_ids = vec![];
21-
let offload_entry_ty = add_tgt_offload_entry(&cx);
21+
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
2222
for num in 0..9 {
2323
let kernel = cx.get_function(&format!("kernel_{num}"));
2424
if let Some(kernel) = kernel {
@@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5252
// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5353
// offloaded was defined.
5454
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
55-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5655
let unknown_txt = ";unknown;unknown;0;0;;";
5756
let c_entry_name = CString::new(unknown_txt).unwrap();
5857
let c_val = c_entry_name.as_bytes_with_nul();
@@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7776
at_one
7877
}
7978

80-
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
81-
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
82-
let tptr = cx.type_ptr();
83-
let ti64 = cx.type_i64();
84-
let ti32 = cx.type_i32();
85-
let ti16 = cx.type_i16();
86-
// For each kernel to run on the gpu, we will later generate one entry of this type.
87-
// copied from LLVM
88-
// typedef struct {
79+
struct TgtOffloadEntry {
8980
// uint64_t Reserved;
9081
// uint16_t Version;
9182
// uint16_t Kind;
@@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9586
// uint64_t Size; Size of the entry info (0 if it is a function)
9687
// uint64_t Data;
9788
// void *AuxAddr;
98-
// } __tgt_offload_entry;
99-
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
100-
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
101-
offload_entry_ty
10289
}
10390

104-
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
105-
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
106-
let tptr = cx.type_ptr();
107-
let ti64 = cx.type_i64();
108-
let ti32 = cx.type_i32();
109-
let tarr = cx.type_array(ti32, 3);
91+
impl TgtOffloadEntry {
92+
pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
93+
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
94+
let tptr = cx.type_ptr();
95+
let ti64 = cx.type_i64();
96+
let ti32 = cx.type_i32();
97+
let ti16 = cx.type_i16();
98+
// For each kernel to run on the gpu, we will later generate one entry of this type.
99+
// copied from LLVM
100+
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
101+
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
102+
offload_entry_ty
103+
}
104+
105+
fn new<'ll>(
106+
cx: &'ll SimpleCx<'_>,
107+
region_id: &'ll Value,
108+
llglobal: &'ll Value,
109+
) -> Vec<&'ll Value> {
110+
let reserved = cx.get_const_i64(0);
111+
let version = cx.get_const_i16(1);
112+
let kind = cx.get_const_i16(1);
113+
let flags = cx.get_const_i32(0);
114+
let size = cx.get_const_i64(0);
115+
let data = cx.get_const_i64(0);
116+
let aux_addr = cx.const_null(cx.type_ptr());
117+
vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
118+
}
119+
}
110120

111-
// Taken from the LLVM APITypes.h declaration:
112-
//struct KernelArgsTy {
121+
// Taken from the LLVM APITypes.h declaration:
122+
struct KernelArgsTy {
113123
// uint32_t Version = 0; // Version of this struct for ABI compatibility.
114124
// uint32_t NumArgs = 0; // Number of arguments in each input pointer.
115125
// void **ArgBasePtrs =
@@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
120130
// void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
121131
// void **ArgMappers = nullptr; // User-defined mappers, possibly null.
122132
// uint64_t Tripcount =
123-
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
124-
// struct {
133+
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
134+
// struct {
125135
// uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
126136
// uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
127137
// uint64_t Unused : 62;
@@ -131,12 +141,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
131141
// // The number of threads (for x,y,z dimension).
132142
// uint32_t ThreadLimit[3] = {0, 0, 0};
133143
// uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
134-
//};
135-
let kernel_elements =
136-
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
144+
}
145+
146+
impl KernelArgsTy {
147+
const OFFLOAD_VERSION: u64 = 3;
148+
const FLAGS: u64 = 0;
149+
const TRIPCOUNT: u64 = 0;
150+
fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
151+
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
152+
let tptr = cx.type_ptr();
153+
let ti64 = cx.type_i64();
154+
let ti32 = cx.type_i32();
155+
let tarr = cx.type_array(ti32, 3);
156+
157+
let kernel_elements =
158+
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
159+
160+
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
161+
kernel_arguments_ty
162+
}
137163

138-
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
139-
kernel_arguments_ty
164+
fn new<'ll>(
165+
cx: &'ll SimpleCx<'_>,
166+
num_args: u64,
167+
memtransfer_types: &[&'ll Value],
168+
geps: [&'ll Value; 3],
169+
) -> [(Align, &'ll Value); 13] {
170+
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
171+
let eight = Align::EIGHT;
172+
let mut values = vec![];
173+
values.push((four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)));
174+
values.push((four, cx.get_const_i32(num_args)));
175+
values.push((eight, geps[0]));
176+
values.push((eight, geps[1]));
177+
values.push((eight, geps[2]));
178+
values.push((eight, memtransfer_types[0]));
179+
// The next two are debug infos. FIXME(offload): set them
180+
values.push((eight, cx.const_null(cx.type_ptr())));
181+
values.push((eight, cx.const_null(cx.type_ptr())));
182+
values.push((eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)));
183+
values.push((eight, cx.get_const_i64(KernelArgsTy::FLAGS)));
184+
let ti32 = cx.type_i32();
185+
let ci32_0 = cx.get_const_i32(0);
186+
values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
187+
values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
188+
values.push((four, cx.get_const_i32(0)));
189+
values.try_into().expect("tgt_kernel_arguments construction failed")
190+
}
140191
}
141192

142193
fn gen_tgt_data_mappers<'ll>(
@@ -242,19 +293,10 @@ fn gen_define_handling<'ll>(
242293
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
243294
llvm::set_alignment(llglobal, Align::ONE);
244295
llvm::set_section(llglobal, c".llvm.rodata.offloading");
245-
246-
// Not actively used yet, for calling real kernels
247296
let name = format!(".offloading.entry.kernel_{num}");
248297

249298
// See the __tgt_offload_entry documentation above.
250-
let reserved = cx.get_const_i64(0);
251-
let version = cx.get_const_i16(1);
252-
let kind = cx.get_const_i16(1);
253-
let flags = cx.get_const_i32(0);
254-
let size = cx.get_const_i64(0);
255-
let data = cx.get_const_i64(0);
256-
let aux_addr = cx.const_null(cx.type_ptr());
257-
let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
299+
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
258300

259301
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
260302
let c_name = CString::new(name).unwrap();
@@ -316,7 +358,7 @@ fn gen_call_handling<'ll>(
316358
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
317359
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
318360

319-
let tgt_kernel_decl = gen_tgt_kernel_global(&cx);
361+
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
320362
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
321363

322364
let main_fn = cx.get_function("main");
@@ -404,19 +446,19 @@ fn gen_call_handling<'ll>(
404446
a1: &'ll Value,
405447
a2: &'ll Value,
406448
a4: &'ll Value,
407-
) -> (&'ll Value, &'ll Value, &'ll Value) {
449+
) -> [&'ll Value; 3] {
408450
let i32_0 = cx.get_const_i32(0);
409451

410452
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
411453
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
412454
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
413-
(gep1, gep2, gep3)
455+
[gep1, gep2, gep3]
414456
}
415457

416458
fn generate_mapper_call<'a, 'll>(
417459
builder: &mut SBuilder<'a, 'll>,
418460
cx: &'ll SimpleCx<'ll>,
419-
geps: (&'ll Value, &'ll Value, &'ll Value),
461+
geps: [&'ll Value; 3],
420462
o_type: &'ll Value,
421463
fn_to_call: &'ll Value,
422464
fn_ty: &'ll Type,
@@ -427,7 +469,7 @@ fn gen_call_handling<'ll>(
427469
let i64_max = cx.get_const_i64(u64::MAX);
428470
let num_args = cx.get_const_i32(num_args);
429471
let args =
430-
vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
472+
vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
431473
builder.call(fn_ty, fn_to_call, &args, None);
432474
}
433475

@@ -436,36 +478,20 @@ fn gen_call_handling<'ll>(
436478
let o = memtransfer_types[0];
437479
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
438480
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
481+
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
439482

440483
// Step 3)
441-
let mut values = vec![];
442-
let offload_version = cx.get_const_i32(3);
443-
values.push((4, offload_version));
444-
values.push((4, cx.get_const_i32(num_args)));
445-
values.push((8, geps.0));
446-
values.push((8, geps.1));
447-
values.push((8, geps.2));
448-
values.push((8, memtransfer_types[0]));
449-
// The next two are debug infos. FIXME(offload) set them
450-
values.push((8, cx.const_null(cx.type_ptr())));
451-
values.push((8, cx.const_null(cx.type_ptr())));
452-
values.push((8, cx.get_const_i64(0)));
453-
values.push((8, cx.get_const_i64(0)));
454-
let ti32 = cx.type_i32();
455-
let ci32_0 = cx.get_const_i32(0);
456-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
457-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
458-
values.push((4, cx.get_const_i32(0)));
459-
484+
// Here we fill the KernelArgsTy, see the documentation above
460485
for (i, value) in values.iter().enumerate() {
461486
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
462-
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
487+
builder.store(value.1, ptr, value.0);
463488
}
464489

465490
let args = vec![
466491
s_ident_t,
467-
// MAX == -1
468-
cx.get_const_i64(u64::MAX),
492+
// FIXME(offload) give users a way to select which GPU to use.
493+
cx.get_const_i64(u64::MAX), // MAX == -1.
494+
// FIXME(offload): Don't hardcode the numbers of threads in the future.
469495
cx.get_const_i32(2097152),
470496
cx.get_const_i32(256),
471497
region_ids[0],
@@ -480,19 +506,14 @@ fn gen_call_handling<'ll>(
480506
}
481507

482508
// Step 4)
483-
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
484-
485509
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
486510
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
487511

488512
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
489513

490514
drop(builder);
515+
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
516+
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
517+
// directly resolves to a call to the GPU version.
491518
unsafe { llvm::LLVMDeleteFunction(called) };
492-
493-
// With this we generated the following begin and end mappers. We could easily generate the
494-
// update mapper in an update.
495-
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
496-
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
497-
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
498519
}

0 commit comments

Comments
 (0)