@@ -190,7 +190,12 @@ pub trait InferCtxtExt<'tcx> {
190
190
trait_ref : & ty:: PolyTraitRef < ' tcx > ,
191
191
) ;
192
192
193
- fn suggest_derive ( & self , err : & mut Diagnostic , trait_pred : ty:: PolyTraitPredicate < ' tcx > ) ;
193
+ fn suggest_derive (
194
+ & self ,
195
+ obligation : & PredicateObligation < ' tcx > ,
196
+ err : & mut Diagnostic ,
197
+ trait_pred : ty:: PolyTraitPredicate < ' tcx > ,
198
+ ) ;
194
199
}
195
200
196
201
fn predicate_constraint ( generics : & hir:: Generics < ' _ > , pred : String ) -> ( Span , String ) {
@@ -2592,33 +2597,60 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
2592
2597
}
2593
2598
}
2594
2599
2595
- fn suggest_derive ( & self , err : & mut Diagnostic , trait_pred : ty:: PolyTraitPredicate < ' tcx > ) {
2600
+ fn suggest_derive (
2601
+ & self ,
2602
+ obligation : & PredicateObligation < ' tcx > ,
2603
+ err : & mut Diagnostic ,
2604
+ trait_pred : ty:: PolyTraitPredicate < ' tcx > ,
2605
+ ) {
2596
2606
let Some ( diagnostic_name) = self . tcx . get_diagnostic_name ( trait_pred. def_id ( ) ) else {
2597
2607
return ;
2598
2608
} ;
2599
- let Some ( self_ty) = trait_pred. self_ty ( ) . no_bound_vars ( ) else {
2600
- return ;
2601
- } ;
2602
-
2603
- let adt = match self_ty. ty_adt_def ( ) {
2604
- Some ( adt) if adt. did ( ) . is_local ( ) => adt,
2609
+ let ( adt, substs) = match trait_pred. skip_binder ( ) . self_ty ( ) . kind ( ) {
2610
+ ty:: Adt ( adt, substs) if adt. did ( ) . is_local ( ) => ( adt, substs) ,
2605
2611
_ => return ,
2606
2612
} ;
2607
- let can_derive = match diagnostic_name {
2608
- sym:: Default => !adt. is_enum ( ) ,
2609
- sym:: PartialEq | sym:: PartialOrd => {
2610
- let rhs_ty = trait_pred. skip_binder ( ) . trait_ref . substs . type_at ( 1 ) ;
2611
- self_ty == rhs_ty
2612
- }
2613
- sym:: Eq | sym:: Ord | sym:: Clone | sym:: Copy | sym:: Hash | sym:: Debug => true ,
2614
- _ => false ,
2613
+ let can_derive = {
2614
+ let is_derivable_trait = match diagnostic_name {
2615
+ sym:: Default => !adt. is_enum ( ) ,
2616
+ sym:: PartialEq | sym:: PartialOrd => {
2617
+ let rhs_ty = trait_pred. skip_binder ( ) . trait_ref . substs . type_at ( 1 ) ;
2618
+ trait_pred. skip_binder ( ) . self_ty ( ) == rhs_ty
2619
+ }
2620
+ sym:: Eq | sym:: Ord | sym:: Clone | sym:: Copy | sym:: Hash | sym:: Debug => true ,
2621
+ _ => false ,
2622
+ } ;
2623
+ is_derivable_trait &&
2624
+ // Ensure all fields impl the trait.
2625
+ adt. all_fields ( ) . all ( |field| {
2626
+ let field_ty = field. ty ( self . tcx , substs) ;
2627
+ let trait_substs = match diagnostic_name {
2628
+ sym:: PartialEq | sym:: PartialOrd => {
2629
+ self . tcx . mk_substs_trait ( field_ty, & [ field_ty. into ( ) ] )
2630
+ }
2631
+ _ => self . tcx . mk_substs_trait ( field_ty, & [ ] ) ,
2632
+ } ;
2633
+ let trait_pred = trait_pred. map_bound_ref ( |tr| ty:: TraitPredicate {
2634
+ trait_ref : ty:: TraitRef {
2635
+ substs : trait_substs,
2636
+ ..trait_pred. skip_binder ( ) . trait_ref
2637
+ } ,
2638
+ ..* tr
2639
+ } ) ;
2640
+ let field_obl = Obligation :: new (
2641
+ obligation. cause . clone ( ) ,
2642
+ obligation. param_env ,
2643
+ trait_pred. to_predicate ( self . tcx ) ,
2644
+ ) ;
2645
+ self . predicate_must_hold_modulo_regions ( & field_obl)
2646
+ } )
2615
2647
} ;
2616
2648
if can_derive {
2617
2649
err. span_suggestion_verbose (
2618
2650
self . tcx . def_span ( adt. did ( ) ) . shrink_to_lo ( ) ,
2619
2651
& format ! (
2620
2652
"consider annotating `{}` with `#[derive({})]`" ,
2621
- trait_pred. skip_binder( ) . self_ty( ) . to_string ( ) ,
2653
+ trait_pred. skip_binder( ) . self_ty( ) ,
2622
2654
diagnostic_name. to_string( ) ,
2623
2655
) ,
2624
2656
format ! ( "#[derive({})]\n " , diagnostic_name. to_string( ) ) ,
0 commit comments