@@ -6,11 +6,12 @@ use crate::animate::{AnimationFieldAttrs, AnimationInputAttrs, AnimationVariantA
6
6
use derive_common:: cg;
7
7
use proc_macro2:: TokenStream ;
8
8
use quote:: TokenStreamExt ;
9
- use syn:: { DeriveInput , Path } ;
9
+ use syn:: { DeriveInput , Path , WhereClause } ;
10
10
use synstructure;
11
11
12
12
pub fn derive ( mut input : DeriveInput ) -> TokenStream {
13
13
let animation_input_attrs = cg:: parse_input_attrs :: < AnimationInputAttrs > ( & input) ;
14
+ let input_attrs = cg:: parse_input_attrs :: < DistanceInputAttrs > ( & input) ;
14
15
let no_bound = animation_input_attrs. no_bound . unwrap_or_default ( ) ;
15
16
let mut where_clause = input. generics . where_clause . take ( ) ;
16
17
for param in input. generics . type_params ( ) {
@@ -22,76 +23,31 @@ pub fn derive(mut input: DeriveInput) -> TokenStream {
22
23
}
23
24
}
24
25
25
- let ( mut match_body, append_error_clause ) = {
26
+ let ( mut match_body, needs_catchall_branch ) = {
26
27
let s = synstructure:: Structure :: new ( & input) ;
27
- let mut append_error_clause = s. variants ( ) . len ( ) > 1 ;
28
+ let needs_catchall_branch = s. variants ( ) . len ( ) > 1 ;
28
29
29
30
let match_body = s. variants ( ) . iter ( ) . fold ( quote ! ( ) , |body, variant| {
30
- let attrs = cg:: parse_variant_attrs_from_ast :: < AnimationVariantAttrs > ( & variant. ast ( ) ) ;
31
- if attrs. error {
32
- append_error_clause = true ;
33
- return body;
34
- }
35
-
36
- let ( this_pattern, this_info) = cg:: ref_pattern ( & variant, "this" ) ;
37
- let ( other_pattern, other_info) = cg:: ref_pattern ( & variant, "other" ) ;
38
- let sum = if this_info. is_empty ( ) {
39
- quote ! { crate :: values:: distance:: SquaredDistance :: from_sqrt( 0. ) }
40
- } else {
41
- let mut sum = quote ! ( ) ;
42
- sum. append_separated ( this_info. iter ( ) . zip ( & other_info) . map ( |( this, other) | {
43
- let field_attrs = cg:: parse_field_attrs :: < DistanceFieldAttrs > ( & this. ast ( ) ) ;
44
- if field_attrs. field_bound {
45
- let ty = & this. ast ( ) . ty ;
46
- cg:: add_predicate (
47
- & mut where_clause,
48
- parse_quote ! ( #ty: crate :: values:: distance:: ComputeSquaredDistance ) ,
49
- ) ;
50
- }
51
-
52
- let animation_field_attrs =
53
- cg:: parse_field_attrs :: < AnimationFieldAttrs > ( & this. ast ( ) ) ;
54
-
55
- if animation_field_attrs. constant {
56
- quote ! {
57
- {
58
- if #this != #other {
59
- return Err ( ( ) ) ;
60
- }
61
- crate :: values:: distance:: SquaredDistance :: from_sqrt( 0. )
62
- }
63
- }
64
- } else {
65
- quote ! {
66
- crate :: values:: distance:: ComputeSquaredDistance :: compute_squared_distance( #this, #other) ?
67
- }
68
- }
69
- } ) , quote ! ( +) ) ;
70
- sum
71
- } ;
72
- quote ! {
73
- #body
74
- ( & #this_pattern, & #other_pattern) => {
75
- Ok ( #sum)
76
- }
77
- }
31
+ let arm = derive_variant_arm ( variant, & mut where_clause) ;
32
+ quote ! { #body #arm }
78
33
} ) ;
79
34
80
- ( match_body, append_error_clause )
35
+ ( match_body, needs_catchall_branch )
81
36
} ;
37
+
82
38
input. generics . where_clause = where_clause;
83
39
84
- if append_error_clause {
85
- let input_attrs = cg:: parse_input_attrs :: < DistanceInputAttrs > ( & input) ;
86
- if let Some ( fallback) = input_attrs. fallback {
87
- match_body. append_all ( quote ! {
88
- ( this, other) => #fallback( this, other)
89
- } ) ;
90
- } else {
91
- match_body. append_all ( quote ! { _ => Err ( ( ) ) } ) ;
92
- }
40
+ if needs_catchall_branch {
41
+ // This ideally shouldn't be needed, but see:
42
+ // https://github.com/rust-lang/rust/issues/68867
43
+ match_body. append_all ( quote ! { _ => unsafe { debug_unreachable!( ) } } ) ;
93
44
}
94
45
46
+ let fallback = match input_attrs. fallback {
47
+ Some ( fallback) => quote ! { #fallback( self , other) } ,
48
+ None => quote ! { Err ( ( ) ) } ,
49
+ } ;
50
+
95
51
let name = & input. ident ;
96
52
let ( impl_generics, ty_generics, where_clause) = input. generics . split_for_impl ( ) ;
97
53
@@ -103,6 +59,9 @@ pub fn derive(mut input: DeriveInput) -> TokenStream {
103
59
& self ,
104
60
other: & Self ,
105
61
) -> Result <crate :: values:: distance:: SquaredDistance , ( ) > {
62
+ if std:: mem:: discriminant( self ) != std:: mem:: discriminant( other) {
63
+ return #fallback;
64
+ }
106
65
match ( self , other) {
107
66
#match_body
108
67
}
@@ -111,6 +70,60 @@ pub fn derive(mut input: DeriveInput) -> TokenStream {
111
70
}
112
71
}
113
72
73
+ fn derive_variant_arm (
74
+ variant : & synstructure:: VariantInfo ,
75
+ mut where_clause : & mut Option < WhereClause > ,
76
+ ) -> TokenStream {
77
+ let variant_attrs = cg:: parse_variant_attrs_from_ast :: < AnimationVariantAttrs > ( & variant. ast ( ) ) ;
78
+ let ( this_pattern, this_info) = cg:: ref_pattern ( & variant, "this" ) ;
79
+ let ( other_pattern, other_info) = cg:: ref_pattern ( & variant, "other" ) ;
80
+
81
+ if variant_attrs. error {
82
+ return quote ! {
83
+ ( & #this_pattern, & #other_pattern) => Err ( ( ) ) ,
84
+ } ;
85
+ }
86
+
87
+ let sum = if this_info. is_empty ( ) {
88
+ quote ! { crate :: values:: distance:: SquaredDistance :: from_sqrt( 0. ) }
89
+ } else {
90
+ let mut sum = quote ! ( ) ;
91
+ sum. append_separated ( this_info. iter ( ) . zip ( & other_info) . map ( |( this, other) | {
92
+ let field_attrs = cg:: parse_field_attrs :: < DistanceFieldAttrs > ( & this. ast ( ) ) ;
93
+ if field_attrs. field_bound {
94
+ let ty = & this. ast ( ) . ty ;
95
+ cg:: add_predicate (
96
+ & mut where_clause,
97
+ parse_quote ! ( #ty: crate :: values:: distance:: ComputeSquaredDistance ) ,
98
+ ) ;
99
+ }
100
+
101
+ let animation_field_attrs =
102
+ cg:: parse_field_attrs :: < AnimationFieldAttrs > ( & this. ast ( ) ) ;
103
+
104
+ if animation_field_attrs. constant {
105
+ quote ! {
106
+ {
107
+ if #this != #other {
108
+ return Err ( ( ) ) ;
109
+ }
110
+ crate :: values:: distance:: SquaredDistance :: from_sqrt( 0. )
111
+ }
112
+ }
113
+ } else {
114
+ quote ! {
115
+ crate :: values:: distance:: ComputeSquaredDistance :: compute_squared_distance( #this, #other) ?
116
+ }
117
+ }
118
+ } ) , quote ! ( +) ) ;
119
+ sum
120
+ } ;
121
+
122
+ return quote ! {
123
+ ( & #this_pattern, & #other_pattern) => Ok ( #sum) ,
124
+ } ;
125
+ }
126
+
114
127
#[ darling( attributes( distance) , default ) ]
115
128
#[ derive( Default , FromDeriveInput ) ]
116
129
struct DistanceInputAttrs {
0 commit comments