|
1 | 1 | use std::mem::swap;
|
2 | 2 |
|
| 3 | +use ast::ptr::P; |
3 | 4 | use ast::HasAttrs;
|
| 5 | +use rustc_ast::mut_visit::MutVisitor; |
| 6 | +use rustc_ast::visit::BoundKind; |
4 | 7 | use rustc_ast::{
|
5 | 8 | self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
|
6 |
| - TraitBoundModifiers, VariantData, |
| 9 | + TraitBoundModifiers, VariantData, WherePredicate, |
7 | 10 | };
|
8 | 11 | use rustc_attr as attr;
|
| 12 | +use rustc_data_structures::flat_map_in_place::FlatMapInPlace; |
9 | 13 | use rustc_expand::base::{Annotatable, ExtCtxt};
|
10 | 14 | use rustc_span::symbol::{sym, Ident};
|
11 |
| -use rustc_span::Span; |
| 15 | +use rustc_span::{Span, Symbol}; |
12 | 16 | use smallvec::{smallvec, SmallVec};
|
13 | 17 | use thin_vec::{thin_vec, ThinVec};
|
14 | 18 |
|
@@ -141,33 +145,239 @@ pub fn expand_deriving_smart_ptr(
|
141 | 145 | alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone());
|
142 | 146 | let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params));
|
143 | 147 |
|
| 148 | + // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location |
| 149 | + // |
144 | 150 | // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
|
145 | 151 | let mut impl_generics = generics.clone();
|
| 152 | + let pointee_ty_ident = generics.params[pointee_param_idx].ident; |
| 153 | + let mut self_bounds; |
146 | 154 | {
|
147 |
| - let p = &mut impl_generics.params[pointee_param_idx]; |
| 155 | + let pointee = &mut impl_generics.params[pointee_param_idx]; |
| 156 | + self_bounds = pointee.bounds.clone(); |
148 | 157 | let arg = GenericArg::Type(s_ty.clone());
|
149 | 158 | let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
|
150 |
| - p.bounds.push(cx.trait_bound(unsize, false)); |
| 159 | + pointee.bounds.push(cx.trait_bound(unsize, false)); |
151 | 160 | let mut attrs = thin_vec![];
|
152 |
| - swap(&mut p.attrs, &mut attrs); |
153 |
| - p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); |
| 161 | + swap(&mut pointee.attrs, &mut attrs); |
| 162 | + // Drop `#[pointee]` attribute since it should not be recognized outside `derive(SmartPointer)` |
| 163 | + pointee.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); |
154 | 164 | }
|
155 | 165 |
|
156 |
| - // Add the `__S: ?Sized` extra parameter to the impl block. |
| 166 | + // # Rewrite generic parameter bounds |
| 167 | + // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]` |
| 168 | + // Example: |
| 169 | + // ``` |
| 170 | + // struct< |
| 171 | + // U: Trait<T>, |
| 172 | + // #[pointee] T: Trait<T>, |
| 173 | + // V: Trait<T>> ... |
| 174 | + // ``` |
| 175 | + // ... generates this `impl` generic parameters |
| 176 | + // ``` |
| 177 | + // impl< |
| 178 | + // U: Trait<T> + Trait<__S>, |
| 179 | + // T: Trait<T> + Unsize<__S>, // (**) |
| 180 | + // __S: Trait<__S> + ?Sized, // (*) |
| 181 | + // V: Trait<T> + Trait<__S>> ... |
| 182 | + // ``` |
| 183 | + // The new bound marked with (*) has to be done separately. |
| 184 | + // See next section |
| 185 | + for (idx, (params, orig_params)) in |
| 186 | + impl_generics.params.iter_mut().zip(&generics.params).enumerate() |
| 187 | + { |
| 188 | + // Default type parameters are rejected for `impl` block. |
| 189 | + // We should drop them now. |
| 190 | + match &mut params.kind { |
| 191 | + ast::GenericParamKind::Const { default, .. } => *default = None, |
| 192 | + ast::GenericParamKind::Type { default } => *default = None, |
| 193 | + ast::GenericParamKind::Lifetime => {} |
| 194 | + } |
| 195 | + // We CANNOT rewrite `#[pointee]` type parameter bounds. |
| 196 | + // This has been set in stone. (**) |
| 197 | + // So we skip over it. |
| 198 | + // Otherwise, we push extra bounds involving `__S`. |
| 199 | + if idx != pointee_param_idx { |
| 200 | + for bound in &orig_params.bounds { |
| 201 | + let mut bound = bound.clone(); |
| 202 | + let mut substitution = TypeSubstitution { |
| 203 | + from_name: pointee_ty_ident.name, |
| 204 | + to_ty: &s_ty, |
| 205 | + rewritten: false, |
| 206 | + }; |
| 207 | + substitution.visit_param_bound(&mut bound, BoundKind::Bound); |
| 208 | + if substitution.rewritten { |
| 209 | + // We found use of `#[pointee]` somewhere, |
| 210 | + // so we make a new bound using `__S` in place of `#[pointee]` |
| 211 | + params.bounds.push(bound); |
| 212 | + } |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | + |
| 217 | + // # Insert `__S` type parameter |
| 218 | + // |
| 219 | + // We now insert `__S` with the missing bounds marked with (*) above. |
| 220 | + // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`. |
157 | 221 | let sized = cx.path_global(span, path!(span, core::marker::Sized));
|
158 |
| - let bound = GenericBound::Trait( |
159 |
| - cx.poly_trait_ref(span, sized), |
160 |
| - TraitBoundModifiers { |
161 |
| - polarity: ast::BoundPolarity::Maybe(span), |
162 |
| - constness: ast::BoundConstness::Never, |
163 |
| - asyncness: ast::BoundAsyncness::Normal, |
164 |
| - }, |
165 |
| - ); |
166 |
| - let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None); |
167 |
| - impl_generics.params.push(extra_param); |
| 222 | + // For some reason, we are not allowed to write `?Sized` bound twice like `__S: ?Sized + ?Sized`. |
| 223 | + if !contains_maybe_sized_bound(&self_bounds) |
| 224 | + && !contains_maybe_sized_bound_on_pointee( |
| 225 | + &generics.where_clause.predicates, |
| 226 | + pointee_ty_ident.name, |
| 227 | + ) |
| 228 | + { |
| 229 | + self_bounds.push(GenericBound::Trait( |
| 230 | + cx.poly_trait_ref(span, sized), |
| 231 | + TraitBoundModifiers { |
| 232 | + polarity: ast::BoundPolarity::Maybe(span), |
| 233 | + constness: ast::BoundConstness::Never, |
| 234 | + asyncness: ast::BoundAsyncness::Normal, |
| 235 | + }, |
| 236 | + )); |
| 237 | + } |
| 238 | + { |
| 239 | + let mut substitution = |
| 240 | + TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false }; |
| 241 | + for bound in &mut self_bounds { |
| 242 | + substitution.visit_param_bound(bound, BoundKind::Bound); |
| 243 | + } |
| 244 | + } |
| 245 | + |
| 246 | + // # Rewrite `where` clauses |
| 247 | + // |
| 248 | + // Move on to `where` clauses. |
| 249 | + // Example: |
| 250 | + // ``` |
| 251 | + // struct MyPointer<#[pointee] T, ..> |
| 252 | + // where |
| 253 | + // U: Trait<V> + Trait<T>, |
| 254 | + // Companion<T>: Trait<T>, |
| 255 | + // T: Trait<T>, |
| 256 | + // { .. } |
| 257 | + // ``` |
| 258 | + // ... will have a impl prelude like so |
| 259 | + // ``` |
| 260 | + // impl<..> .. |
| 261 | + // where |
| 262 | + // U: Trait<V> + Trait<T>, |
| 263 | + // U: Trait<__S>, |
| 264 | + // Companion<T>: Trait<T>, |
| 265 | + // Companion<__S>: Trait<__S>, |
| 266 | + // T: Trait<T>, |
| 267 | + // __S: Trait<__S>, |
| 268 | + // ``` |
| 269 | + // |
| 270 | + // We should also write a few new `where` bounds from `#[pointee] T` to `__S` |
| 271 | + // as well as any bound that indirectly involves the `#[pointee] T` type. |
| 272 | + for bound in &generics.where_clause.predicates { |
| 273 | + if let ast::WherePredicate::BoundPredicate(bound) = bound { |
| 274 | + let mut substitution = TypeSubstitution { |
| 275 | + from_name: pointee_ty_ident.name, |
| 276 | + to_ty: &s_ty, |
| 277 | + rewritten: false, |
| 278 | + }; |
| 279 | + let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate { |
| 280 | + span: bound.span, |
| 281 | + bound_generic_params: bound.bound_generic_params.clone(), |
| 282 | + bounded_ty: bound.bounded_ty.clone(), |
| 283 | + bounds: bound.bounds.clone(), |
| 284 | + }); |
| 285 | + substitution.visit_where_predicate(&mut predicate); |
| 286 | + if substitution.rewritten { |
| 287 | + impl_generics.where_clause.predicates.push(predicate); |
| 288 | + } |
| 289 | + } |
| 290 | + } |
| 291 | + |
| 292 | + let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None); |
| 293 | + impl_generics.params.insert(pointee_param_idx + 1, extra_param); |
168 | 294 |
|
169 | 295 | // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
|
170 | 296 | let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
|
171 | 297 | add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
|
172 | 298 | add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
|
173 | 299 | }
|
| 300 | + |
| 301 | +fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool { |
| 302 | + for bound in predicates { |
| 303 | + if let ast::WherePredicate::BoundPredicate(bound) = bound |
| 304 | + && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee) |
| 305 | + { |
| 306 | + for bound in &bound.bounds { |
| 307 | + if is_maybe_sized_bound(bound) { |
| 308 | + return true; |
| 309 | + } |
| 310 | + } |
| 311 | + } |
| 312 | + } |
| 313 | + false |
| 314 | +} |
| 315 | + |
| 316 | +fn is_maybe_sized_bound(bound: &GenericBound) -> bool { |
| 317 | + if let GenericBound::Trait( |
| 318 | + trait_ref, |
| 319 | + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, |
| 320 | + ) = bound |
| 321 | + { |
| 322 | + is_sized_marker(&trait_ref.trait_ref.path) |
| 323 | + } else { |
| 324 | + false |
| 325 | + } |
| 326 | +} |
| 327 | + |
| 328 | +fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool { |
| 329 | + bounds.iter().any(is_maybe_sized_bound) |
| 330 | +} |
| 331 | + |
| 332 | +fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool { |
| 333 | + path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol) |
| 334 | +} |
| 335 | + |
| 336 | +fn is_sized_marker(path: &ast::Path) -> bool { |
| 337 | + const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized]; |
| 338 | + const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized]; |
| 339 | + if path.segments.len() == 4 && path.is_global() { |
| 340 | + path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE) |
| 341 | + || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE) |
| 342 | + } else if path.segments.len() == 3 { |
| 343 | + path_segment_is_exact_match(&path.segments, &CORE_UNSIZE) |
| 344 | + || path_segment_is_exact_match(&path.segments, &STD_UNSIZE) |
| 345 | + } else { |
| 346 | + *path == sym::Sized |
| 347 | + } |
| 348 | +} |
| 349 | + |
| 350 | +struct TypeSubstitution<'a> { |
| 351 | + from_name: Symbol, |
| 352 | + to_ty: &'a ast::Ty, |
| 353 | + rewritten: bool, |
| 354 | +} |
| 355 | + |
| 356 | +impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> { |
| 357 | + fn visit_ty(&mut self, ty: &mut P<ast::Ty>) { |
| 358 | + if let Some(name) = ty.kind.is_simple_path() |
| 359 | + && name == self.from_name |
| 360 | + { |
| 361 | + **ty = self.to_ty.clone(); |
| 362 | + self.rewritten = true; |
| 363 | + } else { |
| 364 | + ast::mut_visit::walk_ty(self, ty); |
| 365 | + } |
| 366 | + } |
| 367 | + |
| 368 | + fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) { |
| 369 | + match where_predicate { |
| 370 | + rustc_ast::WherePredicate::BoundPredicate(bound) => { |
| 371 | + bound |
| 372 | + .bound_generic_params |
| 373 | + .flat_map_in_place(|param| self.flat_map_generic_param(param)); |
| 374 | + self.visit_ty(&mut bound.bounded_ty); |
| 375 | + for bound in &mut bound.bounds { |
| 376 | + self.visit_param_bound(bound, BoundKind::Bound) |
| 377 | + } |
| 378 | + } |
| 379 | + rustc_ast::WherePredicate::RegionPredicate(_) |
| 380 | + | rustc_ast::WherePredicate::EqPredicate(_) => {} |
| 381 | + } |
| 382 | + } |
| 383 | +} |
0 commit comments