Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate aggregate constants in DataflowConstProp. #115796

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 172 additions & 18 deletions compiler/rustc_mir_transform/src/dataflow_const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
//!
//! Currently, this pass only propagates scalar values.

use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
};
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
use rustc_target::abi::{FieldIdx, VariantIdx};
use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};

use crate::const_prop::throw_machine_stop_str;
use crate::MirPass;

// These constants are somewhat random guesses and have not been optimized.
Expand Down Expand Up @@ -553,16 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {

fn try_make_constant(
&self,
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
place: Place<'tcx>,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> Option<Const<'tcx>> {
let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
return None;
};
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
Some(Const::Val(ConstValue::Scalar(value.into()), ty))
let layout = ecx.layout_of(ty).ok()?;

if layout.is_zst() {
return Some(Const::zero_sized(ty));
}

if layout.is_unsized() {
return None;
}

let place = map.find(place.as_ref())?;
if layout.abi.is_scalar()
&& let Some(value) = propagatable_scalar(place, state, map)
{
return Some(Const::Val(ConstValue::Scalar(value), ty));
}

if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
let alloc_id = ecx
.intern_with_temp_alloc(layout, |ecx, dest| {
try_write_constant(ecx, dest, place, ty, state, map)
})
.ok()?;
return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
}

None
}
}

fn propagatable_scalar(
place: PlaceIndex,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> Option<Scalar> {
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
Some(value)
} else {
None
}
}

#[instrument(level = "trace", skip(ecx, state, map))]
fn try_write_constant<'tcx>(
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
dest: &PlaceTy<'tcx>,
place: PlaceIndex,
ty: Ty<'tcx>,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> InterpResult<'tcx> {
let layout = ecx.layout_of(ty)?;

// Fast path for ZSTs.
if layout.is_zst() {
return Ok(());
}

// Fast path for scalars.
if layout.abi.is_scalar()
&& let Some(value) = propagatable_scalar(place, state, map)
{
return ecx.write_immediate(Immediate::Scalar(value), dest);
}

match ty.kind() {
// ZSTs. Nothing to do.
ty::FnDef(..) => {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ty::FnDef(..) => {}
ty::FnDef(..) => bug!(),


// Those are scalars, must be handled above.
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => bug!("primitive type with provenance"),

Is this reachable? Our constants check that this doesn't happen. Or can dataflow_const_prop operations lead to such a situation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible in case of a pointer-to-int transmute. We don't strip provenance when transmuting, so the state is a Scalar::Ptr, which is rejected by propagatable_scalar.


ty::Tuple(elem_tys) => {
for (i, elem) in elem_tys.iter().enumerate() {
let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
throw_machine_stop_str!("missing field in tuple")
};
let field_dest = ecx.project_field(dest, i)?;
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
}
}

ty::Adt(def, args) => {
if def.is_union() {
throw_machine_stop_str!("cannot propagate unions")
}

let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
throw_machine_stop_str!("missing discriminant for enum")
};
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
throw_machine_stop_str!("discriminant with provenance")
};
let discr_bits = discr.assert_bits(discr.size());
let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
throw_machine_stop_str!("illegal discriminant for enum")
};
let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
throw_machine_stop_str!("missing variant for enum")
};
let variant_dest = ecx.project_downcast(dest, variant)?;
(variant, def.variant(variant), variant_place, variant_dest)
} else {
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
};

for (i, field) in variant_def.fields.iter_enumerated() {
let ty = field.ty(*ecx.tcx, args);
let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
throw_machine_stop_str!("missing field in ADT")
};
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
}
ecx.write_discriminant(variant_idx, dest)?;
}

// Unsupported for now.
ty::Array(_, _)

// Do not attempt to support indirection in constants.
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)

| ty::Never
| ty::Foreign(..)
| ty::Alias(..)
| ty::Param(_)
| ty::Bound(..)
| ty::Placeholder(..)
| ty::Closure(..)
| ty::Coroutine(..)
| ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),

ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
}

Ok(())
}

impl<'mir, 'tcx>
Expand All @@ -580,8 +716,13 @@ impl<'mir, 'tcx>
) {
match &statement.kind {
StatementKind::Assign(box (_, rvalue)) => {
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
.visit_rvalue(rvalue, location);
OperandCollector {
state,
visitor: self,
ecx: &mut results.analysis.0.ecx,
map: &results.analysis.0.map,
}
.visit_rvalue(rvalue, location);
}
_ => (),
}
Expand All @@ -599,7 +740,12 @@ impl<'mir, 'tcx>
// Don't overwrite the assignment if it already uses a constant (to keep the span).
}
StatementKind::Assign(box (place, _)) => {
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
if let Some(value) = self.try_make_constant(
&mut results.analysis.0.ecx,
place,
state,
&results.analysis.0.map,
) {
self.patch.assignments.insert(location, value);
}
}
Expand All @@ -614,8 +760,13 @@ impl<'mir, 'tcx>
terminator: &'mir Terminator<'tcx>,
location: Location,
) {
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
.visit_terminator(terminator, location);
OperandCollector {
state,
visitor: self,
ecx: &mut results.analysis.0.ecx,
map: &results.analysis.0.map,
}
.visit_terminator(terminator, location);
}
}

