Skip to content

Commit

Permalink
Support vector type in AArch64 C abi
Browse files Browse the repository at this point in the history
[ci skip]
  • Loading branch information
yuyichao committed May 29, 2016
1 parent b0ce3c7 commit 44d4ece
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 18 deletions.
3 changes: 2 additions & 1 deletion doc/manual/calling-c-and-fortran-code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ In the future, some of these restrictions may be reduced or eliminated.
SIMD Values
~~~~~~~~~~~

Note: This feature is currently implemented on 64-bit x86 platforms only.
Note: This feature is currently implemented on 64-bit x86
and AArch64 platforms only.

If a C/C++ routine has an argument or return value that is a native
SIMD type, the corresponding Julia type is a homogeneous tuple
Expand Down
75 changes: 65 additions & 10 deletions src/abi_aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,53 @@ namespace {
typedef bool AbiState;
static const AbiState default_abi_state = 0;

static Type *get_llvm_vectype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
// `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields > 0`
size_t nfields = dt->nfields;
assert(nfields > 0);
if (nfields < 2)
return nullptr;
static Type *T_vec64 = VectorType::get(T_int32, 2);
static Type *T_vec128 = VectorType::get(T_int32, 4);
Type *lltype;
// Short vector should be either 8 bytes or 16 bytes.
// Note that there are only two distinct fundamental types for
// short vectors so we normalize them to <2 x i32> and <4 x i32>
switch (dt->size) {
case 8:
lltype = T_vec64;
break;
case 16:
lltype = T_vec128;
break;
default:
return nullptr;
}
// Since `dt` is pointer free and has no padding and is 8 or 16 in size,
// `ft0` must be concrete, immutable with no padding and we don't need
// to check if its size is legal since it is included in
// the homogeneity check.
jl_datatype_t *ft0 = (jl_datatype_t*)jl_field_type(dt, 0);
// `ft0` should be a `VecElement` type and the true element type
// should be a `bitstype`
if (ft0->name != jl_vecelement_typename ||
((jl_datatype_t*)jl_field_type(ft0, 0))->nfields)
return nullptr;
for (int i = 1; i < nfields; i++) {
if (jl_field_type(dt, i) != (jl_value_t*)ft0) {
// Not homogeneous
return nullptr;
}
}
return lltype;
}

static Type *get_llvm_fptype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
if (dt->mutabl || jl_datatype_nfields(dt) != 0)
return NULL;
// `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields == 0`
Type *lltype;
// Check size first since it's cheaper.
switch (dt->size) {
Expand All @@ -37,9 +79,17 @@ static Type *get_llvm_fptype(jl_datatype_t *dt)
lltype = T_float128;
break;
default:
return NULL;
return nullptr;
}
return jl_is_floattype((jl_value_t*)dt) ? lltype : NULL;
return jl_is_floattype((jl_value_t*)dt) ? lltype : nullptr;
}

static Type *get_llvm_fp_or_vectype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
if (dt->mutabl || !dt->pointerfree || dt->haspadding)
return nullptr;
return dt->nfields ? get_llvm_vectype(dt) : get_llvm_fptype(dt);
}

struct ElementType {
Expand All @@ -50,8 +100,6 @@ struct ElementType {

// Whether a type is a homogeneous floating-point aggregates (HFA) or a
// homogeneous short-vector aggregates (HVA). Returns the element type.
// We only handle HFA of HP, SP, DP and QP here since these are the only ones we
// have (no vectors).
// An Homogeneous Aggregate is a Composite Type where all of the Fundamental
// Data Types of the members that compose the type are the same.
// Note that it is the fundamental types that are important and not the member
Expand All @@ -62,6 +110,7 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType
// dt is a pointerfree type, (all members are isbits)
// dsz == dt->size > 0
// 0 <= nele <= 3
// dt has no padding

// We ignore zero sized member here. This isn't really consistent with
// GCC for zero-sized array members. GCC seems to treat structs with
Expand All @@ -83,6 +132,14 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType
dt = (jl_datatype_t*)jl_field_type(dt, i);
continue;
}
if (Type *vectype = get_llvm_vectype(dt)) {
if ((ele.sz && dsz != ele.sz) || (ele.type && ele.type != vectype))
return false;
ele.type = vectype;
ele.sz = dsz;
nele++;
return true;
}
// Otherwise, process each members
for (;i < nfields;i++) {
size_t fieldsz = jl_field_size(dt, i);
Expand Down Expand Up @@ -183,9 +240,7 @@ static Type *classify_arg(jl_value_t *ty, bool *fpreg, bool *onstack,
// the argument is allocated to the least significant bits of register
// v[NSRN]. The NSRN is incremented by one. The argument has now been
// allocated.
// Note that this is missing QP float as well as short vector types since we
// don't really have those types.
if (get_llvm_fptype(dt)) {
if (get_llvm_fp_or_vectype(dt)) {
*fpreg = true;
return NULL;
}
Expand Down Expand Up @@ -323,7 +378,7 @@ Type *preferred_llvm_type(jl_value_t *ty, bool)
if (!jl_is_datatype(ty) || jl_is_abstracttype(ty))
return NULL;
jl_datatype_t *dt = (jl_datatype_t*)ty;
if (Type *fptype = get_llvm_fptype(dt))
if (Type *fptype = get_llvm_fp_or_vectype(dt))
return fptype;
bool fpreg = false;
bool onstack = false;
Expand Down
4 changes: 2 additions & 2 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_uninitialized_datatype(size_t nfields, int8_t
// For sake of Ahead-Of-Time (AOT) compilation, this routine has to work
// without LLVM being available.
unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) {
if (!is_vecelement_type(t))
if (!jl_is_vecelement_type(t))
return 0;
// LLVM 3.7 and 3.8 either crash or generate wrong code for many
// SIMD vector sizes N. It seems the rule is that N can have at
Expand All @@ -859,7 +859,7 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) {
return 0; // nfields has more than two 1s
assert(jl_datatype_nfields(t)==1);
jl_value_t *ty = jl_field_type(t, 0);
if( !jl_is_bitstype(ty) )
if (!jl_is_bitstype(ty))
// LLVM requires that a vector element be a primitive type.
// LLVM allows pointer types as vector elements, but until a
// motivating use case comes up for Julia, we reject pointers.
Expand Down
26 changes: 26 additions & 0 deletions src/ccalltest.c
Original file line number Diff line number Diff line change
Expand Up @@ -449,4 +449,30 @@ JL_DLLEXPORT struct_aa64_2 test_aa64_fp16_2(int v1, float v2,
return x;
}

#include <arm_neon.h>

JL_DLLEXPORT int64x2_t test_aa64_vec_1(int32x2_t v1, float _v2, int32x2_t v3)
{
int v2 = (int)_v2;
return vmovl_s32(v1 * v2 + v3);
}

// This is a homogenious short vector aggregate
typedef struct {
int8x8_t v1;
float32x2_t v2;
} struct_aa64_3;

// This is NOT a homogenious short vector aggregate
typedef struct {
float32x2_t v2;
int16x8_t v1;
} struct_aa64_4;

JL_DLLEXPORT struct_aa64_3 test_aa64_vec_2(struct_aa64_3 v1, struct_aa64_4 v2)
{
struct_aa64_3 x = {v1.v1 + vmovn_s16(v2.v1), v1.v2 - v2.v2};
return x;
}

#endif
8 changes: 4 additions & 4 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ static Type *julia_struct_to_llvm(jl_value_t *jt, bool *isboxed)
latypes.push_back(lty);
}
if (!isTuple) {
if (is_vecelement_type(jt))
if (jl_is_vecelement_type(jt))
// VecElement type is unwrapped in LLVM
jst->struct_decl = latypes[0];
else
Expand Down Expand Up @@ -1101,7 +1101,7 @@ static jl_cgval_t emit_getfield_knownidx(const jl_cgval_t &strct, unsigned idx,
}
else if (strct.ispointer()) { // something stack allocated
Value *addr;
if (is_vecelement_type((jl_value_t*)jt))
if (jl_is_vecelement_type((jl_value_t*)jt))
// VecElement types are unwrapped in LLVM.
addr = strct.V;
else
Expand Down Expand Up @@ -1678,7 +1678,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg
// or instead initialize the stack buffer with stores
bool init_as_value = false;
if (lt->isVectorTy() ||
is_vecelement_type(ty) ||
jl_is_vecelement_type(ty) ||
type_is_ghost(lt)) // maybe also check the size ?
init_as_value = true;

Expand Down Expand Up @@ -1714,7 +1714,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg
strct = builder.CreateInsertValue(strct, fval, ArrayRef<unsigned>(&idx,1));
else {
// Must be a VecElement type, which comes unwrapped in LLVM.
assert(is_vecelement_type(ty));
assert(jl_is_vecelement_type(ty));
strct = fval;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ STATIC_INLINE int jl_is_tuple_type(void *t)
((jl_datatype_t*)(t))->name == jl_tuple_typename);
}

STATIC_INLINE int is_vecelement_type(jl_value_t* t)
STATIC_INLINE int jl_is_vecelement_type(jl_value_t* t)
{
return (jl_is_datatype(t) &&
((jl_datatype_t*)(t))->name == jl_vecelement_typename);
Expand Down
45 changes: 45 additions & 0 deletions test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,17 @@ immutable Struct_AA64_2
v2::Float64
end

# This is a homogenious short vector aggregate
immutable Struct_AA64_3
v1::VecReg{8,Int8}
v2::VecReg{2,Float32}
end
# This is NOT a homogenious short vector aggregate
immutable Struct_AA64_4
v2::VecReg{2,Float32}
v1::VecReg{8,Int16}
end

if Sys.ARCH === :x86_64

function test_sse(a1::V4xF32,a2::V4xF32,a3::V4xF32,a4::V4xF32)
Expand Down Expand Up @@ -590,4 +601,38 @@ elseif Sys.ARCH === :aarch64
expected = Struct_AA64_2(v4 / 2 + 1, v1 * 2 + v2 * 4 - v3)
@test res === expected
end
for v1_1 in 1:4, v1_2 in -2:2, v2 in -4:-1, v3_1 in 3:5, v3_2 in 6:8
res = ccall((:test_aa64_vec_1, libccalltest),
VecReg{2,Int64},
(VecReg{2,Int32}, Float32, VecReg{2,Int32}),
(VecElement(Int32(v1_1)), VecElement(Int32(v1_2))),
v2, (VecElement(Int32(v3_1)), VecElement(Int32(v3_2))))
expected = (VecElement(v1_1 * v2 + v3_1), VecElement(v1_2 * v2 + v3_2))
@test res === expected
end
for v1_11 in 1:4, v1_12 in -2:2, v1_21 in 1:4, v1_22 in -2:2,
v2_11 in 1:4, v2_12 in -2:2, v2_21 in 1:4, v2_22 in -2:2
v1 = Struct_AA64_3((VecElement(Int8(v1_11)), VecElement(Int8(v1_12)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0))),
(VecElement(Float32(v1_21)),
VecElement(Float32(v1_22))))
v2 = Struct_AA64_4((VecElement(Float32(v2_21)),
VecElement(Float32(v2_22))),
(VecElement(Int16(v2_11)), VecElement(Int16(v2_12)),
VecElement(Int16(0)), VecElement(Int16(0)),
VecElement(Int16(0)), VecElement(Int16(0)),
VecElement(Int16(0)), VecElement(Int16(0))))
res = ccall((:test_aa64_vec_2, libccalltest),
Struct_AA64_3, (Struct_AA64_3, Struct_AA64_4), v1, v2)
expected = Struct_AA64_3((VecElement(Int8(v1_11 + v2_11)),
VecElement(Int8(v1_12 + v2_12)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0))),
(VecElement(Float32(v1_21 - v2_21)),
VecElement(Float32(v1_22 - v2_22))))
@test res === expected
end
end

0 comments on commit 44d4ece

Please sign in to comment.