Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement coherence checks for negative trait impls #85764

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl Elaborator<'tcx> {

let bound_predicate = obligation.predicate.kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Trait(data, _) => {
ty::PredicateKind::Trait(data, _, _) => {
// Get predicates declared on the trait.
let predicates = tcx.super_predicates_of(data.def_id());

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl<'tcx> LateLintPass<'tcx> for DropTraitConstraints {
let predicates = cx.tcx.explicit_predicates_of(item.def_id);
for &(predicate, span) in predicates.predicates {
let trait_predicate = match predicate.kind().skip_binder() {
Trait(trait_predicate, _constness) => trait_predicate,
Trait(trait_predicate, _constness, _polarity) => trait_predicate,
_ => continue,
};
let def_id = trait_predicate.trait_ref.def_id;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/unused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
let mut has_emitted = false;
for &(predicate, _) in cx.tcx.explicit_item_bounds(def) {
// We only look at the `DefId`, so it is safe to skip the binder here.
if let ty::PredicateKind::Trait(ref poly_trait_predicate, _) =
if let ty::PredicateKind::Trait(ref poly_trait_predicate, _, _) =
predicate.kind().skip_binder()
{
let def_id = poly_trait_predicate.trait_ref.def_id;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2127,7 +2127,7 @@ impl<'tcx> TyCtxt<'tcx> {
let generic_predicates = self.super_predicates_of(trait_did);

for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateKind::Trait(data, _) = predicate.kind().skip_binder() {
if let ty::PredicateKind::Trait(data, _, _) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl FlagComputation {

fn add_predicate_atom(&mut self, atom: ty::PredicateKind<'_>) {
match atom {
ty::PredicateKind::Trait(trait_pred, _constness) => {
ty::PredicateKind::Trait(trait_pred, _constness, _polarity) => {
self.add_substs(trait_pred.trait_ref.substs);
}
ty::PredicateKind::RegionOutlives(ty::OutlivesPredicate(a, b)) => {
Expand Down
63 changes: 53 additions & 10 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,18 @@ pub struct ImplHeader<'tcx> {
pub predicates: Vec<Predicate<'tcx>>,
}

#[derive(Copy, Clone, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
TyEncodable,
TyDecodable,
HashStable,
TypeFoldable,
Debug
)]
pub enum ImplPolarity {
/// `impl Trait for Type`
Positive,
Expand Down Expand Up @@ -371,7 +382,7 @@ impl ty::EarlyBoundRegion {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
crate struct PredicateInner<'tcx> {
kind: Binder<'tcx, PredicateKind<'tcx>>,
flags: TypeFlags,
Expand Down Expand Up @@ -408,6 +419,26 @@ impl<'tcx> Predicate<'tcx> {
pub fn kind(self) -> Binder<'tcx, PredicateKind<'tcx>> {
self.inner.kind
}

/// hi oli
pub fn negate_trait(mut self, tcx: TyCtxt<'tcx>) -> Option<Self> {
let mut inner = self.inner.clone();

if !matches!(inner.kind.skip_binder(), PredicateKind::Trait(..)) {
return None;
}

inner.kind = inner.kind.map_bound(|kind| match kind {
PredicateKind::Trait(trait_pred, constness, ImplPolarity::Positive) => {
PredicateKind::Trait(trait_pred, constness, ImplPolarity::Negative)
}
_ => bug!(),
});

self.inner = tcx.arena.alloc(inner);

Some(self)
}
}

impl<'a, 'tcx> HashStable<StableHashingContext<'a>> for Predicate<'tcx> {
Expand Down Expand Up @@ -435,7 +466,7 @@ pub enum PredicateKind<'tcx> {
/// A trait predicate will have `Constness::Const` if it originates
/// from a bound on a `const fn` without the `?const` opt-out (e.g.,
/// `const fn foobar<Foo: Bar>() {}`).
Trait(TraitPredicate<'tcx>, Constness),
Trait(TraitPredicate<'tcx>, Constness, ImplPolarity),

/// `where 'a: 'b`
RegionOutlives(RegionOutlivesPredicate<'tcx>),
Expand Down Expand Up @@ -591,6 +622,7 @@ impl<'tcx> Predicate<'tcx> {
#[derive(HashStable, TypeFoldable)]
pub struct TraitPredicate<'tcx> {
pub trait_ref: TraitRef<'tcx>,
pub polarity: ImplPolarity,
}

pub type PolyTraitPredicate<'tcx> = ty::Binder<'tcx, TraitPredicate<'tcx>>;
Expand Down Expand Up @@ -724,24 +756,34 @@ impl ToPredicate<'tcx> for PredicateKind<'tcx> {

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<TraitRef<'tcx>> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
PredicateKind::Trait(ty::TraitPredicate { trait_ref: self.value }, self.constness)
.to_predicate(tcx)
PredicateKind::Trait(
ty::TraitPredicate { trait_ref: self.value, polarity: self.polarity },
self.constness,
self.polarity,
)
.to_predicate(tcx)
}
}

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<PolyTraitRef<'tcx>> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
self.value
.map_bound(|trait_ref| {
PredicateKind::Trait(ty::TraitPredicate { trait_ref }, self.constness)
PredicateKind::Trait(
ty::TraitPredicate { trait_ref, polarity: self.polarity },
self.constness,
self.polarity,
)
})
.to_predicate(tcx)
}
}

impl<'tcx> ToPredicate<'tcx> for ConstnessAnd<PolyTraitPredicate<'tcx>> {
fn to_predicate(self, tcx: TyCtxt<'tcx>) -> Predicate<'tcx> {
self.value.map_bound(|value| PredicateKind::Trait(value, self.constness)).to_predicate(tcx)
self.value
.map_bound(|value| PredicateKind::Trait(value, self.constness, self.polarity))
.to_predicate(tcx)
}
}

Expand All @@ -767,8 +809,8 @@ impl<'tcx> Predicate<'tcx> {
pub fn to_opt_poly_trait_ref(self) -> Option<ConstnessAnd<PolyTraitRef<'tcx>>> {
let predicate = self.kind();
match predicate.skip_binder() {
PredicateKind::Trait(t, constness) => {
Some(ConstnessAnd { constness, value: predicate.rebind(t.trait_ref) })
PredicateKind::Trait(t, constness, polarity) => {
Some(ConstnessAnd { constness, polarity, value: predicate.rebind(t.trait_ref) })
}
PredicateKind::Projection(..)
| PredicateKind::Subtype(..)
Expand Down Expand Up @@ -1234,6 +1276,7 @@ impl<'tcx> ParamEnv<'tcx> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, TypeFoldable)]
pub struct ConstnessAnd<T> {
pub constness: Constness,
pub polarity: ImplPolarity,
pub value: T,
}

Expand All @@ -1242,7 +1285,7 @@ pub struct ConstnessAnd<T> {
pub trait WithConstness: Sized {
#[inline]
fn with_constness(self, constness: Constness) -> ConstnessAnd<Self> {
ConstnessAnd { constness, value: self }
ConstnessAnd { constness, polarity: ImplPolarity::Positive, value: self }
}

#[inline]
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,8 @@ pub trait PrettyPrinter<'tcx>:
for (predicate, _) in bounds {
let predicate = predicate.subst(self.tcx(), substs);
let bound_predicate = predicate.kind();
if let ty::PredicateKind::Trait(pred, _) = bound_predicate.skip_binder() {
if let ty::PredicateKind::Trait(pred, _, _) = bound_predicate.skip_binder()
{
let trait_ref = bound_predicate.rebind(pred.trait_ref);
// Don't print +Sized, but rather +?Sized if absent.
if Some(trait_ref.def_id()) == self.tcx().lang_items().sized_trait() {
Expand Down Expand Up @@ -2191,7 +2192,7 @@ define_print_and_forward_display! {

ty::PredicateKind<'tcx> {
match *self {
ty::PredicateKind::Trait(ref data, constness) => {
ty::PredicateKind::Trait(ref data, constness, _polarity) => {
if let hir::Constness::Const = constness {
p!("const ");
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,8 @@ impl<'tcx> Relate<'tcx> for ty::TraitPredicate<'tcx> {
a: ty::TraitPredicate<'tcx>,
b: ty::TraitPredicate<'tcx>,
) -> RelateResult<'tcx, ty::TraitPredicate<'tcx>> {
Ok(ty::TraitPredicate { trait_ref: relation.relate(a.trait_ref, b.trait_ref)? })
// TODO(yaahc): double check this is fine
Ok(ty::TraitPredicate { trait_ref: relation.relate(a.trait_ref, b.trait_ref)?, polarity: ty::ImplPolarity::Positive })
}
}

Expand Down
9 changes: 5 additions & 4 deletions compiler/rustc_middle/src/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl fmt::Debug for ty::Predicate<'tcx> {
impl fmt::Debug for ty::PredicateKind<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ty::PredicateKind::Trait(ref a, constness) => {
ty::PredicateKind::Trait(ref a, constness, _polarity) => {
if let hir::Constness::Const = constness {
write!(f, "const ")?;
}
Expand Down Expand Up @@ -366,7 +366,8 @@ impl<'a, 'tcx> Lift<'tcx> for ty::ExistentialPredicate<'a> {
impl<'a, 'tcx> Lift<'tcx> for ty::TraitPredicate<'a> {
type Lifted = ty::TraitPredicate<'tcx>;
fn lift_to_tcx(self, tcx: TyCtxt<'tcx>) -> Option<ty::TraitPredicate<'tcx>> {
tcx.lift(self.trait_ref).map(|trait_ref| ty::TraitPredicate { trait_ref })
// TODO(yaahc): double check this is fine
tcx.lift(self.trait_ref).map(|trait_ref| ty::TraitPredicate { trait_ref, polarity: ty::ImplPolarity::Positive })
}
}

Expand Down Expand Up @@ -419,8 +420,8 @@ impl<'a, 'tcx> Lift<'tcx> for ty::PredicateKind<'a> {
type Lifted = ty::PredicateKind<'tcx>;
fn lift_to_tcx(self, tcx: TyCtxt<'tcx>) -> Option<Self::Lifted> {
match self {
ty::PredicateKind::Trait(data, constness) => {
tcx.lift(data).map(|data| ty::PredicateKind::Trait(data, constness))
ty::PredicateKind::Trait(data, constness, polarity) => {
tcx.lift(data).map(|data| ty::PredicateKind::Trait(data, constness, polarity))
}
ty::PredicateKind::Subtype(data) => tcx.lift(data).map(ty::PredicateKind::Subtype),
ty::PredicateKind::RegionOutlives(data) => {
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,8 @@ impl<'tcx> PolyTraitRef<'tcx> {
}

pub fn to_poly_trait_predicate(&self) -> ty::PolyTraitPredicate<'tcx> {
self.map_bound(|trait_ref| ty::TraitPredicate { trait_ref })
// TODO(yaahc): double check this is fine
self.map_bound(|trait_ref| ty::TraitPredicate { trait_ref, polarity: ty::ImplPolarity::Positive })
}
}

Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_mir/src/borrow_check/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use rustc_middle::ty::cast::CastTy;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::{GenericArgKind, Subst, SubstsRef, UserSubsts};
use rustc_middle::ty::{
self, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, RegionVid, ToPredicate, Ty,
TyCtxt, UserType, UserTypeAnnotationIndex, WithConstness,
self, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, ImplPolarity, RegionVid,
ToPredicate, Ty, TyCtxt, UserType, UserTypeAnnotationIndex, WithConstness,
};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::VariantIdx;
Expand Down Expand Up @@ -2718,8 +2718,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
) {
self.prove_predicates(
Some(ty::PredicateKind::Trait(
ty::TraitPredicate { trait_ref },
ty::TraitPredicate { trait_ref, polarity: ImplPolarity::Positive },
hir::Constness::NotConst,
ImplPolarity::Positive,
)),
locations,
category,
Expand Down
4 changes: 3 additions & 1 deletion compiler/rustc_mir/src/transform/check_consts/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl Validator<'mir, 'tcx> {
ty::PredicateKind::Subtype(_) => {
bug!("subtype predicate on function: {:#?}", predicate)
}
ty::PredicateKind::Trait(pred, _constness) => {
ty::PredicateKind::Trait(pred, _constness, _polarity) => {
if Some(pred.def_id()) == tcx.lang_items().sized_trait() {
continue;
}
Expand Down Expand Up @@ -826,12 +826,14 @@ impl Visitor<'tcx> for Validator<'mir, 'tcx> {
}

let trait_ref = TraitRef::from_method(tcx, trait_id, substs);
// TODO(yaahc): double check
let obligation = Obligation::new(
ObligationCause::dummy(),
param_env,
Binder::bind(
TraitPredicate {
trait_ref: TraitRef::from_method(tcx, trait_id, substs),
polarity: ty::ImplPolarity::Positive,
},
tcx,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl<'a, 'tcx> FunctionItemRefChecker<'a, 'tcx> {

/// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type.
fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option<Ty<'tcx>> {
if let ty::PredicateKind::Trait(predicate, _) = bound {
if let ty::PredicateKind::Trait(predicate, _, _) = bound {
if self.tcx.is_diagnostic_item(sym::pointer_trait, predicate.def_id()) {
Some(predicate.trait_ref.self_ty())
} else {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_privacy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ where

fn visit_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ControlFlow<V::BreakTy> {
match predicate.kind().skip_binder() {
ty::PredicateKind::Trait(ty::TraitPredicate { trait_ref }, _) => {
ty::PredicateKind::Trait(ty::TraitPredicate { trait_ref, .. }, _, _) => {
self.visit_trait(trait_ref)
}
ty::PredicateKind::Projection(ty::ProjectionPredicate { projection_ty, ty }) => {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_resolve/src/late/lifetimes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2666,7 +2666,7 @@ impl<'a, 'tcx> LifetimeContext<'a, 'tcx> {
let obligations = predicates.predicates.iter().filter_map(|&(pred, _)| {
let bound_predicate = pred.kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Trait(data, _) => {
ty::PredicateKind::Trait(data, _, _) => {
// The order here needs to match what we would get from `subst_supertrait`
let pred_bound_vars = bound_predicate.bound_vars();
let mut all_bound_vars = bound_vars.clone();
Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_trait_selection/src/traits/auto_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,13 @@ impl AutoTraitFinder<'tcx> {

let mut already_visited = FxHashSet::default();
let mut predicates = VecDeque::new();
// TODO(yaahc): double check this is fine
predicates.push_back(ty::Binder::dummy(ty::TraitPredicate {
trait_ref: ty::TraitRef {
def_id: trait_did,
substs: infcx.tcx.mk_substs_trait(ty, &[]),
},
polarity: ty::ImplPolarity::Positive,
}));

let computed_preds = param_env.caller_bounds().iter();
Expand Down Expand Up @@ -415,8 +417,8 @@ impl AutoTraitFinder<'tcx> {
let mut should_add_new = true;
user_computed_preds.retain(|&old_pred| {
if let (
ty::PredicateKind::Trait(new_trait, _),
ty::PredicateKind::Trait(old_trait, _),
ty::PredicateKind::Trait(new_trait, _, _),
ty::PredicateKind::Trait(old_trait, _, _),
) = (new_pred.kind().skip_binder(), old_pred.kind().skip_binder())
{
if new_trait.def_id() == old_trait.def_id() {
Expand Down Expand Up @@ -638,7 +640,7 @@ impl AutoTraitFinder<'tcx> {

let bound_predicate = predicate.kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Trait(p, _) => {
ty::PredicateKind::Trait(p, _, _) => {
// Add this to `predicates` so that we end up calling `select`
// with it. If this predicate ends up being unimplemented,
// then `evaluate_predicates` will handle adding it the `ParamEnv`
Expand Down
Loading