Expand Down Expand Up @@ -670,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
state: &'a State<FlatSet<Scalar>>,
visitor: &'a mut Collector<'tcx, 'locals>,
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
map: &'map Map,
}

Expand All @@ -682,15 +834,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
location: Location,
) {
if let PlaceElem::Index(local) = elem
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
{
self.visitor.patch.before_effect.insert((location, local.into()), value);
}
}

fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
if let Some(place) = operand.place() {
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
if let Some(value) =
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
{
self.visitor.patch.before_effect.insert((location, place), value);
} else if !place.projection.is_empty() {
// Try to propagate into `Index` projections.
Expand All @@ -713,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
}

fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
unimplemented!()
false
}

fn before_access_global(
Expand All @@ -725,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
crate::const_prop::throw_machine_stop_str!("can't write to global");
throw_machine_stop_str!("can't write to global");
}

// If the static allocation is mutable, then we can't const prop it as its content
// might be different at runtime.
if alloc.inner().mutability.is_mut() {
crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
throw_machine_stop_str!("can't access mutable globals in ConstProp");
}

Ok(())
Expand Down Expand Up @@ -781,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
throw_machine_stop_str!("can't do pointer arithmetic");
}

fn expose_ptr(
Expand Down
9 changes: 7 additions & 2 deletions tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
+ debug ((f: (bool, bool, u32)).2: u32) => const 123_u32;
let _10: std::option::Option<u16>;
scope 7 {
debug o => _10;
- debug o => _10;
+ debug o => const Option::<u16>::Some(99_u16);
let _17: u32;
let _18: u32;
scope 8 {
Expand Down Expand Up @@ -81,7 +82,7 @@
_15 = const false;
_16 = const 123_u32;
StorageLive(_10);
_10 = Option::<u16>::Some(const 99_u16);
_10 = const Option::<u16>::Some(99_u16);
_17 = const 32_u32;
_18 = const 32_u32;
StorageLive(_11);
Expand All @@ -97,3 +98,7 @@
}
}

ALLOC0 (size: 4, align: 2) {
01 00 63 00 │ ..c.
}

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
- _6 = CheckedAdd(_4, _5);
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind unreachable];
+ _5 = const 2_i32;
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
+ _6 = const (3_i32, false);
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind unreachable];
}

Expand All @@ -60,7 +60,7 @@
- _10 = CheckedAdd(_9, const 1_i32);
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind unreachable];
+ _9 = const i32::MAX;
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
+ _10 = const (i32::MIN, true);
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind unreachable];
}

Expand All @@ -76,5 +76,13 @@
StorageDead(_1);
return;
}
+ }
+
+ ALLOC0 (size: 8, align: 4) {
+ 00 00 00 80 01 __ __ __ │ .....░░░
+ }
+
+ ALLOC1 (size: 8, align: 4) {
+ 03 00 00 00 00 __ __ __ │ .....░░░
}

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
- _6 = CheckedAdd(_4, _5);
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind continue];
+ _5 = const 2_i32;
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
+ _6 = const (3_i32, false);
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind continue];
}

Expand All @@ -60,7 +60,7 @@
- _10 = CheckedAdd(_9, const 1_i32);
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind continue];
+ _9 = const i32::MAX;
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
+ _10 = const (i32::MIN, true);
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind continue];
}

Expand All @@ -76,5 +76,13 @@
StorageDead(_1);
return;
}
+ }
+
+ ALLOC0 (size: 8, align: 4) {
+ 00 00 00 80 01 __ __ __ │ .....░░░
+ }
+
+ ALLOC1 (size: 8, align: 4) {
+ 03 00 00 00 00 __ __ __ │ .....░░░
}

2 changes: 1 addition & 1 deletion tests/mir-opt/dataflow-const-prop/checked.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// skip-filecheck
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY
// unit-test: DataflowConstProp
// compile-flags: -Coverflow-checks=on
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY

// EMIT_MIR checked.main.DataflowConstProp.diff
#[allow(arithmetic_overflow)]
Expand Down
Loading
Loading