Skip to content

Commit 685adca

Browse files
committed
Add auto-bitcasts from/to x86amx and i32x256 for AMX intrinsics
1 parent 2c12b4a commit 685adca

File tree

15 files changed

+281
-36
lines changed

15 files changed

+281
-36
lines changed

compiler/rustc_codegen_gcc/src/type_of.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Write;
22

3-
use gccjit::{Struct, Type};
3+
use gccjit::{RValue, Struct, Type};
44
use rustc_abi as abi;
55
use rustc_abi::Primitive::*;
66
use rustc_abi::{
@@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
373373
unimplemented!();
374374
}
375375

376-
fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
376+
fn fn_decl_backend_type(
377+
&self,
378+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
379+
_fn_ptr: RValue<'gcc>,
380+
) -> Type<'gcc> {
377381
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
378382
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
379383
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::cmp;
44
use libc::c_uint;
55
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
66
use rustc_codegen_ssa::MemFlags;
7+
use rustc_codegen_ssa::common::TypeKind;
78
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
89
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
910
use rustc_codegen_ssa::traits::*;
@@ -308,7 +309,9 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309
}
309310

310311
pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
311-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
312+
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
313+
fn llvm_param_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type>;
314+
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type;
312315
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
313316
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;
314317

@@ -325,26 +328,29 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325328
}
326329

327330
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
328-
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
331+
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
332+
match &self.ret.mode {
333+
PassMode::Ignore => cx.type_void(),
334+
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
335+
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
336+
PassMode::Indirect { .. } => cx.type_void(),
337+
}
338+
}
339+
340+
fn llvm_param_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type> {
341+
let indirect_return = matches!(self.ret.mode, PassMode::Indirect { .. });
342+
329343
// Ignore "extra" args from the call site for C variadic functions.
330344
// Only the "fixed" args are part of the LLVM function signature.
331345
let args =
332346
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };
333347

334-
// This capacity calculation is approximate.
335-
let mut llargument_tys = Vec::with_capacity(
336-
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
337-
);
348+
let mut llargument_tys =
349+
Vec::with_capacity(args.len() + if indirect_return { 1 } else { 0 });
338350

339-
let llreturn_ty = match &self.ret.mode {
340-
PassMode::Ignore => cx.type_void(),
341-
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
342-
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
343-
PassMode::Indirect { .. } => {
344-
llargument_tys.push(cx.type_ptr());
345-
cx.type_void()
346-
}
347-
};
351+
if indirect_return {
352+
llargument_tys.push(cx.type_ptr());
353+
}
348354

349355
for arg in args {
350356
// Note that the exact number of arguments pushed here is carefully synchronized with
@@ -391,6 +397,36 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
391397
llargument_tys.push(llarg_ty);
392398
}
393399

400+
llargument_tys
401+
}
402+
403+
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type {
404+
let is_llvm_intrinsic = llvm::Intrinsic::try_from(name).is_ok();
405+
406+
// todo(sayantn): how to get the actual llvm signature of intrinsics????
407+
408+
let amx_intrinsic =
409+
is_llvm_intrinsic && name.starts_with(b"llvm.x86.") && name.ends_with(b".internal");
410+
let adjust_ty = |ty| {
411+
// Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
412+
// todo(sayantn): this should work with all 1024-byte vectors, not just `i32x256`
413+
if amx_intrinsic && cx.type_kind(ty) == TypeKind::Vector && cx.vector_length(ty) == 256
414+
{
415+
let element_ty = cx.element_type(ty);
416+
if cx.type_kind(element_ty) == TypeKind::Integer && cx.int_width(element_ty) == 32 {
417+
return cx.type_x86amx();
418+
}
419+
}
420+
ty
421+
};
422+
423+
let llreturn_ty = adjust_ty(self.llvm_return_type(cx));
424+
425+
let mut llargument_tys = self.llvm_param_types(cx);
426+
for llargument_ty in &mut llargument_tys {
427+
*llargument_ty = adjust_ty(&llargument_ty);
428+
}
429+
394430
if self.c_variadic {
395431
cx.type_variadic_func(&llargument_tys, llreturn_ty)
396432
} else {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 123 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
6767
) -> &'ll Value {
6868
debug!("call {:?} with args ({:?})", llfn, args);
6969

70-
let args = self.check_call("call", llty, llfn, args);
70+
let args = self.cast_arguments("call", llty, llfn, args);
7171
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
7272
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
7373
if let Some(funclet_bundle) = funclet_bundle {
@@ -97,10 +97,89 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
9797
GenericBuilder { llbuilder, cx: scx }
9898
}
9999

100+
pub(crate) fn declare_intrinsic(
101+
&mut self,
102+
intrinsic: llvm::Intrinsic,
103+
param_types: &[&'ll Type],
104+
) -> &'ll Value {
105+
unsafe {
106+
llvm::LLVMGetIntrinsicDeclaration(
107+
self.cx.llmod(),
108+
intrinsic.id(),
109+
param_types.as_ptr(),
110+
param_types.len().try_into().unwrap(),
111+
)
112+
}
113+
}
114+
115+
pub(crate) fn intrinsic_type(
116+
&mut self,
117+
intrinsic: llvm::Intrinsic,
118+
param_types: &[&'ll Type],
119+
) -> &'ll Type {
120+
unsafe {
121+
llvm::LLVMIntrinsicGetType(
122+
self.cx.llcx(),
123+
intrinsic.id(),
124+
param_types.as_ptr(),
125+
param_types.len().try_into().unwrap(),
126+
)
127+
}
128+
}
129+
100130
pub(crate) fn bitcast(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
101131
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
102132
}
103133

134+
pub(crate) fn cast_vector_to_tile(&mut self, val: &'ll Value) -> &'ll Value {
135+
let vector_type = self.cx.val_ty(val);
136+
137+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
138+
139+
let intrinsic =
140+
llvm::Intrinsic::try_from(b"llvm.x86.cast.vector.to.tile".as_ref()).unwrap();
141+
let fn_ty = self.intrinsic_type(intrinsic, &[vector_type]);
142+
let f = self.declare_intrinsic(intrinsic, &[vector_type]);
143+
unsafe {
144+
llvm::LLVMBuildCallWithOperandBundles(
145+
self.llbuilder,
146+
fn_ty,
147+
f,
148+
[val].as_ptr().cast(),
149+
1,
150+
[].as_ptr(),
151+
0,
152+
c"".as_ptr(),
153+
)
154+
}
155+
}
156+
157+
pub(crate) fn cast_tile_to_vector(
158+
&mut self,
159+
val: &'ll Value,
160+
vector_type: &'ll Type,
161+
) -> &'ll Value {
162+
assert!(self.cx.val_ty(val) == self.cx.type_x86amx());
163+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
164+
165+
let intrinsic =
166+
llvm::Intrinsic::try_from(b"llvm.x86.cast.tile.to.vector".as_ref()).unwrap();
167+
let fn_ty = self.intrinsic_type(intrinsic, &[vector_type]);
168+
let f = self.declare_intrinsic(intrinsic, &[vector_type]);
169+
unsafe {
170+
llvm::LLVMBuildCallWithOperandBundles(
171+
self.llbuilder,
172+
fn_ty,
173+
f,
174+
[val].as_ptr().cast(),
175+
1,
176+
[].as_ptr(),
177+
0,
178+
c"".as_ptr(),
179+
)
180+
}
181+
}
182+
104183
pub(crate) fn ret_void(&mut self) {
105184
llvm::LLVMBuildRetVoid(self.llbuilder);
106185
}
@@ -349,7 +428,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
349428
) -> &'ll Value {
350429
debug!("invoke {:?} with args ({:?})", llfn, args);
351430

352-
let args = self.check_call("invoke", llty, llfn, args);
431+
let args = self.cast_arguments("invoke", llty, llfn, args);
353432
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
354433
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
355434
if let Some(funclet_bundle) = funclet_bundle {
@@ -381,8 +460,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
381460
};
382461
if let Some(fn_abi) = fn_abi {
383462
fn_abi.apply_attrs_callsite(self, invoke);
463+
self.cast_return(fn_abi, llfn, invoke)
464+
} else {
465+
invoke
384466
}
385-
invoke
386467
}
387468

388469
fn unreachable(&mut self) {
@@ -1404,7 +1485,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14041485
) -> &'ll Value {
14051486
debug!("call {:?} with args ({:?})", llfn, args);
14061487

1407-
let args = self.check_call("call", llty, llfn, args);
1488+
let args = self.cast_arguments("call", llty, llfn, args);
14081489
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
14091490
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
14101491
if let Some(funclet_bundle) = funclet_bundle {
@@ -1434,8 +1515,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14341515
};
14351516
if let Some(fn_abi) = fn_abi {
14361517
fn_abi.apply_attrs_callsite(self, call);
1518+
self.cast_return(fn_abi, llfn, call)
1519+
} else {
1520+
call
14371521
}
1438-
call
14391522
}
14401523

14411524
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
@@ -1596,7 +1679,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15961679
ret.expect("LLVM does not have support for catchret")
15971680
}
15981681

1599-
fn check_call<'b>(
1682+
fn cast_arguments<'b>(
16001683
&mut self,
16011684
typ: &str,
16021685
fn_ty: &'ll Type,
@@ -1627,7 +1710,11 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16271710
Expected {:?} for param {}, got {:?}; injecting bitcast",
16281711
llfn, expected_ty, i, actual_ty
16291712
);
1630-
self.bitcast(actual_val, expected_ty)
1713+
if self.cx.type_kind(expected_ty) == TypeKind::X86_AMX {
1714+
self.cast_vector_to_tile(actual_val)
1715+
} else {
1716+
self.bitcast(actual_val, expected_ty)
1717+
}
16311718
} else {
16321719
actual_val
16331720
}
@@ -1708,6 +1795,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17081795
self.call(self.type_func(&[src_ty], dest_ty), None, None, f, &[val], None, None)
17091796
}
17101797

1798+
fn cast_return(
1799+
&mut self,
1800+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
1801+
llfn: &'ll Value,
1802+
ret: &'ll Value,
1803+
) -> &'ll Value {
1804+
let expected_ty = fn_abi.llvm_return_type(self.cx);
1805+
let actual_ty = self.cx.val_ty(ret);
1806+
1807+
if expected_ty != actual_ty {
1808+
debug!(
1809+
"type mismatch in function call of {:?}. \
1810+
Expected {:?} for return value, got {:?}; injecting bitcast",
1811+
llfn, expected_ty, actual_ty
1812+
);
1813+
if self.cx.type_kind(actual_ty) == TypeKind::X86_AMX {
1814+
self.cast_tile_to_vector(ret, expected_ty)
1815+
} else {
1816+
self.bitcast(ret, expected_ty)
1817+
}
1818+
} else {
1819+
ret
1820+
}
1821+
}
1822+
17111823
pub(crate) fn landing_pad(
17121824
&mut self,
17131825
ty: &'ll Type,
@@ -1737,7 +1849,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17371849
) -> &'ll Value {
17381850
debug!("invoke {:?} with args ({:?})", llfn, args);
17391851

1740-
let args = self.check_call("callbr", llty, llfn, args);
1852+
let args = self.cast_arguments("callbr", llty, llfn, args);
17411853
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
17421854
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
17431855
if let Some(funclet_bundle) = funclet_bundle {
@@ -1770,8 +1882,10 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17701882
};
17711883
if let Some(fn_abi) = fn_abi {
17721884
fn_abi.apply_attrs_callsite(self, callbr);
1885+
self.cast_return(fn_abi, llfn, callbr)
1886+
} else {
1887+
callbr
17731888
}
1774-
callbr
17751889
}
17761890

17771891
// Emits CFI pointer type membership tests.

compiler/rustc_codegen_llvm/src/callee.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub(crate) fn get_fn<'ll, 'tcx>(cx: &CodegenCx<'ll, 'tcx>, instance: Instance<'t
3636
llfn
3737
} else {
3838
let instance_def_id = instance.def_id();
39+
3940
let llfn = if tcx.sess.target.arch == "x86"
4041
&& let Some(dllimport) = crate::common::get_dllimport(tcx, instance_def_id, sym)
4142
{

compiler/rustc_codegen_llvm/src/declare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
158158
fn_abi.llvm_cconv(self),
159159
llvm::UnnamedAddr::Global,
160160
llvm::Visibility::Default,
161-
fn_abi.llvm_type(self),
161+
fn_abi.llvm_type(self, name.as_ref()),
162162
);
163163
fn_abi.apply_attrs_llfn(self, llfn, instance);
164164

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ fn gen_fn<'a, 'll, 'tcx>(
11021102
codegen: &mut dyn FnMut(Builder<'a, 'll, 'tcx>),
11031103
) -> (&'ll Type, &'ll Value) {
11041104
let fn_abi = cx.fn_abi_of_fn_ptr(rust_fn_sig, ty::List::empty());
1105-
let llty = fn_abi.llvm_type(cx);
1105+
let llty = fn_abi.llvm_type(cx, name.as_ref());
11061106
let llfn = cx.declare_fn(name, fn_abi, None);
11071107
cx.set_frame_pointer_type(llfn);
11081108
cx.apply_target_cpu_attr(llfn);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,9 @@ unsafe extern "C" {
10721072
pub(crate) fn LLVMPointerTypeInContext(C: &Context, AddressSpace: c_uint) -> &Type;
10731073
pub(crate) fn LLVMVectorType(ElementType: &Type, ElementCount: c_uint) -> &Type;
10741074

1075+
// Special X86 Type for AMX
1076+
pub(crate) fn LLVMX86AMXTypeInContext(C: &Context) -> &Type;
1077+
10751078
pub(crate) fn LLVMGetElementType(Ty: &Type) -> &Type;
10761079
pub(crate) fn LLVMGetVectorSize(VectorTy: &Type) -> c_uint;
10771080

@@ -1195,6 +1198,24 @@ unsafe extern "C" {
11951198
// Operations on functions
11961199
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
11971200

1201+
// Operations about llvm intrinsics
1202+
pub(crate) fn LLVMLookupIntrinsicID(Name: *const c_char, NameLen: c_uint) -> c_uint;
1203+
pub(crate) fn LLVMGetIntrinsicID(Fn: &Value) -> c_uint;
1204+
pub(crate) fn LLVMGetIntrinsicDeclaration<'a>(
1205+
M: &'a Module,
1206+
ID: c_uint,
1207+
ParamTypes: *const &'a Type,
1208+
ParamCount: c_uint,
1209+
) -> &'a Value;
1210+
pub(crate) fn LLVMIntrinsicGetType<'a>(
1211+
C: &'a Context,
1212+
ID: c_uint,
1213+
ParamTypes: *const &'a Type,
1214+
ParamCount: c_uint,
1215+
) -> &'a Type;
1216+
pub(crate) fn LLVMIntrinsicGetName(ID: c_uint, NameLength: *mut c_uint) -> *const c_char;
1217+
pub(crate) fn LLVMIntrinsicIsOverloaded(ID: c_uint) -> Bool;
1218+
11981219
// Operations on parameters
11991220
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
12001221
pub(crate) safe fn LLVMCountParams(Fn: &Value) -> c_uint;

0 commit comments

Comments
 (0)