Skip to content

Commit 31c0645

Browse files
author
Michael Benfield
committed
rustc_codegen_ssa: Fix for codegen_get_discr
When doing the optimized implementation of getting the discriminant, the arithmetic needs to be done in the tag type so wrapping behavior works correctly. Fixes #104519
1 parent fd3bfb3 commit 31c0645

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

compiler/rustc_codegen_ssa/src/mir/place.rs

+18-9
Original file line numberDiff line numberDiff line change
@@ -309,14 +309,14 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
309309
// In the algorithm above, we can change
310310
// cast(relative_tag) + niche_variants.start()
311311
// into
312-
// cast(tag) + (niche_variants.start() - niche_start)
312+
// cast(tag + (niche_variants.start() - niche_start))
313313
// if either the casted type is no larger than the original
314314
// type, or if the niche values are contiguous (in either the
315315
// signed or unsigned sense).
316-
let can_incr_after_cast = cast_smaller || niches_ule || niches_sle;
316+
let can_incr = cast_smaller || niches_ule || niches_sle;
317317

318318
let data_for_boundary_niche = || -> Option<(IntPredicate, u128)> {
319-
if !can_incr_after_cast {
319+
if !can_incr {
320320
None
321321
} else if niche_start == low_unsigned {
322322
Some((IntPredicate::IntULE, niche_end))
@@ -353,24 +353,33 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> {
353353
// The algorithm is now this:
354354
// is_niche = tag <= niche_end
355355
// discr = if is_niche {
356-
// cast(tag) + (niche_variants.start() - niche_start)
356+
// cast(tag + (niche_variants.start() - niche_start))
357357
// } else {
358358
// untagged_variant
359359
// }
360360
// (the first line may instead be tag >= niche_start,
361361
// and may be a signed or unsigned comparison)
362+
// The arithmetic must be done before the cast, so we can
363+
// have the correct wrapping behavior. See issue #104519 for
364+
// the consequences of getting this wrong.
362365
let is_niche =
363366
bx.icmp(predicate, tag, bx.cx().const_uint_big(tag_llty, constant));
367+
let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
368+
let incr_tag = if delta == 0 {
369+
tag
370+
} else {
371+
bx.add(tag, bx.cx().const_uint_big(tag_llty, delta))
372+
};
373+
364374
let cast_tag = if cast_smaller {
365-
bx.intcast(tag, cast_to, false)
375+
bx.intcast(incr_tag, cast_to, false)
366376
} else if niches_ule {
367-
bx.zext(tag, cast_to)
377+
bx.zext(incr_tag, cast_to)
368378
} else {
369-
bx.sext(tag, cast_to)
379+
bx.sext(incr_tag, cast_to)
370380
};
371381

372-
let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
373-
(is_niche, cast_tag, delta)
382+
(is_niche, cast_tag, 0)
374383
} else {
375384
// The special cases don't apply, so we'll have to go with
376385
// the general algorithm.

src/test/codegen/enum-match.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,8 @@ pub enum Enum1 {
3434

3535
// CHECK: define i8 @match1{{.*}}
3636
// CHECK-NEXT: start:
37-
// CHECK-NEXT: %1 = icmp ugt i8 %0, 1
38-
// CHECK-NEXT: %2 = zext i8 %0 to i64
39-
// CHECK-NEXT: %3 = add nsw i64 %2, -1
40-
// CHECK-NEXT: %_2 = select i1 %1, i64 %3, i64 0
41-
// CHECK-NEXT: switch i64 %_2, label {{.*}} [
37+
// CHECK-NEXT: %1 = {{.*}}call i8 @llvm.usub.sat.i8(i8 %0, i8 1)
38+
// CHECK-NEXT: switch i8 %1, label {{.*}} [
4239
#[no_mangle]
4340
pub fn match1(e: Enum1) -> u8 {
4441
use Enum1::*;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// run-pass
2+
#![allow(dead_code)]
3+
4+
enum OpenResult {
5+
Ok(()),
6+
Err(()),
7+
TransportErr(TransportErr),
8+
}
9+
10+
#[repr(i32)]
11+
enum TransportErr {
12+
UnknownMethod = -2,
13+
}
14+
15+
#[inline(never)]
16+
fn some_match(result: OpenResult) -> u8 {
17+
match result {
18+
OpenResult::Ok(()) => 0,
19+
_ => 1,
20+
}
21+
}
22+
23+
fn main() {
24+
let result = OpenResult::Ok(());
25+
assert_eq!(some_match(result), 0);
26+
27+
let result = OpenResult::Ok(());
28+
match result {
29+
OpenResult::Ok(()) => (),
30+
_ => unreachable!("message a"),
31+
}
32+
match result {
33+
OpenResult::Ok(()) => (),
34+
_ => unreachable!("message b"),
35+
}
36+
}

0 commit comments

Comments
 (0)