Skip to content

Commit 5bb815a

Browse files
committed
model offload C++ structs through Rust structs
1 parent b56d555 commit 5bb815a

File tree

1 file changed

+97
-75
lines changed

1 file changed

+97
-75
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1818
// The offload memory transfer type for each kernel
1919
let mut memtransfer_types = vec![];
2020
let mut region_ids = vec![];
21-
let offload_entry_ty = add_tgt_offload_entry(&cx);
21+
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
2222
for num in 0..9 {
2323
let kernel = cx.get_function(&format!("kernel_{num}"));
2424
if let Some(kernel) = kernel {
@@ -52,7 +52,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5252
// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5353
// offloaded was defined.
5454
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
55-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5655
let unknown_txt = ";unknown;unknown;0;0;;";
5756
let c_entry_name = CString::new(unknown_txt).unwrap();
5857
let c_val = c_entry_name.as_bytes_with_nul();
@@ -77,15 +76,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7776
at_one
7877
}
7978

80-
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
81-
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
82-
let tptr = cx.type_ptr();
83-
let ti64 = cx.type_i64();
84-
let ti32 = cx.type_i32();
85-
let ti16 = cx.type_i16();
86-
// For each kernel to run on the gpu, we will later generate one entry of this type.
87-
// copied from LLVM
88-
// typedef struct {
79+
struct TgtOffloadEntry {
8980
// uint64_t Reserved;
9081
// uint16_t Version;
9182
// uint16_t Kind;
@@ -95,21 +86,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9586
// uint64_t Size; Size of the entry info (0 if it is a function)
9687
// uint64_t Data;
9788
// void *AuxAddr;
98-
// } __tgt_offload_entry;
99-
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
100-
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
101-
offload_entry_ty
10289
}
10390

104-
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
105-
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
106-
let tptr = cx.type_ptr();
107-
let ti64 = cx.type_i64();
108-
let ti32 = cx.type_i32();
109-
let tarr = cx.type_array(ti32, 3);
91+
impl TgtOffloadEntry {
92+
pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
93+
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
94+
let tptr = cx.type_ptr();
95+
let ti64 = cx.type_i64();
96+
let ti32 = cx.type_i32();
97+
let ti16 = cx.type_i16();
98+
// For each kernel to run on the gpu, we will later generate one entry of this type.
99+
// copied from LLVM
100+
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
101+
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
102+
offload_entry_ty
103+
}
104+
105+
fn new<'ll>(
106+
cx: &'ll SimpleCx<'_>,
107+
region_id: &'ll Value,
108+
llglobal: &'ll Value,
109+
) -> [&'ll Value; 9] {
110+
let reserved = cx.get_const_i64(0);
111+
let version = cx.get_const_i16(1);
112+
let kind = cx.get_const_i16(1);
113+
let flags = cx.get_const_i32(0);
114+
let size = cx.get_const_i64(0);
115+
let data = cx.get_const_i64(0);
116+
let aux_addr = cx.const_null(cx.type_ptr());
117+
[reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
118+
}
119+
}
110120

111-
// Taken from the LLVM APITypes.h declaration:
112-
//struct KernelArgsTy {
121+
// Taken from the LLVM APITypes.h declaration:
122+
struct KernelArgsTy {
113123
// uint32_t Version = 0; // Version of this struct for ABI compatibility.
114124
// uint32_t NumArgs = 0; // Number of arguments in each input pointer.
115125
// void **ArgBasePtrs =
@@ -120,8 +130,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
120130
// void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
121131
// void **ArgMappers = nullptr; // User-defined mappers, possibly null.
122132
// uint64_t Tripcount =
123-
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
124-
// struct {
133+
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
134+
// struct {
125135
// uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
126136
// uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
127137
// uint64_t Unused : 62;
@@ -131,12 +141,54 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
131141
// // The number of threads (for x,y,z dimension).
132142
// uint32_t ThreadLimit[3] = {0, 0, 0};
133143
// uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
134-
//};
135-
let kernel_elements =
136-
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
144+
}
145+
146+
impl KernelArgsTy {
147+
const OFFLOAD_VERSION: u64 = 3;
148+
const FLAGS: u64 = 0;
149+
const TRIPCOUNT: u64 = 0;
150+
fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
151+
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
152+
let tptr = cx.type_ptr();
153+
let ti64 = cx.type_i64();
154+
let ti32 = cx.type_i32();
155+
let tarr = cx.type_array(ti32, 3);
156+
157+
let kernel_elements =
158+
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
159+
160+
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
161+
kernel_arguments_ty
162+
}
137163

