Skip to content

Commit 4d45b07

Browse files
committed
Auto merge of rust-lang#100571 - cjgillot:mir-cost-visit, r=compiler-errors
Check projection types before inlining MIR Fixes rust-lang#100550 I'm very unhappy with this solution, having to duplicate MIR validation code, but at least it removes the ICE. r? `@compiler-errors`
2 parents 76531be + 911cbae commit 4d45b07

File tree

3 files changed

+266
-108
lines changed

3 files changed

+266
-108
lines changed

compiler/rustc_const_eval/src/transform/validate.rs

+14-16
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,20 @@ pub fn equal_up_to_regions<'tcx>(
8989

9090
// Normalize lifetimes away on both sides, then compare.
9191
let normalize = |ty: Ty<'tcx>| {
92-
tcx.normalize_erasing_regions(
93-
param_env,
94-
ty.fold_with(&mut BottomUpFolder {
95-
tcx,
96-
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
97-
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
98-
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
99-
// since one may have an `impl SomeTrait for fn(&32)` and
100-
// `impl SomeTrait for fn(&'static u32)` at the same time which
101-
// specify distinct values for Assoc. (See also #56105)
102-
lt_op: |_| tcx.lifetimes.re_erased,
103-
// Leave consts and types unchanged.
104-
ct_op: |ct| ct,
105-
ty_op: |ty| ty,
106-
}),
107-
)
92+
let ty = ty.fold_with(&mut BottomUpFolder {
93+
tcx,
94+
// FIXME: We erase all late-bound lifetimes, but this is not fully correct.
95+
// If you have a type like `<for<'a> fn(&'a u32) as SomeTrait>::Assoc`,
96+
// this is not necessarily equivalent to `<fn(&'static u32) as SomeTrait>::Assoc`,
97+
// since one may have an `impl SomeTrait for fn(&32)` and
98+
// `impl SomeTrait for fn(&'static u32)` at the same time which
99+
// specify distinct values for Assoc. (See also #56105)
100+
lt_op: |_| tcx.lifetimes.re_erased,
101+
// Leave consts and types unchanged.
102+
ct_op: |ct| ct,
103+
ty_op: |ty| ty,
104+
});
105+
tcx.try_normalize_erasing_regions(param_env, ty).unwrap_or(ty)
108106
};
109107
tcx.infer_ctxt().enter(|infcx| infcx.can_eq(param_env, normalize(src), normalize(dest)).is_ok())
110108
}

compiler/rustc_mir_transform/src/inline.rs

+222-92
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyC
1212
use rustc_session::config::OptLevel;
1313
use rustc_span::def_id::DefId;
1414
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
15+
use rustc_target::abi::VariantIdx;
1516
use rustc_target::spec::abi::Abi;
1617

1718
use super::simplify::{remove_dead_blocks, CfgSimplifier};
@@ -414,118 +415,60 @@ impl<'tcx> Inliner<'tcx> {
414415
debug!(" final inline threshold = {}", threshold);
415416

416417
// FIXME: Give a bonus to functions with only a single caller
417-
let mut first_block = true;
418-
let mut cost = 0;
418+
let diverges = matches!(
419+
callee_body.basic_blocks()[START_BLOCK].terminator().kind,
420+
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
421+
);
422+
if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
423+
return Err("callee diverges unconditionally");
424+
}
425+
426+
let mut checker = CostChecker {
427+
tcx: self.tcx,
428+
param_env: self.param_env,
429+
instance: callsite.callee,
430+
callee_body,
431+
cost: 0,
432+
validation: Ok(()),
433+
};
419434

420-
// Traverse the MIR manually so we can account for the effects of
421-
// inlining on the CFG.
435+
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
422436
let mut work_list = vec![START_BLOCK];
423437
let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
424438
while let Some(bb) = work_list.pop() {
425439
if !visited.insert(bb.index()) {
426440
continue;
427441
}
442+
428443
let blk = &callee_body.basic_blocks()[bb];
444+
checker.visit_basic_block_data(bb, blk);
429445

430-
for stmt in &blk.statements {
431-
// Don't count StorageLive/StorageDead in the inlining cost.
432-
match stmt.kind {
433-
StatementKind::StorageLive(_)
434-
| StatementKind::StorageDead(_)
435-
| StatementKind::Deinit(_)
436-
| StatementKind::Nop => {}
437-
_ => cost += INSTR_COST,
438-
}
439-
}
440446
let term = blk.terminator();
441-
let mut is_drop = false;
442-
match term.kind {
443-
TerminatorKind::Drop { ref place, target, unwind }
444-
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
445-
is_drop = true;
446-
work_list.push(target);
447-
// If the place doesn't actually need dropping, treat it like
448-
// a regular goto.
449-
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
450-
if ty.needs_drop(tcx, self.param_env) {
451-
cost += CALL_PENALTY;
452-
if let Some(unwind) = unwind {
453-
cost += LANDINGPAD_PENALTY;
454-
work_list.push(unwind);
455-
}
456-
} else {
457-
cost += INSTR_COST;
458-
}
459-
}
460-
461-
TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
462-
if first_block =>
463-
{
464-
// If the function always diverges, don't inline
465-
// unless the cost is zero
466-
threshold = 0;
467-
}
468-
469-
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
470-
if let ty::FnDef(def_id, _) =
471-
*callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
472-
{
473-
// Don't give intrinsics the extra penalty for calls
474-
if tcx.is_intrinsic(def_id) {
475-
cost += INSTR_COST;
476-
} else {
477-
cost += CALL_PENALTY;
478-
}
479-
} else {
480-
cost += CALL_PENALTY;
481-
}
482-
if cleanup.is_some() {
483-
cost += LANDINGPAD_PENALTY;
484-
}
485-
}
486-
TerminatorKind::Assert { cleanup, .. } => {
487-
cost += CALL_PENALTY;
488-
489-
if cleanup.is_some() {
490-
cost += LANDINGPAD_PENALTY;
491-
}
492-
}
493-
TerminatorKind::Resume => cost += RESUME_PENALTY,
494-
TerminatorKind::InlineAsm { cleanup, .. } => {
495-
cost += INSTR_COST;
447+
if let TerminatorKind::Drop { ref place, target, unwind }
448+
| TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
449+
{
450+
work_list.push(target);
496451

497-
if cleanup.is_some() {
498-
cost += LANDINGPAD_PENALTY;
452+
// If the place doesn't actually need dropping, treat it like a regular goto.
453+
let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
454+
if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
455+
work_list.push(unwind);
499456
}
500-
}
501-
_ => cost += INSTR_COST,
502-
}
503-
504-
if !is_drop {
505-
for succ in term.successors() {
506-
work_list.push(succ);
507-
}
457+
} else {
458+
work_list.extend(term.successors())
508459
}
509-
510-
first_block = false;
511460
}
512461

513462
// Count up the cost of local variables and temps, if we know the size
514463
// use that, otherwise we use a moderately-large dummy cost.
515-
516-
let ptr_size = tcx.data_layout.pointer_size.bytes();
517-
518464
for v in callee_body.vars_and_temps_iter() {
519-
let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
520-
// Cost of the var is the size in machine-words, if we know
521-
// it.
522-
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
523-
cost += ((size + ptr_size - 1) / ptr_size) as usize;
524-
} else {
525-
cost += UNKNOWN_SIZE_COST;
526-
}
465+
checker.visit_local_decl(v, &callee_body.local_decls[v]);
527466
}
528467

468+
// Abort if type validation found anything fishy.
469+
checker.validation?;
470+
471+
let cost = checker.cost;
529472
if let InlineAttr::Always = callee_attrs.inline {
530473
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
531474
Ok(())
@@ -799,6 +742,193 @@ fn type_size_of<'tcx>(
799742
tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
800743
}
801744

745+
/// Verify that the callee body is compatible with the caller.
746+
///
747+
/// This visitor mostly computes the inlining cost,
748+
/// but also needs to verify that types match because of normalization failure.
749+
struct CostChecker<'b, 'tcx> {
750+
tcx: TyCtxt<'tcx>,
751+
param_env: ParamEnv<'tcx>,
752+
cost: usize,
753+
callee_body: &'b Body<'tcx>,
754+
instance: ty::Instance<'tcx>,
755+
validation: Result<(), &'static str>,
756+
}
757+
758+
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
759+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
760+
// Don't count StorageLive/StorageDead in the inlining cost.
761+
match statement.kind {
762+
StatementKind::StorageLive(_)
763+
| StatementKind::StorageDead(_)
764+
| StatementKind::Deinit(_)
765+
| StatementKind::Nop => {}
766+
_ => self.cost += INSTR_COST,
767+
}
768+
769+
self.super_statement(statement, location);
770+
}
771+
772+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
773+
let tcx = self.tcx;
774+
match terminator.kind {
775+
TerminatorKind::Drop { ref place, unwind, .. }
776+
| TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
777+
// If the place doesn't actually need dropping, treat it like a regular goto.
778+
let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
779+
if ty.needs_drop(tcx, self.param_env) {
780+
self.cost += CALL_PENALTY;
781+
if unwind.is_some() {
782+
self.cost += LANDINGPAD_PENALTY;
783+
}
784+
} else {
785+
self.cost += INSTR_COST;
786+
}
787+
}
788+
TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
789+
let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
790+
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
791+
// Don't give intrinsics the extra penalty for calls
792+
INSTR_COST
793+
} else {
794+
CALL_PENALTY
795+
};
796+
if cleanup.is_some() {
797+
self.cost += LANDINGPAD_PENALTY;
798+
}
799+
}
800+
TerminatorKind::Assert { cleanup, .. } => {
801+
self.cost += CALL_PENALTY;
802+
if cleanup.is_some() {
803+
self.cost += LANDINGPAD_PENALTY;
804+
}
805+
}
806+
TerminatorKind::Resume => self.cost += RESUME_PENALTY,
807+
TerminatorKind::InlineAsm { cleanup, .. } => {
808+
self.cost += INSTR_COST;
809+
if cleanup.is_some() {
810+
self.cost += LANDINGPAD_PENALTY;
811+
}
812+
}
813+
_ => self.cost += INSTR_COST,
814+
}
815+
816+
self.super_terminator(terminator, location);
817+
}
818+
819+
/// Count up the cost of local variables and temps, if we know the size
820+
/// use that, otherwise we use a moderately-large dummy cost.
821+
fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
822+
let tcx = self.tcx;
823+
let ptr_size = tcx.data_layout.pointer_size.bytes();
824+
825+
let ty = self.instance.subst_mir(tcx, &local_decl.ty);
826+
// Cost of the var is the size in machine-words, if we know
827+
// it.
828+
if let Some(size) = type_size_of(tcx, self.param_env, ty) {
829+
self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
830+
} else {
831+
self.cost += UNKNOWN_SIZE_COST;
832+
}
833+
834+
self.super_local_decl(local, local_decl)
835+
}
836+
837+
/// This method duplicates code from MIR validation in an attempt to detect type mismatches due
838+
/// to normalization failure.
839+
fn visit_projection_elem(
840+
&mut self,
841+
local: Local,
842+
proj_base: &[PlaceElem<'tcx>],
843+
elem: PlaceElem<'tcx>,
844+
context: PlaceContext,
845+
location: Location,
846+
) {
847+
if let ProjectionElem::Field(f, ty) = elem {
848+
let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
849+
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
850+
let check_equal = |this: &mut Self, f_ty| {
851+
if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
852+
trace!(?ty, ?f_ty);
853+
this.validation = Err("failed to normalize projection type");
854+
return;
855+
}
856+
};
857+
858+
let kind = match parent_ty.ty.kind() {
859+
&ty::Opaque(def_id, substs) => {
860+
self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
861+
}
862+
kind => kind,
863+
};
864+
865+
match kind {
866+
ty::Tuple(fields) => {
867+
let Some(f_ty) = fields.get(f.as_usize()) else {
868+
self.validation = Err("malformed MIR");
869+
return;
870+
};
871+
check_equal(self, *f_ty);
872+
}
873+
ty::Adt(adt_def, substs) => {
874+
let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
875+
let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
876+
self.validation = Err("malformed MIR");
877+
return;
878+
};
879+
check_equal(self, field.ty(self.tcx, substs));
880+
}
881+
ty::Closure(_, substs) => {
882+
let substs = substs.as_closure();
883+
let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
884+
self.validation = Err("malformed MIR");
885+
return;
886+
};
887+
check_equal(self, f_ty);
888+
}
889+
&ty::Generator(def_id, substs, _) => {
890+
let f_ty = if let Some(var) = parent_ty.variant_index {
891+
let gen_body = if def_id == self.callee_body.source.def_id() {
892+
self.callee_body
893+
} else {
894+
self.tcx.optimized_mir(def_id)
895+
};
896+
897+
let Some(layout) = gen_body.generator_layout() else {
898+
self.validation = Err("malformed MIR");
899+
return;
900+
};
901+
902+
let Some(&local) = layout.variant_fields[var].get(f) else {
903+
self.validation = Err("malformed MIR");
904+
return;
905+
};
906+
907+
let Some(&f_ty) = layout.field_tys.get(local) else {
908+
self.validation = Err("malformed MIR");
909+
return;
910+
};
911+
912+
f_ty
913+
} else {
914+
let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
915+
self.validation = Err("malformed MIR");
916+
return;
917+
};
918+
919+
f_ty
920+
};
921+
922+
check_equal(self, f_ty);
923+
}
924+
_ => self.validation = Err("malformed MIR"),
925+
}
926+
}
927+
928+
self.super_projection_elem(local, proj_base, elem, context, location);
929+
}
930+
}
931+
802932
/**
803933
* Integrator.
804934
*

0 commit comments

Comments
 (0)