Skip to content

Commit 6d9b03e

Browse files
committed
Auto merge of #123572 - Mark-Simulacrum:vtable-methods, r=oli-obk
Increase vtable layout size This improves LLVM's codegen by allowing vtable loads to be hoisted out of loops (as just one example). The calculation here is an under-approximation but works for simple trait hierarchies (e.g., FnMut will be improved). We have a runtime assert that the approximation is accurate, so there's no risk of UB as a result of getting this wrong. ```rust #[no_mangle] pub fn foo(elements: &[u32], callback: &mut dyn Callback) { for element in elements.iter() { if *element != 0 { callback.call(*element); } } } pub trait Callback { fn call(&mut self, _: u32); } ``` Simplifying a bit (e.g., numbering ends up different): ```diff ; Function Attrs: nonlazybind uwtable -define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(24) %callback.1) unnamed_addr #0 { +define void `@foo(ptr` noalias noundef nonnull readonly align 4 %elements.0, i64 noundef %elements.1, ptr noundef nonnull align 1 %callback.0, ptr noalias nocapture noundef readonly align 8 dereferenceable(32) %callback.1) unnamed_addr #0 { start: %_15 = getelementptr inbounds i32, ptr %elements.0, i64 %elements.1 `@@` -13,4 +13,5 `@@` bb4.lr.ph: ; preds = %start %1 = getelementptr inbounds i8, ptr %callback.1, i64 24 + %2 = load ptr, ptr %1, align 8, !nonnull !3 br label %bb4 bb6: ; preds = %bb4 - %4 = load ptr, ptr %1, align 8, !invariant.load !3, !nonnull !3 - tail call void %4(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9) + tail call void %2(ptr noundef nonnull align 1 %callback.0, i32 noundef %_9) br label %bb7 } ```
2 parents 39d2f2a + fbc3dff commit 6d9b03e

File tree

10 files changed

+141
-124
lines changed

10 files changed

+141
-124
lines changed

compiler/rustc_hir_analysis/src/coherence/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use rustc_middle::query::Providers;
1212
use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
1313
use rustc_session::parse::feature_err;
1414
use rustc_span::{sym, ErrorGuaranteed};
15-
use rustc_trait_selection::traits;
1615

1716
mod builtin;
1817
mod inherent_impls;
@@ -199,7 +198,7 @@ fn check_object_overlap<'tcx>(
199198
// With the feature enabled, the trait is not implemented automatically,
200199
// so this is valid.
201200
} else {
202-
let mut supertrait_def_ids = traits::supertrait_def_ids(tcx, component_def_id);
201+
let mut supertrait_def_ids = tcx.supertrait_def_ids(component_def_id);
203202
if supertrait_def_ids.any(|d| d == trait_def_id) {
204203
let span = tcx.def_span(impl_def_id);
205204
return Err(struct_span_code_err!(

compiler/rustc_middle/src/ty/layout.rs

+7-18
Original file line numberDiff line numberDiff line change
@@ -826,25 +826,14 @@ where
826826
});
827827
}
828828

829-
let mk_dyn_vtable = || {
829+
let mk_dyn_vtable = |principal: Option<ty::PolyExistentialTraitRef<'tcx>>| {
830+
let min_count = ty::vtable_min_entries(tcx, principal);
830831
Ty::new_imm_ref(
831832
tcx,
832833
tcx.lifetimes.re_static,
833-
Ty::new_array(tcx, tcx.types.usize, 3),
834+
// FIXME: properly type (e.g. usize and fn pointers) the fields.
835+
Ty::new_array(tcx, tcx.types.usize, min_count.try_into().unwrap()),
834836
)
835-
/* FIXME: use actual fn pointers
836-
Warning: naively computing the number of entries in the
837-
vtable by counting the methods on the trait + methods on
838-
all parent traits does not work, because some methods can
839-
be not object safe and thus excluded from the vtable.
840-
Increase this counter if you tried to implement this but
841-
failed to do it without duplicating a lot of code from
842-
other places in the compiler: 2
843-
Ty::new_tup(tcx,&[
844-
Ty::new_array(tcx,tcx.types.usize, 3),
845-
Ty::new_array(tcx,Option<fn()>),
846-
])
847-
*/
848837
};
849838

850839
let metadata = if let Some(metadata_def_id) = tcx.lang_items().metadata_type()
@@ -863,16 +852,16 @@ where
863852
// `std::mem::uninitialized::<&dyn Trait>()`, for example.
864853
if let ty::Adt(def, args) = metadata.kind()
865854
&& Some(def.did()) == tcx.lang_items().dyn_metadata()
866-
&& args.type_at(0).is_trait()
855+
&& let ty::Dynamic(data, _, ty::Dyn) = args.type_at(0).kind()
867856
{
868-
mk_dyn_vtable()
857+
mk_dyn_vtable(data.principal())
869858
} else {
870859
metadata
871860
}
872861
} else {
873862
match tcx.struct_tail_erasing_lifetimes(pointee, cx.param_env()).kind() {
874863
ty::Slice(_) | ty::Str => tcx.types.usize,
875-
ty::Dynamic(_, _, ty::Dyn) => mk_dyn_vtable(),
864+
ty::Dynamic(data, _, ty::Dyn) => mk_dyn_vtable(data.principal()),
876865
_ => bug!("TyAndLayout::field({:?}): not applicable", this),
877866
}
878867
};

compiler/rustc_middle/src/ty/vtable.rs

+62
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::fmt;
33
use crate::mir::interpret::{alloc_range, AllocId, Allocation, Pointer, Scalar};
44
use crate::ty::{self, Instance, PolyTraitRef, Ty, TyCtxt};
55
use rustc_ast::Mutability;
6+
use rustc_data_structures::fx::FxHashSet;
7+
use rustc_hir::def_id::DefId;
68
use rustc_macros::HashStable;
79

810
#[derive(Clone, Copy, PartialEq, HashStable)]
@@ -40,12 +42,69 @@ impl<'tcx> fmt::Debug for VtblEntry<'tcx> {
4042
impl<'tcx> TyCtxt<'tcx> {
4143
pub const COMMON_VTABLE_ENTRIES: &'tcx [VtblEntry<'tcx>] =
4244
&[VtblEntry::MetadataDropInPlace, VtblEntry::MetadataSize, VtblEntry::MetadataAlign];
45+
46+
pub fn supertrait_def_ids(self, trait_def_id: DefId) -> SupertraitDefIds<'tcx> {
47+
SupertraitDefIds {
48+
tcx: self,
49+
stack: vec![trait_def_id],
50+
visited: Some(trait_def_id).into_iter().collect(),
51+
}
52+
}
4353
}
4454

4555
pub const COMMON_VTABLE_ENTRIES_DROPINPLACE: usize = 0;
4656
pub const COMMON_VTABLE_ENTRIES_SIZE: usize = 1;
4757
pub const COMMON_VTABLE_ENTRIES_ALIGN: usize = 2;
4858

