Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute comparison masks in narrower types if possible #7392

Merged
merged 9 commits into from
Mar 25, 2023

Conversation

abadams
Copy link
Member

@abadams abadams commented Mar 3, 2023

In various circumstances (e.g. boundary conditions) we generate comparisons between ramps and broadcasts and use them either as a load/store predicate, or as a select argument. The comparison is currently done in 32-bit. If used to mux between narrow types, this generates multiple vectors of comparison mask, which are then narrowed. With some care, it's possible to instead perform the comparison directly in the narrow type.

For example if we're selecting between uint8s, and we have the condition:
ramp(x, 1, 16) < broadcast(y)
where x is an Int(32), we can rewrite this to:
cast<int8_t>(ramp(0, 1, 16)) < broadcast(saturating_cast<int8_t>(saturating_sub(y, x)))
because the ramp can't take on the extreme values of an int8

For an example of the assembly generated, consider the uint8 expression select(x < 50, f(x), 17). On main when vectorized 64-wide for avx512 this compiles to:

	leal	(%rbx,%rsi), %edi
	vpbroadcastd	%edi, %zmm6
	vpaddd	%zmm0, %zmm6, %zmm7
	vpaddd	%zmm1, %zmm6, %zmm8
	vpaddd	%zmm2, %zmm6, %zmm9
	vpaddd	%zmm3, %zmm6, %zmm6
	vpcmpgtd	%zmm6, %zmm4, %k0
	vpcmpgtd	%zmm9, %zmm4, %k1
	vpcmpgtd	%zmm8, %zmm4, %k2
	vpcmpgtd	%zmm7, %zmm4, %k3
	kunpckwd	%k0, %k1, %k0
	kunpckwd	%k2, %k3, %k1
	kunpckdq	%k0, %k1, %k1
	vpblendmb	(%rax,%rsi), %zmm5, %zmm6 {%k1}
	vmovdqu64	%zmm6, (%r15,%rsi)
	addq	$64, %rsi
	cmpq	%rsi, %rdx
	jne	.LBB0_21

In this branch it generates:

	cmpl	$127, %esi
	movl	$127, %r9d
	cmovll	%esi, %r9d
	cmpl	$-127, %r9d
	cmovll	%r8d, %r9d
	vpbroadcastb	%r9d, %zmm2
	vpcmpgtb	%zmm0, %zmm2, %k1
	vpblendmb	(%rax,%rdi), %zmm1, %zmm2 {%k1}
	vmovdqu64	%zmm2, (%r15,%rdi)
	addq	$64, %rdi
	addl	$-64, %esi
	cmpq	%rdi, %rdx
	jne	.LBB0_21

According to llvm-mca, the latter is 3x faster than the former.

@abadams abadams requested a review from rootjalex March 3, 2023 18:16
@steven-johnson
Copy link
Contributor

the latter is 3x faster than the former

!!!

(I'll definitely be pulling this into google3 for some testing once we get some buildbots green)

return !narrow_predicate(op->a, t);
}

const LT *lt = p.as<LT>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably we've normalized away the GT and GE at this point?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, those should be long-gone

@abadams
Copy link
Member Author

abadams commented Mar 3, 2023

I think I can avoid the issue of y-x overflowing if I use saturating_sub. Given that it's feeding into a saturating_narrow, I don't think it should change the result compared to an infinite-precision y-x. Does that sound right @rootjalex ?

@steven-johnson
Copy link
Contributor

(Looks like some of the buildbots happened to sync to bad LLVM17 revs last night -- I'm forcing rebuilds on those. Linux will be ready soon, but the armbots will take ~all day)

auto rewrite = IRMatcher::rewriter(p, Int(32, lanes));

// Construct predicates which state the ramp can't hit the extreme
// values of an int8 or an int16. This is an overconservative condition,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I was going to suggest saturating_narrow on the ramp because it's faster (on x86 at least, and possible others?) before I realized the rewritten ramps below are all constant-folded. If we instead used bounds inference to prove this for symbolic ramps, we should probably saturating_cast the ramp as well, right?

auto min_ramp_lane = min(c0, c0 * (lanes - 1));
auto max_ramp_lane = max(c0, c0 * (lanes - 1));
auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127;
auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want a comment here explaining why these inequalities are strict. It took me a minute to work through why.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment immediately above not sufficient? I'm checking they can't hit the extreme values of the narrower type. I think in some cases it's fine, but it's quite hard to think about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an explanation of why this condition is necessary would be helpful. Also I think the inequality only needs to be strict for a strict <, pretty sure the inequality can be non-strict (is there a word for this?) for a <=

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment updated. I found it very hard to think through the cases, and just ended up using the most conservative condition for all of them.

@rootjalex
Copy link
Member

Seems like some legit failures producing signed integer overflow in constant folding

@abadams
Copy link
Member Author

abadams commented Mar 20, 2023

The signed overflow was caused by our lowering of saturating_add and saturating_sub flirting with int32 overflow in the simplifier by introducing INT_MAX and INT_MIN constants. I rewrote the lowering to only do unsigned math (and to use substantially fewer ops). I brute-force checked the correctness of the new lowerings for int8 types.

@abadams
Copy link
Member Author

abadams commented Mar 24, 2023

ptaal

@rootjalex
Copy link
Member

I brute-force checked the correctness of the new lowerings for int8 types.

I'd be more comfortable with the changes if they were formally verified. I'd offer to do it myself but don't have time to do it with traveling this week.

Comment on lines 1256 to 1257
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these just be reinterprets?

Comment on lines 1283 to 1284
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should probably also just be reinterprets?

@abadams
Copy link
Member Author

abadams commented Mar 24, 2023

wrt the new lowerings, I'm pretty confident that if they're checked correct for int8 then they're right for any signed integer, so I don't think those need formal verification.

Did you also want to see formal verification of the rewrite rules in FindIntrinsics.cpp? I might wait until you're not travelling and then get you to show me the easiest way to do that. I haven't needed to do it in a while.

@rootjalex
Copy link
Member

I was referring to the new lowerings because they are harder for me to mentally convince myself about than the additions to FindIntrinsics.cpp, but now that you mention it, those could also be verified (by scalarizing the ramps).

I'm happy to show you how I'd do it after traveling, but don't necessarily want to hold this PR back (I probably won't have time until the first week of April). I am reasonably convinced of the correctness of these rewrites so will approve this

@abadams abadams merged commit ab5f042 into main Mar 25, 2023
ardier pushed a commit to ardier/Halide-mutation that referenced this pull request Mar 3, 2024
* Compute comparison masks in narrower types if possible

* Remove reliance on infinite precision int32s

* Further elaborate on comment

* Lower signed saturating_add and sub to unsigned math

The existing lowering was prone to overflow

* cast -> reinterpret
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants