diff --git a/src/abi_aarch64.cpp b/src/abi_aarch64.cpp index b71328d274e9b5..c53b02a2c9cb07 100644 --- a/src/abi_aarch64.cpp +++ b/src/abi_aarch64.cpp @@ -16,11 +16,53 @@ namespace { typedef bool AbiState; static const AbiState default_abi_state = 0; +static Type *get_llvm_vectype(jl_datatype_t *dt) +{ + // Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt) + // `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields > 0` + size_t nfields = dt->nfields; + assert(nfields > 0); + if (nfields < 2) + return nullptr; + static Type *T_vec64 = VectorType::get(T_int32, 2); + static Type *T_vec128 = VectorType::get(T_int32, 4); + Type *lltype; + // Short vector should be either 8 bytes or 16 bytes. + // Note that there are only two distinct fundamental types for + // short vectors so we normalize them to <2 x i32> and <4 x i32> + switch (dt->size) { + case 8: + lltype = T_vec64; + break; + case 16: + lltype = T_vec128; + break; + default: + return nullptr; + } + // Since `dt` is pointer free and has no padding and is 8 or 16 in size, + // `ft0` must be concrete, immutable with no padding and we don't need + // to check if its size is legal since it is included in + // the homogeneity check. + jl_datatype_t *ft0 = (jl_datatype_t*)jl_field_type(dt, 0); + // `ft0` should be a `VecElement` type and the true element type + // should be a `bitstype` + if (ft0->name != jl_vecelement_typename || + ((jl_datatype_t*)jl_field_type(ft0, 0))->nfields) + return nullptr; + for (int i = 1; i < nfields; i++) { + if (jl_field_type(dt, i) != (jl_value_t*)ft0) { + // Not homogeneous + return nullptr; + } + } + return lltype; +} + static Type *get_llvm_fptype(jl_datatype_t *dt) { // Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt) - if (dt->mutabl || jl_datatype_nfields(dt) != 0) - return NULL; + // `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields == 0` Type *lltype; // Check size first since it's cheaper. switch (dt->size) { @@ -37,9 +79,17 @@ static Type *get_llvm_fptype(jl_datatype_t *dt) lltype = T_float128; break; default: - return NULL; + return nullptr; } - return jl_is_floattype((jl_value_t*)dt) ? lltype : NULL; + return jl_is_floattype((jl_value_t*)dt) ? lltype : nullptr; +} + +static Type *get_llvm_fp_or_vectype(jl_datatype_t *dt) +{ + // Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt) + if (dt->mutabl || !dt->pointerfree || dt->haspadding) + return nullptr; + return dt->nfields ? get_llvm_vectype(dt) : get_llvm_fptype(dt); } struct ElementType { @@ -50,8 +100,6 @@ struct ElementType { // Whether a type is a homogeneous floating-point aggregates (HFA) or a // homogeneous short-vector aggregates (HVA). Returns the element type. -// We only handle HFA of HP, SP, DP and QP here since these are the only ones we -// have (no vectors). // An Homogeneous Aggregate is a Composite Type where all of the Fundamental // Data Types of the members that compose the type are the same. // Note that it is the fundamental types that are important and not the member @@ -62,6 +110,7 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType // dt is a pointerfree type, (all members are isbits) // dsz == dt->size > 0 // 0 <= nele <= 3 + // dt has no padding // We ignore zero sized member here. This isn't really consistent with // GCC for zero-sized array members. GCC seems to treat structs with @@ -83,6 +132,14 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType dt = (jl_datatype_t*)jl_field_type(dt, i); continue; } + if (Type *vectype = get_llvm_vectype(dt)) { + if ((ele.sz && dsz != ele.sz) || (ele.type && ele.type != vectype)) + return false; + ele.type = vectype; + ele.sz = dsz; + nele++; + return true; + } // Otherwise, process each members for (;i < nfields;i++) { size_t fieldsz = jl_field_size(dt, i); @@ -183,9 +240,7 @@ static Type *classify_arg(jl_value_t *ty, bool *fpreg, bool *onstack, // the argument is allocated to the least significant bits of register // v[NSRN]. The NSRN is incremented by one. The argument has now been // allocated. - // Note that this is missing QP float as well as short vector types since we - // don't really have those types. - if (get_llvm_fptype(dt)) { + if (get_llvm_fp_or_vectype(dt)) { *fpreg = true; return NULL; } @@ -323,7 +378,7 @@ Type *preferred_llvm_type(jl_value_t *ty, bool) if (!jl_is_datatype(ty) || jl_is_abstracttype(ty)) return NULL; jl_datatype_t *dt = (jl_datatype_t*)ty; - if (Type *fptype = get_llvm_fptype(dt)) + if (Type *fptype = get_llvm_fp_or_vectype(dt)) return fptype; bool fpreg = false; bool onstack = false; diff --git a/src/alloc.c b/src/alloc.c index fda70983dd7951..f3993cfc582abe 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -843,7 +843,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_uninitialized_datatype(size_t nfields, int8_t // For sake of Ahead-Of-Time (AOT) compilation, this routine has to work // without LLVM being available. unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) { - if (!is_vecelement_type(t)) + if (!jl_is_vecelement_type(t)) return 0; // LLVM 3.7 and 3.8 either crash or generate wrong code for many // SIMD vector sizes N. It seems the rule is that N can have at @@ -859,7 +859,7 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) { return 0; // nfields has more than two 1s assert(jl_datatype_nfields(t)==1); jl_value_t *ty = jl_field_type(t, 0); - if( !jl_is_bitstype(ty) ) + if (!jl_is_bitstype(ty)) // LLVM requires that a vector element be a primitive type. // LLVM allows pointer types as vector elements, but until a // motivating use case comes up for Julia, we reject pointers. diff --git a/src/ccalltest.c b/src/ccalltest.c index 1046fea33b86b1..7533aafc4490ee 100644 --- a/src/ccalltest.c +++ b/src/ccalltest.c @@ -449,4 +449,30 @@ JL_DLLEXPORT struct_aa64_2 test_aa64_fp16_2(int v1, float v2, return x; } +#include + +JL_DLLEXPORT int64x2_t test_aa64_vec_1(int32x2_t v1, float _v2, int32x2_t v3) +{ + int v2 = (int)_v2; + return vmovl_s32(v1 * v2 + v3); +} + +// This is a homogenious short vector aggregate +typedef struct { + int8x8_t v1; + float32x2_t v2; +} struct_aa64_3; + +// This is NOT a homogenious short vector aggregate +typedef struct { + float32x2_t v2; + int16x8_t v1; +} struct_aa64_4; + +JL_DLLEXPORT struct_aa64_3 test_aa64_vec_2(struct_aa64_3 v1, struct_aa64_4 v2) +{ + struct_aa64_3 x = {v1.v1 + vmovn_s16(v2.v1), v1.v2 - v2.v2}; + return x; +} + #endif diff --git a/src/cgutils.cpp b/src/cgutils.cpp index c0ceda7fcb1aa3..62c70f327e1d02 100644 --- a/src/cgutils.cpp +++ b/src/cgutils.cpp @@ -385,7 +385,7 @@ static Type *julia_struct_to_llvm(jl_value_t *jt, bool *isboxed) latypes.push_back(lty); } if (!isTuple) { - if (is_vecelement_type(jt)) + if (jl_is_vecelement_type(jt)) // VecElement type is unwrapped in LLVM jst->struct_decl = latypes[0]; else @@ -1101,7 +1101,7 @@ static jl_cgval_t emit_getfield_knownidx(const jl_cgval_t &strct, unsigned idx, } else if (strct.ispointer()) { // something stack allocated Value *addr; - if (is_vecelement_type((jl_value_t*)jt)) + if (jl_is_vecelement_type((jl_value_t*)jt)) // VecElement types are unwrapped in LLVM. addr = strct.V; else @@ -1678,7 +1678,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg // or instead initialize the stack buffer with stores bool init_as_value = false; if (lt->isVectorTy() || - is_vecelement_type(ty) || + jl_is_vecelement_type(ty) || type_is_ghost(lt)) // maybe also check the size ? init_as_value = true; @@ -1714,7 +1714,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg strct = builder.CreateInsertValue(strct, fval, ArrayRef(&idx,1)); else { // Must be a VecElement type, which comes unwrapped in LLVM. - assert(is_vecelement_type(ty)); + assert(jl_is_vecelement_type(ty)); strct = fval; } } diff --git a/src/julia.h b/src/julia.h index f7e5ec96e8dfe2..adcf356ca3c1b0 100644 --- a/src/julia.h +++ b/src/julia.h @@ -954,7 +954,7 @@ STATIC_INLINE int jl_is_tuple_type(void *t) ((jl_datatype_t*)(t))->name == jl_tuple_typename); } -STATIC_INLINE int is_vecelement_type(jl_value_t* t) +STATIC_INLINE int jl_is_vecelement_type(jl_value_t* t) { return (jl_is_datatype(t) && ((jl_datatype_t*)(t))->name == jl_vecelement_typename); diff --git a/test/ccall.jl b/test/ccall.jl index d90d6840e33d80..1ffe0edcc70f98 100644 --- a/test/ccall.jl +++ b/test/ccall.jl @@ -540,6 +540,17 @@ immutable Struct_AA64_2 v2::Float64 end +# This is a homogenious short vector aggregate +immutable Struct_AA64_3 + v1::VecReg{8,Int8} + v2::VecReg{2,Float32} +end +# This is NOT a homogenious short vector aggregate +immutable Struct_AA64_4 + v2::VecReg{2,Float32} + v1::VecReg{8,Int16} +end + if Sys.ARCH === :x86_64 function test_sse(a1::V4xF32,a2::V4xF32,a3::V4xF32,a4::V4xF32) @@ -590,4 +601,38 @@ elseif Sys.ARCH === :aarch64 expected = Struct_AA64_2(v4 / 2 + 1, v1 * 2 + v2 * 4 - v3) @test res === expected end + for v1_1 in 1:4, v1_2 in -2:2, v2 in -4:-1, v3_1 in 3:5, v3_2 in 6:8 + res = ccall((:test_aa64_vec_1, libccalltest), + VecReg{2,Int64}, + (VecReg{2,Int32}, Float32, VecReg{2,Int32}), + (VecElement(Int32(v1_1)), VecElement(Int32(v1_2))), + v2, (VecElement(Int32(v3_1)), VecElement(Int32(v3_2)))) + expected = (VecElement(v1_1 * v2 + v3_1), VecElement(v1_2 * v2 + v3_2)) + @test res === expected + end + for v1_11 in 1:4, v1_12 in -2:2, v1_21 in 1:4, v1_22 in -2:2, + v2_11 in 1:4, v2_12 in -2:2, v2_21 in 1:4, v2_22 in -2:2 + v1 = Struct_AA64_3((VecElement(Int8(v1_11)), VecElement(Int8(v1_12)), + VecElement(Int8(0)), VecElement(Int8(0)), + VecElement(Int8(0)), VecElement(Int8(0)), + VecElement(Int8(0)), VecElement(Int8(0))), + (VecElement(Float32(v1_21)), + VecElement(Float32(v1_22)))) + v2 = Struct_AA64_4((VecElement(Float32(v2_21)), + VecElement(Float32(v2_22))), + (VecElement(Int16(v2_11)), VecElement(Int16(v2_12)), + VecElement(Int16(0)), VecElement(Int16(0)), + VecElement(Int16(0)), VecElement(Int16(0)), + VecElement(Int16(0)), VecElement(Int16(0)))) + res = ccall((:test_aa64_vec_2, libccalltest), + Struct_AA64_3, (Struct_AA64_3, Struct_AA64_4), v1, v2) + expected = Struct_AA64_3((VecElement(Int8(v1_11 + v2_11)), + VecElement(Int8(v1_12 + v2_12)), + VecElement(Int8(0)), VecElement(Int8(0)), + VecElement(Int8(0)), VecElement(Int8(0)), + VecElement(Int8(0)), VecElement(Int8(0))), + (VecElement(Float32(v1_21 - v2_21)), + VecElement(Float32(v1_22 - v2_22)))) + @test res === expected + end end