59+
pub struct SupertraitDefIds<'tcx> {
60+
tcx: TyCtxt<'tcx>,
61+
stack: Vec<DefId>,
62+
visited: FxHashSet<DefId>,
63+
}
64+
65+
impl Iterator for SupertraitDefIds<'_> {
66+
type Item = DefId;
67+
68+
fn next(&mut self) -> Option<DefId> {
69+
let def_id = self.stack.pop()?;
70+
let predicates = self.tcx.super_predicates_of(def_id);
71+
let visited = &mut self.visited;
72+
self.stack.extend(
73+
predicates
74+
.predicates
75+
.iter()
76+
.filter_map(|(pred, _)| pred.as_trait_clause())
77+
.map(|trait_ref| trait_ref.def_id())
78+
.filter(|&super_def_id| visited.insert(super_def_id)),
79+
);
80+
Some(def_id)
81+
}
82+
}
83+
84+
// Note that we don't have access to a self type here, this has to be purely based on the trait (and
85+
// supertrait) definitions. That means we can't call into the same vtable_entries code since that
86+
// returns a specific instantiation (e.g., with Vacant slots when bounds aren't satisfied). The goal
87+
// here is to do a best-effort approximation without duplicating a lot of code.
88+
//
89+
// This function is used in layout computation for e.g. &dyn Trait, so it's critical that this
90+
// function is an accurate approximation. We verify this when actually computing the vtable below.
91+
pub(crate) fn vtable_min_entries<'tcx>(
92+
tcx: TyCtxt<'tcx>,
93+
trait_ref: Option<ty::PolyExistentialTraitRef<'tcx>>,
94+
) -> usize {
95+
let mut count = TyCtxt::COMMON_VTABLE_ENTRIES.len();
96+
let Some(trait_ref) = trait_ref else {
97+
return count;
98+
};
99+
100+
// This includes self in supertraits.
101+
for def_id in tcx.supertrait_def_ids(trait_ref.def_id()) {
102+
count += tcx.own_existential_vtable_entries(def_id).len();
103+
}
104+
105+
count
106+
}
107+
49108
/// Retrieves an allocation that represents the contents of a vtable.
50109
/// Since this is a query, allocations are cached and not duplicated.
51110
pub(super) fn vtable_allocation_provider<'tcx>(
@@ -63,6 +122,9 @@ pub(super) fn vtable_allocation_provider<'tcx>(
63122
TyCtxt::COMMON_VTABLE_ENTRIES
64123
};
65124

125+
// This confirms that the layout computation for &dyn Trait has an accurate sizing.
126+
assert!(vtable_entries.len() >= vtable_min_entries(tcx, poly_trait_ref));
127+
66128
let layout = tcx
67129
.layout_of(ty::ParamEnv::reveal_all().and(ty))
68130
.expect("failed to build vtable representation");

compiler/rustc_trait_selection/src/solve/trait_goals.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
//! Dealing with trait goals, i.e. `T: Trait<'a, U>`.
22
3-
use crate::traits::supertrait_def_ids;
4-
53
use super::assembly::structural_traits::AsyncCallableRelevantTypes;
64
use super::assembly::{self, structural_traits, Candidate};
75
use super::{EvalCtxt, GoalSource, SolverMode};
@@ -837,7 +835,8 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
837835
let a_auto_traits: FxIndexSet<DefId> = a_data
838836
.auto_traits()
839837
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
840-
supertrait_def_ids(self.tcx(), principal_def_id)
838+
self.tcx()
839+
.supertrait_def_ids(principal_def_id)
841840
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
842841
}))
843842
.collect();

compiler/rustc_trait_selection/src/traits/mod.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ pub use self::structural_normalize::StructurallyNormalizeExt;
6565
pub use self::util::elaborate;
6666
pub use self::util::{expand_trait_aliases, TraitAliasExpander, TraitAliasExpansionInfo};
6767
pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
68-
pub use self::util::{
69-
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item,
70-
SupertraitDefIds,
71-
};
68+
pub use self::util::{supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item};
7269
pub use self::util::{with_replaced_escaping_bound_vars, BoundVarReplacer, PlaceholderReplacer};
7370

7471
pub use rustc_infer::traits::*;

compiler/rustc_trait_selection/src/traits/object_safety.rs

+63-55
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use rustc_middle::ty::{TypeVisitableExt, Upcast};
2626
use rustc_session::lint::builtin::WHERE_CLAUSES_OBJECT_SAFETY;
2727
use rustc_span::symbol::Symbol;
2828
use rustc_span::Span;
29+
use rustc_target::abi::Abi;
2930
use smallvec::SmallVec;
3031

3132
use std::iter;
@@ -44,7 +45,8 @@ pub fn hir_ty_lowering_object_safety_violations(
4445
trait_def_id: DefId,
4546
) -> Vec<ObjectSafetyViolation> {
4647
debug_assert!(tcx.generics_of(trait_def_id).has_self);
47-
let violations = traits::supertrait_def_ids(tcx, trait_def_id)
48+
let violations = tcx
49+
.supertrait_def_ids(trait_def_id)
4850
.map(|def_id| predicates_reference_self(tcx, def_id, true))
4951
.filter(|spans| !spans.is_empty())
5052
.map(ObjectSafetyViolation::SupertraitSelf)
@@ -58,7 +60,7 @@ fn object_safety_violations(tcx: TyCtxt<'_>, trait_def_id: DefId) -> &'_ [Object
5860
debug!("object_safety_violations: {:?}", trait_def_id);
5961

6062
tcx.arena.alloc_from_iter(
61-
traits::supertrait_def_ids(tcx, trait_def_id)
63+
tcx.supertrait_def_ids(trait_def_id)
6264
.flat_map(|def_id| object_safety_violations_for_trait(tcx, def_id)),
6365
)
6466
}
@@ -145,6 +147,14 @@ fn object_safety_violations_for_trait(
145147
violations.push(ObjectSafetyViolation::SupertraitNonLifetimeBinder(spans));
146148
}
147149

150+
if violations.is_empty() {
151+
for item in tcx.associated_items(trait_def_id).in_definition_order() {
152+
if let ty::AssocKind::Fn = item.kind {
153+
check_receiver_correct(tcx, trait_def_id, *item);
154+
}
155+
}
156+
}
157+
148158
debug!(
149159
"object_safety_violations_for_trait(trait_def_id={:?}) = {:?}",
150160
trait_def_id, violations
@@ -498,59 +508,8 @@ fn virtual_call_violations_for_method<'tcx>(
498508
};
499509
errors.push(MethodViolationCode::UndispatchableReceiver(span));
500510
} else {
501-
// Do sanity check to make sure the receiver actually has the layout of a pointer.
502-
503-
use rustc_target::abi::Abi;
504-
505-
let param_env = tcx.param_env(method.def_id);
506-
507-
let abi_of_ty = |ty: Ty<'tcx>| -> Option<Abi> {
508-
match tcx.layout_of(param_env.and(ty)) {
509-
Ok(layout) => Some(layout.abi),
510-
Err(err) => {
511-
// #78372
512-
tcx.dcx().span_delayed_bug(
513-
tcx.def_span(method.def_id),
514-
format!("error: {err}\n while computing layout for type {ty:?}"),
515-
);
516-
None
517-
}
518-
}
519-
};
520-
521-
// e.g., `Rc<()>`
522-
let unit_receiver_ty =
523-
receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method.def_id);
524-
525-
match abi_of_ty(unit_receiver_ty) {
526-
Some(Abi::Scalar(..)) => (),
527-
abi => {
528-
tcx.dcx().span_delayed_bug(
529-
tcx.def_span(method.def_id),
530-
format!(
531-
"receiver when `Self = ()` should have a Scalar ABI; found {abi:?}"
532-
),
533-
);
534-
}
535-
}
536-
537-
let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);
538-
539-
// e.g., `Rc<dyn Trait>`
540-
let trait_object_receiver =
541-
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method.def_id);
542-
543-
match abi_of_ty(trait_object_receiver) {
544-
Some(Abi::ScalarPair(..)) => (),
545-
abi => {
546-
tcx.dcx().span_delayed_bug(
547-
tcx.def_span(method.def_id),
548-
format!(
549-
"receiver when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
550-
),
551-
);
552-
}
553-
}
511+
// We confirm that the `receiver_is_dispatchable` is accurate later,
512+
// see `check_receiver_correct`. It should be kept in sync with this code.
554513
}
555514
}
556515

@@ -611,6 +570,55 @@ fn virtual_call_violations_for_method<'tcx>(
611570
errors
612571
}
613572

573+
/// This code checks that `receiver_is_dispatchable` is correctly implemented.
574+
///
575+
/// This check is outlined from the object safety check to avoid cycles with
576+
/// layout computation, which relies on knowing whether methods are object safe.
577+
pub fn check_receiver_correct<'tcx>(tcx: TyCtxt<'tcx>, trait_def_id: DefId, method: ty::AssocItem) {
578+
if !is_vtable_safe_method(tcx, trait_def_id, method) {
579+
return;
580+
}
581+
582+
let method_def_id = method.def_id;
583+
let sig = tcx.fn_sig(method_def_id).instantiate_identity();
584+
let param_env = tcx.param_env(method_def_id);
585+
let receiver_ty = tcx.liberate_late_bound_regions(method_def_id, sig.input(0));
586+
587+
if receiver_ty == tcx.types.self_param {
588+
// Assumed OK, may change later if unsized_locals permits `self: Self` as dispatchable.
589+
return;
590+
}
591+
592+
// e.g., `Rc<()>`
593+
let unit_receiver_ty = receiver_for_self_ty(tcx, receiver_ty, tcx.types.unit, method_def_id);
594+
match tcx.layout_of(param_env.and(unit_receiver_ty)).map(|l| l.abi) {
595+
Ok(Abi::Scalar(..)) => (),
596+
abi => {
597+
tcx.dcx().span_delayed_bug(
598+
tcx.def_span(method_def_id),
599+
format!("receiver {unit_receiver_ty:?} when `Self = ()` should have a Scalar ABI; found {abi:?}"),
600+
);
601+
}
602+
}
603+
604+
let trait_object_ty = object_ty_for_trait(tcx, trait_def_id, tcx.lifetimes.re_static);
605+
606+
// e.g., `Rc<dyn Trait>`
607+
let trait_object_receiver =
608+
receiver_for_self_ty(tcx, receiver_ty, trait_object_ty, method_def_id);
609+
match tcx.layout_of(param_env.and(trait_object_receiver)).map(|l| l.abi) {
610+
Ok(Abi::ScalarPair(..)) => (),
611+
abi => {
612+
tcx.dcx().span_delayed_bug(
613+
tcx.def_span(method_def_id),
614+
format!(
615+
"receiver {trait_object_receiver:?} when `Self = {trait_object_ty}` should have a ScalarPair ABI; found {abi:?}"
616+
),
617+
);
618+
}
619+
}
620+
}
621+
614622
/// Performs a type instantiation to produce the version of `receiver_ty` when `Self = self_ty`.
615623
/// For example, for `receiver_ty = Rc<Self>` and `self_ty = Foo`, returns `Rc<Foo>`.
616624
fn receiver_for_self_ty<'tcx>(

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
10041004
let a_auto_traits: FxIndexSet<DefId> = a_data
10051005
.auto_traits()
10061006
.chain(principal_def_id_a.into_iter().flat_map(|principal_def_id| {
1007-
util::supertrait_def_ids(self.tcx(), principal_def_id)
1007+
self.tcx()
1008+
.supertrait_def_ids(principal_def_id)
10081009
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
10091010
}))
10101011
.collect();

compiler/rustc_trait_selection/src/traits/select/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -2571,8 +2571,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
25712571
let a_auto_traits: FxIndexSet<DefId> = a_data
25722572
.auto_traits()
25732573
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
2574-
util::supertrait_def_ids(tcx, principal_def_id)
2575-
.filter(|def_id| tcx.trait_is_auto(*def_id))
2574+
tcx.supertrait_def_ids(principal_def_id).filter(|def_id| tcx.trait_is_auto(*def_id))
25762575
}))
25772576
.collect();
25782577

0 commit comments

Comments
 (0)