138-
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
139-
kernel_arguments_ty
164+
fn new<'ll>(
165+
cx: &'ll SimpleCx<'_>,
166+
num_args: u64,
167+
memtransfer_types: &[&'ll Value],
168+
geps: [&'ll Value; 3],
169+
) -> [(Align, &'ll Value); 13] {
170+
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
171+
let eight = Align::EIGHT;
172+
173+
let ti32 = cx.type_i32();
174+
let ci32_0 = cx.get_const_i32(0);
175+
[
176+
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
177+
(four, cx.get_const_i32(num_args)),
178+
(eight, geps[0]),
179+
(eight, geps[1]),
180+
(eight, geps[2]),
181+
(eight, memtransfer_types[0]),
182+
// The next two are debug infos. FIXME(offload): set them
183+
(eight, cx.const_null(cx.type_ptr())), // dbg
184+
(eight, cx.const_null(cx.type_ptr())), // dbg
185+
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
186+
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
187+
(four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
188+
(four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
189+
(four, cx.get_const_i32(0)),
190+
]
191+
}
140192
}
141193

142194
fn gen_tgt_data_mappers<'ll>(
@@ -245,19 +297,10 @@ fn gen_define_handling<'ll>(
245297
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
246298
llvm::set_alignment(llglobal, Align::ONE);
247299
llvm::set_section(llglobal, c".llvm.rodata.offloading");
248-
249-
// Not actively used yet, for calling real kernels
250300
let name = format!(".offloading.entry.kernel_{num}");
251301

252302
// See the __tgt_offload_entry documentation above.
253-
let reserved = cx.get_const_i64(0);
254-
let version = cx.get_const_i16(1);
255-
let kind = cx.get_const_i16(1);
256-
let flags = cx.get_const_i32(0);
257-
let size = cx.get_const_i64(0);
258-
let data = cx.get_const_i64(0);
259-
let aux_addr = cx.const_null(cx.type_ptr());
260-
let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
303+
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
261304

262305
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
263306
let c_name = CString::new(name).unwrap();
@@ -319,7 +362,7 @@ fn gen_call_handling<'ll>(
319362
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
320363
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
321364

322-
let tgt_kernel_decl = gen_tgt_kernel_global(&cx);
365+
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
323366
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
324367

325368
let main_fn = cx.get_function("main");
@@ -407,19 +450,19 @@ fn gen_call_handling<'ll>(
407450
a1: &'ll Value,
408451
a2: &'ll Value,
409452
a4: &'ll Value,
410-
) -> (&'ll Value, &'ll Value, &'ll Value) {
453+
) -> [&'ll Value; 3] {
411454
let i32_0 = cx.get_const_i32(0);
412455

413456
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
414457
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
415458
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
416-
(gep1, gep2, gep3)
459+
[gep1, gep2, gep3]
417460
}
418461

419462
fn generate_mapper_call<'a, 'll>(
420463
builder: &mut SBuilder<'a, 'll>,
421464
cx: &'ll SimpleCx<'ll>,
422-
geps: (&'ll Value, &'ll Value, &'ll Value),
465+
geps: [&'ll Value; 3],
423466
o_type: &'ll Value,
424467
fn_to_call: &'ll Value,
425468
fn_ty: &'ll Type,
@@ -430,7 +473,7 @@ fn gen_call_handling<'ll>(
430473
let i64_max = cx.get_const_i64(u64::MAX);
431474
let num_args = cx.get_const_i32(num_args);
432475
let args =
433-
vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
476+
vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
434477
builder.call(fn_ty, fn_to_call, &args, None);
435478
}
436479

@@ -439,36 +482,20 @@ fn gen_call_handling<'ll>(
439482
let o = memtransfer_types[0];
440483
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
441484
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
485+
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
442486

443487
// Step 3)
444-
let mut values = vec![];
445-
let offload_version = cx.get_const_i32(3);
446-
values.push((4, offload_version));
447-
values.push((4, cx.get_const_i32(num_args)));
448-
values.push((8, geps.0));
449-
values.push((8, geps.1));
450-
values.push((8, geps.2));
451-
values.push((8, memtransfer_types[0]));
452-
// The next two are debug infos. FIXME(offload) set them
453-
values.push((8, cx.const_null(cx.type_ptr())));
454-
values.push((8, cx.const_null(cx.type_ptr())));
455-
values.push((8, cx.get_const_i64(0)));
456-
values.push((8, cx.get_const_i64(0)));
457-
let ti32 = cx.type_i32();
458-
let ci32_0 = cx.get_const_i32(0);
459-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
460-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
461-
values.push((4, cx.get_const_i32(0)));
462-
488+
// Here we fill the KernelArgsTy, see the documentation above
463489
for (i, value) in values.iter().enumerate() {
464490
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
465-
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
491+
builder.store(value.1, ptr, value.0);
466492
}
467493

468494
let args = vec![
469495
s_ident_t,
470-
// MAX == -1
471-
cx.get_const_i64(u64::MAX),
496+
// FIXME(offload) give users a way to select which GPU to use.
497+
cx.get_const_i64(u64::MAX), // MAX == -1.
498+
// FIXME(offload): Don't hardcode the numbers of threads in the future.
472499
cx.get_const_i32(2097152),
473500
cx.get_const_i32(256),
474501
region_ids[0],
@@ -483,19 +510,14 @@ fn gen_call_handling<'ll>(
483510
}
484511

485512
// Step 4)
486-
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
487-
488513
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
489514
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
490515

491516
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
492517

493518
drop(builder);
519+
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
520+
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
521+
// directly resolves to a call to the GPU version.
494522
unsafe { llvm::LLVMDeleteFunction(called) };
495-
496-
// With this we generated the following begin and end mappers. We could easily generate the
497-
// update mapper in an update.
498-
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
499-
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
500-
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
501523
}

0 commit comments

Comments
 (0)