Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating vmlal.s16 #2997

Closed
mohamedadaly opened this issue May 29, 2018 · 17 comments
Closed

Generating vmlal.s16 #2997

mohamedadaly opened this issue May 29, 2018 · 17 comments

Comments

@mohamedadaly
Copy link
Contributor

mohamedadaly commented May 29, 2018

I have this in a Func that accumulates products of 16-bit integers into 32-bit integers:

f(x, y) += i32(a(x, r)) * i32(b(r, y));

The problem is that when I look at the generated code for arm32, it doesn't have the vmlal.s16 instructions, and what it does is convert both a and b to 32-bit using vmovl.s16 before accumulating using vmla.i32. I tried various combinations of casting on a only or b only ... etc. but still don't get vmlal.

Is there a way to check or understand why this doesn't happen?

Thanks for your help.

@mohamedadaly mohamedadaly changed the title Generating vmlal.i16 Generating vmlal.s16 May 29, 2018
@zvookin
Copy link
Member

zvookin commented May 29, 2018 via email

@mohamedadaly
Copy link
Contributor Author

I found this

~/Halide/test/correctness/simd_op_check.cpp:970:            check(arm32 ? "vmlal.s16" : "smlal", 4*w, i32_1 + i32(i16_2)*i16_3);

and accordingly tried this:

f(x, y) += i32(a(x, r)) * b(r, y);

but still don't get the desired result.

  • Is there a way to check the final result of the simplification of this expression, before code generation?
  • Could scheduling have any effect on the pattern being simplified/matched e.g. b.compute_root() vs inlining b etc.?

Thanks again!

@zvookin
Copy link
Member

zvookin commented May 29, 2018 via email

@mohamedadaly
Copy link
Contributor Author

Yep, the function is vectorized. The scheduling is rather complicated, but here is the relevant part in the output .stmt file:

f[ramp(0, 1, 4)] = (f[ramp(0, 1, 4)] + (x4(int32(a[(t174 + (f.s1.r$x * t132))])) * int32x4(b[ramp((t175 + (f.s1.r$x * t134)), 1, 4)])))

and here is the relevant part in the arm32 soft-abi assembly, where it's expanding the int16 operands to int32 before doing the multiplication+accumulation in 32 bits, instead of using vmlal.s16 directly on the 16-bit operands:

	ldr	r1, [r4]
	subs	r10, r10, #1
	ldr	r0, [r4, #4]
	add	r4, r4, r11
	ldrsh	r12, [r3, #2]
	ldrsh	lr, [r3]
	ldrsh	r5, [r3, #-2]
	vdup.32	q13, r12
	str	r0, [sp, #436]
	str	r1, [sp, #432]
	vld1.16	{d24}, [r7:64]
	ldrsh	r0, [r3, #-4]
	add	r3, r3, r9
	vmovl.s16	q12, d24
	vdup.32	q14, lr
	vdup.32	q15, r5
	vdup.32	q0, r0
	vmla.i32	q8, q13, q12
	vmla.i32	q10, q14, q12
	vmla.i32	q9, q15, q12
	vmla.i32	q11, q0, q12
	bne	.LBB0_116

I ran the test/correctness/simd_op_check.cpp file and it does produce the correct assembly instruction for the pattern I have. I have also been looking at the CodeGen_ARM.cpp file and noticed there are some special handling of other instructions e.g. vqrdmulh.

  • Is it possible to insert something there?
  • Or how do I trace the source of why this expression doesn't result in using vmlal.s16? Is it the llvm simplification/optimization passes?

@abadams
Copy link
Member

abadams commented May 30, 2018 via email

@mohamedadaly
Copy link
Contributor Author

Here is the relevant llvm assembly for vectorizing by a factor of 4. I don't know much of llvm assembly, but it seems to me it's loading a vector of 4 i16 values from b and duplicating one i16 value from a into a vector, then sign-extending them to i32, and finally doing the multiplication then the accumulation in 32 bits:

  %887 = mul nsw i32 %"f.s1.r$x", %t134
  %888 = add nsw i32 %t175, %887
  %889 = getelementptr inbounds i16, i16* %b, i32 %888
  %890 = bitcast i16* %889 to <4 x i16>*
  %891 = load <4 x i16>, <4 x i16>* %890, align 2, !tbaa !45
  %892 = sext <4 x i16> %891 to <4 x i32>
  %893 = mul nsw i32 %"f.s1.r$x", %t132
  %894 = add nsw i32 %t174, %893
  %895 = add nsw i32 %894, 3
  %896 = getelementptr inbounds i16, i16* %a, i32 %895
  %897 = load i16, i16* %896, align 2, !tbaa !49
  %898 = sext i16 %897 to i32
  %899 = insertelement <4 x i32> undef, i32 %898, i32 0
  %900 = shufflevector <4 x i32> %899, <4 x i32> undef, <4 x i32> zeroinitializer
  %901 = mul nsw <4 x i32> %900, %892
  %902 = getelementptr inbounds i32, i32* %f, i32 12
  %903 = bitcast i32* %902 to <4 x i32>*
  %904 = load <4 x i32>, <4 x i32>* %903, align 16, !tbaa !67
  %905 = add nsw <4 x i32> %904, %901
  %906 = getelementptr inbounds i32, i32* %conv, i32 12
  %907 = bitcast i32* %906 to <4 x i32>*
  store <4 x i32> %905, <4 x i32>* %907, align 16, !tbaa !67
  %908 = add nsw i32 %"f.s1.r$x", 1
  %909 = icmp ne i32 %908, %825
  br i1 %909, label %"for f.s1.r$x", label %"end for f.s1.r$x"

Is there any workaround for that? Would a custom llvm pass help? If so, how much effort would it be?

@zvookin
Copy link
Member

zvookin commented May 30, 2018 via email

@mohamedadaly
Copy link
Contributor Author

I don't quite get your suggestion, but here is more information about what I am trying to do, in case it helps. I am trying to implement a GEMM for uint8 quantized inputs. a and b are the left-hand side and right-hand sides, and are int16 (after adding offset values), and f is the product.

f is tiled into two levels of blocks, with the smaller one being 4x4, and the output tile in f is updated by computing the "outer product" of 4x1 vector from a by a 1x4 vector from b. This is done by scheduling the inner tile in f as

f(x, y) += i32(a(r, y)) * b(x, r);

f.update(0).reorder(x, y, r).uroll(y).vectorize(x);

Conceptually the generated code should use vmlal.s16 instructions since I am multiplying two int16 values to produce an int32 value, but that doesn't happen :(

@mohamedadaly
Copy link
Contributor Author

@abadams Would it be feasible to add a pattern to the Mul op in src/CodeGen_ARM.cpp that will catch this case, and then an LLVM IR function in src/runtime/arm.ll that has inline arm32 assembly to call the right instruction?

@abadams
Copy link
Member

abadams commented May 30, 2018

Yes, but that's not the first thing I'd try. Before dropping to inline assembly I'd check the llvm codegen tests folder for the IR patterns that are supposed to compile to vmlal and either try to see why Halide isn't generating that, or add those patterns as intrinsics to arm.ll. This is what we did for pavgb on x86 when llvm removed the intrinsic. See: https://github.com/halide/Halide/blob/master/src/runtime/x86.ll#L2

@mohamedadaly
Copy link
Contributor Author

Thanks a lot! I will try that first, and let you know.

@zvookin
Copy link
Member

zvookin commented May 30, 2018 via email

@mohamedadaly
Copy link
Contributor Author

@zvookin Thanks!

I guess the problem with the IR matching is that what LLVM expects is this pattern (from here)

        %tmp4 = sext <4 x i16> %tmp2 to <4 x i32>
	%tmp5 = sext <4 x i16> %tmp3 to <4 x i32>
	%tmp6 = mul <4 x i32> %tmp4, %tmp5
	%tmp7 = add <4 x i32> %tmp1, %tmp6

but what's being generated by Halide is this sequence, which has insertelement and shufflevector that do the duplication of the scalar value into a vector with vdup.32, and the sext is done on a scalar instead of a vector:

  %892 = sext <4 x i16> %891 to <4 x i32>
  %898 = sext i16 %897 to i32
  %899 = insertelement <4 x i32> undef, i32 %898, i32 0
  %900 = shufflevector <4 x i32> %899, <4 x i32> undef, <4 x i32> zeroinitializer
  %901 = mul nsw <4 x i32> %900, %892
  %905 = add nsw <4 x i32> %904, %901

I will try Andrew's suggestion and see how it goes, and will be happy to submit a PR if it works :-)

@abadams
Copy link
Member

abadams commented May 31, 2018

Ah, so maybe the problem is just that it's x4(int32(foo)) instead of int32(x4(foo))!

If that's the case, then the broadcast visitor could detect widening casts in the arg and move them outwards.

@dsharletg
Copy link
Contributor

I don't think we would want to do that in general though. I think maybe what we want here is to pattern match our old pattern, and when we match it, rewrite vectorbroadcast(widen(scalar)) to vectorwiden(broadcast(scalar)).

@mohamedadaly
Copy link
Contributor Author

I implemented the suggestion by @abadams and @dsharletg to flip the order of broadcast and widen, and it works correctly. I couldn't get the tests in simd_op_check to fail though, so a new set of tests might be needed to catch this error and fix.

Thanks for your help!

@mohamedadaly
Copy link
Contributor Author

It would actually make more sense to limit the scope of the pattern matching to the case when one (or both) operands of the Mul has the pattern broadcast(widen(scalar)), so moved the changes to the Mul op instead of the Broadcast op.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants