Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mono] [Arm64] Added SIMD support for vector 2/3/4 methods #98761

Merged
merged 16 commits into from
Mar 15, 2024
Merged
292 changes: 199 additions & 93 deletions src/mono/mono/mini/simd-intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ emit_xconst_v128 (MonoCompile *cfg, MonoClass *klass, guint8 value[16])
ins->type = STACK_VTYPE;
ins->dreg = alloc_xreg (cfg);
ins->inst_p0 = mono_mem_manager_alloc (cfg->mem_manager, size);
ins->klass = klass;
MONO_ADD_INS (cfg->cbb, ins);

memcpy (ins->inst_p0, &value[0], size);
Expand Down Expand Up @@ -1390,6 +1391,76 @@ emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoType
}
#endif

static MonoInst*
emit_dot (MonoCompile *cfg, MonoClass *klass, MonoType *vector_type, MonoTypeEnum arg0_type, int sreg1, int sreg2) {
if (!is_element_type_primitive (vector_type))
return NULL;
#if defined(TARGET_WASM)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8))
return NULL;
#elif defined(TARGET_ARM64)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U))
return NULL;
#endif

#if defined(TARGET_ARM64) || defined(TARGET_WASM)
MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2);
pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
pairwise_multiply->inst_c1 = arg0_type;
return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply);
#elif defined(TARGET_AMD64)
int instc =-1;
if (type_enum_is_float (arg0_type)) {
if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) {
int mask_val = -1;
switch (arg0_type) {
case MONO_TYPE_R4:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM;
mask_val = 0xf1; // 0xf1 ... 0b11110001
break;
case MONO_TYPE_R8:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM;
mask_val = 0x31; // 0x31 ... 0b00110001
break;
default:
return NULL;
}

MonoInst *dot;
if (COMPILE_LLVM (cfg)) {
int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val);

dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2);
dot->sreg3 = mask_reg;
} else {
dot = emit_simd_ins (cfg, klass, instc, sreg1, sreg2);
dot->inst_c0 = mask_val;
}
return extract_first_element (cfg, klass, arg0_type, dot->dreg);
} else {
instc = OP_FMUL;
}
} else {
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1)
return NULL; // We don't support sum vector for byte, sbyte types yet

// FIXME:
if (!COMPILE_LLVM (cfg))
return NULL;

instc = OP_IMUL;
}
MonoInst *pairwise_multiply = emit_simd_ins (cfg, klass, OP_XBINOP, sreg1, sreg2);
pairwise_multiply->inst_c0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
pairwise_multiply->inst_c1 = arg0_type;

return emit_sum_vector (cfg, vector_type, arg0_type, pairwise_multiply);
#else
return NULL;
#endif
}

/*
* Emit intrinsics in System.Numerics.Vector and System.Runtime.Intrinsics.Vector64/128/256/512.
* If the intrinsic is not supported for some reasons, return NULL, and fall back to the c#
Expand Down Expand Up @@ -1768,70 +1839,7 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
}
}
case SN_Dot: {
if (!is_element_type_primitive (fsig->params [0]))
return NULL;
#if defined(TARGET_WASM)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8))
return NULL;
#elif defined(TARGET_ARM64)
if (!COMPILE_LLVM (cfg) && (arg0_type == MONO_TYPE_I8 || arg0_type == MONO_TYPE_U8 || arg0_type == MONO_TYPE_I || arg0_type == MONO_TYPE_U))
return NULL;
#endif

#if defined(TARGET_ARM64) || defined(TARGET_WASM)
int instc0 = type_enum_is_float (arg0_type) ? OP_FMUL : OP_IMUL;
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc0, arg0_type, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#elif defined(TARGET_AMD64)
int instc =-1;
if (type_enum_is_float (arg0_type)) {
if (is_SIMD_feature_supported (cfg, MONO_CPU_X86_SSE41)) {
int mask_val = -1;
switch (arg0_type) {
case MONO_TYPE_R4:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPS : OP_SSE41_DPPS_IMM;
mask_val = 0xf1; // 0xf1 ... 0b11110001
break;
case MONO_TYPE_R8:
instc = COMPILE_LLVM (cfg) ? OP_SSE41_DPPD : OP_SSE41_DPPD_IMM;
mask_val = 0x31; // 0x31 ... 0b00110001
break;
default:
return NULL;
}

MonoInst *dot;
if (COMPILE_LLVM (cfg)) {
int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, mask_val);

dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg);
dot->sreg3 = mask_reg;
} else {
dot = emit_simd_ins (cfg, klass, instc, args [0]->dreg, args [1]->dreg);
dot->inst_c0 = mask_val;
}

return extract_first_element (cfg, klass, arg0_type, dot->dreg);
} else {
instc = OP_FMUL;
}
} else {
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1)
return NULL; // We don't support sum vector for byte, sbyte types yet

// FIXME:
if (!COMPILE_LLVM (cfg))
return NULL;

instc = OP_IMUL;
}
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, instc, arg0_type, fsig, args);

return emit_sum_vector (cfg, fsig->params [0], arg0_type, pairwise_multiply);
#else
return NULL;
#endif
return emit_dot (cfg, klass, fsig->params [0], arg0_type, args [0]->dreg, args [1]->dreg);
}
case SN_Equals:
case SN_EqualsAll:
Expand Down Expand Up @@ -2910,6 +2918,8 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
value [1] = 1.0f;
value [2] = 1.0f;
value [3] = 1.0f;
if (len == 3)
value [3] = 0.0f;
return emit_xconst_v128 (cfg, klass, (guint8*)value);
}
case SN_set_Item: {
Expand Down Expand Up @@ -2988,28 +2998,7 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
return emit_simd_ins_for_binary_op (cfg, klass, fsig, args, MONO_TYPE_R4, id);
}
case SN_Dot: {
#if defined(TARGET_ARM64) || defined(TARGET_WASM)
MonoInst *pairwise_multiply = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FMUL, MONO_TYPE_R4, fsig, args);
return emit_sum_vector (cfg, fsig->params [0], MONO_TYPE_R4, pairwise_multiply);
#elif defined(TARGET_AMD64)
if (!(mini_get_cpu_features (cfg) & MONO_CPU_X86_SSE41))
return NULL;

int mask_reg = alloc_ireg (cfg);
MONO_EMIT_NEW_ICONST (cfg, mask_reg, 0xf1);
MonoInst *dot = emit_simd_ins (cfg, klass, OP_SSE41_DPPS, args [0]->dreg, args [1]->dreg);
dot->sreg3 = mask_reg;

MONO_INST_NEW (cfg, ins, OP_EXTRACT_R4);
ins->dreg = alloc_freg (cfg);
ins->sreg1 = dot->dreg;
ins->inst_c0 = 0;
ins->inst_c1 = MONO_TYPE_R4;
MONO_ADD_INS (cfg->cbb, ins);
return ins;
#else
return NULL;
#endif
return emit_dot (cfg, klass, fsig->params [0], MONO_TYPE_R4, args [0]->dreg, args [1]->dreg);
}
case SN_Negate:
case SN_op_UnaryNegation: {
Expand Down Expand Up @@ -3061,7 +3050,6 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f
#endif
}
case SN_CopyTo:
// FIXME: https://github.com/dotnet/runtime/issues/91394
return NULL;
case SN_Clamp: {
if (!(!fsig->hasthis && fsig->param_count == 3 && mono_metadata_type_equal (fsig->ret, type) && mono_metadata_type_equal (fsig->params [0], type) && mono_metadata_type_equal (fsig->params [1], type) && mono_metadata_type_equal (fsig->params [2], type)))
Expand All @@ -3077,15 +3065,133 @@ emit_vector_2_3_4 (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *f

return min;
}
case SN_Conjugate:
case SN_Distance:
case SN_DistanceSquared:
case SN_Distance:
case SN_DistanceSquared: {
#if defined(TARGET_ARM64)
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
MonoInst *diffs = emit_simd_ins_for_sig (cfg, klass, OP_XBINOP, OP_FSUB, MONO_TYPE_R4, fsig, args);
MonoInst *dot = emit_dot(cfg, klass, fsig->params [0], MONO_TYPE_R4, diffs->dreg, diffs->dreg);

switch (id) {
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
case SN_Distance: {
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt->inst_c1 = MONO_TYPE_R4;

MonoInst *distance = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1);
distance->inst_c0 = 0;
distance->inst_c1 = MONO_TYPE_R4;
return distance;
}
case SN_DistanceSquared:
return dot;
default:
g_assert_not_reached ();
}
#else
return NULL;
#endif
}
case SN_Length:
case SN_LengthSquared:
case SN_Lerp:
case SN_LengthSquared: {
#if defined (TARGET_ARM64)
int src1 = load_simd_vreg (cfg, cmethod, args [0], NULL);
MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, src1, src1);

switch (id) {
jkurdek marked this conversation as resolved.
Show resolved Hide resolved
case SN_Length: {
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt->inst_c1 = MONO_TYPE_R4;

MonoInst *length = emit_simd_ins (cfg, klass, OP_EXTRACT_R4, sqrt->dreg, -1);
length->inst_c0 = 0;
length->inst_c1 = MONO_TYPE_R4;
return length;
}
case SN_LengthSquared:
return dot;
default:
g_assert_not_reached ();
}
#else
return NULL;
#endif
}
case SN_Lerp: {
#if defined (TARGET_ARM64)
MonoInst* v1 = args [1];
if (!strcmp ("Quaternion", m_class_get_name (klass)))
return NULL;


MonoInst *diffs = emit_simd_ins (cfg, klass, OP_XBINOP, v1->dreg, args [0]->dreg);
diffs->inst_c0 = OP_FSUB;
diffs->inst_c1 = MONO_TYPE_R4;

MonoInst *scaled_diffs = handle_mul_div_by_scalar (cfg, klass, MONO_TYPE_R4, args [2]->dreg, diffs->dreg, OP_FMUL);

MonoInst *result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, scaled_diffs->dreg);
result->inst_c0 = OP_FADD;
result->inst_c1 = MONO_TYPE_R4;

return result;
#else
return NULL;
#endif
}
case SN_Normalize: {
// FIXME: https://github.com/dotnet/runtime/issues/91394
#if defined (TARGET_ARM64)
MonoInst* vec = args[0];
const char *class_name = m_class_get_name (klass);
if (!strcmp ("Plane", class_name)) {
static float r4_0 = 0;
MonoInst *zero;
int zero_dreg = alloc_freg (cfg);
MONO_INST_NEW (cfg, zero, OP_R4CONST);
zero->inst_p0 = (void*)&r4_0;
zero->dreg = zero_dreg;
MONO_ADD_INS (cfg->cbb, zero);
vec = emit_vector_insert_element (cfg, klass, vec, MONO_TYPE_R4, zero, 3, FALSE);
}

MonoInst *dot = emit_dot(cfg, klass, type, MONO_TYPE_R4, vec->dreg, vec->dreg);
dot = emit_simd_ins (cfg, klass, OP_EXPAND_R4, dot->dreg, -1);
dot->inst_c1 = MONO_TYPE_R4;

MonoInst *sqrt_vec = emit_simd_ins (cfg, klass, OP_XOP_OVR_X_X, dot->dreg, -1);
sqrt_vec->inst_c0 = INTRINS_AARCH64_ADV_SIMD_FSQRT;
sqrt_vec->inst_c1 = MONO_TYPE_R4;

MonoInst *normalized_vec = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, sqrt_vec->dreg);
normalized_vec->inst_c0 = OP_FDIV;
normalized_vec->inst_c1 = MONO_TYPE_R4;

return normalized_vec;
#else
return NULL;
#endif
}
case SN_Conjugate: {
#if defined (TARGET_ARM64)
float value[4];
value [0] = -1.0f;
value [1] = -1.0f;
value [2] = -1.0f;
value [3] = 1.0f;
MonoInst* r = emit_xconst_v128 (cfg, klass, (guint8*)value);
MonoInst* result = emit_simd_ins (cfg, klass, OP_XBINOP, args [0]->dreg, r->dreg);
result->inst_c0 = OP_FMUL;
result->inst_c1 = MONO_TYPE_R4;
return result;
#else
return NULL;
#endif
}
default:
g_assert_not_reached ();
Expand Down