Skip to content

Commit 00413c5

Browse files
derive(SmartPointer): rewrite bounds in where and generic bounds
1 parent a5ee5cb commit 00413c5

File tree

4 files changed

+335
-11
lines changed

4 files changed

+335
-11
lines changed

compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs

+197-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
use std::mem::swap;
22

33
use ast::HasAttrs;
4+
use rustc_ast::mut_visit::MutVisitor;
5+
use rustc_ast::visit::BoundKind;
46
use rustc_ast::{
57
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
68
TraitBoundModifiers, VariantData,
79
};
810
use rustc_attr as attr;
11+
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
912
use rustc_expand::base::{Annotatable, ExtCtxt};
1013
use rustc_span::symbol::{sym, Ident};
11-
use rustc_span::Span;
14+
use rustc_span::{Span, Symbol};
1215
use smallvec::{smallvec, SmallVec};
1316
use thin_vec::{thin_vec, ThinVec};
1417

18+
type AstTy = ast::ptr::P<ast::Ty>;
19+
1520
macro_rules! path {
1621
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
1722
}
1823

24+
macro_rules! symbols {
25+
($($part:ident)::*) => { [$(sym::$part),*] }
26+
}
27+
1928
pub fn expand_deriving_smart_ptr(
2029
cx: &ExtCtxt<'_>,
2130
span: Span,
@@ -143,31 +152,208 @@ pub fn expand_deriving_smart_ptr(
143152

144153
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
145154
let mut impl_generics = generics.clone();
155+
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
156+
let mut self_bounds;
146157
{
147158
let p = &mut impl_generics.params[pointee_param_idx];
159+
self_bounds = p.bounds.clone();
148160
let arg = GenericArg::Type(s_ty.clone());
149161
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
150162
p.bounds.push(cx.trait_bound(unsize, false));
151163
let mut attrs = thin_vec![];
152164
swap(&mut p.attrs, &mut attrs);
153165
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
154166
}
167+
// We should not set default values to constant generic parameters
168+
// and write out bounds that indirectly involves `#[pointee]`.
169+
for (params, orig_params) in impl_generics.params[pointee_param_idx + 1..]
170+
.iter_mut()
171+
.zip(&generics.params[pointee_param_idx + 1..])
172+
{
173+
if let ast::GenericParamKind::Const { default, .. } = &mut params.kind {
174+
*default = None;
175+
}
176+
for bound in &orig_params.bounds {
177+
let mut bound = bound.clone();
178+
let mut substitution = TypeSubstitution {
179+
from_name: pointee_ty_ident.name,
180+
to_ty: &s_ty,
181+
rewritten: false,
182+
};
183+
substitution.visit_param_bound(&mut bound, BoundKind::Bound);
184+
if substitution.rewritten {
185+
params.bounds.push(bound);
186+
}
187+
}
188+
}
155189

156190
// Add the `__S: ?Sized` extra parameter to the impl block.
191+
// We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
157192
let sized = cx.path_global(span, path!(span, core::marker::Sized));
158-
let bound = GenericBound::Trait(
159-
cx.poly_trait_ref(span, sized),
160-
TraitBoundModifiers {
161-
polarity: ast::BoundPolarity::Maybe(span),
162-
constness: ast::BoundConstness::Never,
163-
asyncness: ast::BoundAsyncness::Normal,
164-
},
165-
);
166-
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
167-
impl_generics.params.push(extra_param);
193+
if self_bounds.iter().all(|bound| {
194+
if let GenericBound::Trait(
195+
trait_ref,
196+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
197+
) = bound
198+
{
199+
!is_sized_marker(&trait_ref.trait_ref.path)
200+
} else {
201+
false
202+
}
203+
}) {
204+
self_bounds.push(GenericBound::Trait(
205+
cx.poly_trait_ref(span, sized),
206+
TraitBoundModifiers {
207+
polarity: ast::BoundPolarity::Maybe(span),
208+
constness: ast::BoundConstness::Never,
209+
asyncness: ast::BoundAsyncness::Normal,
210+
},
211+
));
212+
}
213+
{
214+
let mut substitution =
215+
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
216+
for bound in &mut self_bounds {
217+
substitution.visit_param_bound(bound, BoundKind::Bound);
218+
}
219+
}
220+
221+
// We should also commute the where bounds from `#[pointee]` to `__S`
222+
// as well as any bound that indirectly involves the `#[pointee]` type.
223+
for bound in &generics.where_clause.predicates {
224+
if let ast::WherePredicate::BoundPredicate(bound) = bound {
225+
let bound_on_pointee = bound
226+
.bounded_ty
227+
.kind
228+
.is_simple_path()
229+
.map_or(false, |name| name == pointee_ty_ident.name);
230+
231+
let bounds: Vec<_> = bound
232+
.bounds
233+
.iter()
234+
.filter(|bound| {
235+
if let GenericBound::Trait(
236+
trait_ref,
237+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
238+
) = bound
239+
{
240+
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
241+
} else {
242+
true
243+
}
244+
})
245+
.cloned()
246+
.collect();
247+
let mut substitution = TypeSubstitution {
248+
from_name: pointee_ty_ident.name,
249+
to_ty: &s_ty,
250+
rewritten: bounds.len() != bound.bounds.len(),
251+
};
252+
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
253+
span: bound.span,
254+
bound_generic_params: bound.bound_generic_params.clone(),
255+
bounded_ty: bound.bounded_ty.clone(),
256+
bounds,
257+
});
258+
substitution.visit_where_predicate(&mut predicate);
259+
if substitution.rewritten {
260+
impl_generics.where_clause.predicates.push(predicate);
261+
}
262+
}
263+
}
264+
265+
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
266+
impl_generics.params.insert(pointee_param_idx + 1, extra_param);
168267

169268
// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
170269
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
171270
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
172271
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
173272
}
273+
274+
fn is_sized_marker(path: &ast::Path) -> bool {
275+
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
276+
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
277+
if path.segments.len() == 3 {
278+
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
279+
|| path
280+
.segments
281+
.iter()
282+
.zip(STD_UNSIZE)
283+
.all(|(segment, symbol)| segment.ident.name == symbol)
284+
} else {
285+
*path == sym::Sized
286+
}
287+
}
288+
289+
struct TypeSubstitution<'a> {
290+
from_name: Symbol,
291+
to_ty: &'a AstTy,
292+
rewritten: bool,
293+
}
294+
295+
impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
296+
fn visit_ty(&mut self, ty: &mut AstTy) {
297+
if let Some(name) = ty.kind.is_simple_path()
298+
&& name == self.from_name
299+
{
300+
*ty = self.to_ty.clone();
301+
self.rewritten = true;
302+
return;
303+
}
304+
match &mut ty.kind {
305+
ast::TyKind::Slice(_)
306+
| ast::TyKind::Array(_, _)
307+
| ast::TyKind::Ptr(_)
308+
| ast::TyKind::Ref(_, _)
309+
| ast::TyKind::BareFn(_)
310+
| ast::TyKind::Never
311+
| ast::TyKind::Tup(_)
312+
| ast::TyKind::AnonStruct(_, _)
313+
| ast::TyKind::AnonUnion(_, _)
314+
| ast::TyKind::Path(_, _)
315+
| ast::TyKind::TraitObject(_, _)
316+
| ast::TyKind::ImplTrait(_, _)
317+
| ast::TyKind::Paren(_)
318+
| ast::TyKind::Typeof(_)
319+
| ast::TyKind::Infer
320+
| ast::TyKind::MacCall(_)
321+
| ast::TyKind::Pat(_, _) => ast::mut_visit::walk_ty(self, ty),
322+
ast::TyKind::ImplicitSelf
323+
| ast::TyKind::CVarArgs
324+
| ast::TyKind::Dummy
325+
| ast::TyKind::Err(_) => {}
326+
}
327+
}
328+
329+
fn visit_param_bound(&mut self, bound: &mut GenericBound, _ctxt: BoundKind) {
330+
match bound {
331+
GenericBound::Trait(trait_ref, _) => {
332+
self.visit_poly_trait_ref(trait_ref);
333+
}
334+
335+
GenericBound::Use(args, _span) => {
336+
for arg in args {
337+
self.visit_precise_capturing_arg(arg);
338+
}
339+
}
340+
GenericBound::Outlives(_) => {}
341+
}
342+
}
343+
344+
fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
345+
match where_predicate {
346+
rustc_ast::WherePredicate::BoundPredicate(bound) => {
347+
bound
348+
.bound_generic_params
349+
.flat_map_in_place(|param| self.flat_map_generic_param(param));
350+
self.visit_ty(&mut bound.bounded_ty);
351+
for bound in &mut bound.bounds {
352+
self.visit_param_bound(bound, BoundKind::Bound)
353+
}
354+
}
355+
rustc_ast::WherePredicate::RegionPredicate(_)
356+
| rustc_ast::WherePredicate::EqPredicate(_) => {}
357+
}
358+
}
359+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ check-pass
2+
//@ compile-flags: -Zunpretty=expanded
3+
#![feature(derive_smart_pointer)]
4+
use std::marker::SmartPointer;
5+
6+
pub trait MyTrait<T: ?Sized> {}
7+
8+
#[derive(SmartPointer)]
9+
#[repr(transparent)]
10+
struct MyPointer<'a, #[pointee] T: ?Sized> {
11+
ptr: &'a T,
12+
}
13+
14+
#[derive(core::marker::SmartPointer)]
15+
#[repr(transparent)]
16+
pub struct MyPointer2<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
17+
data: &'a mut T,
18+
x: core::marker::PhantomData<X>,
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#![feature(prelude_import)]
2+
#![no_std]
3+
//@ check-pass
4+
//@ compile-flags: -Zunpretty=expanded
5+
#![feature(derive_smart_pointer)]
6+
#[prelude_import]
7+
use ::std::prelude::rust_2015::*;
8+
#[macro_use]
9+
extern crate std;
10+
use std::marker::SmartPointer;
11+
12+
pub trait MyTrait<T: ?Sized> {}
13+
14+
#[repr(transparent)]
15+
struct MyPointer<'a, #[pointee] T: ?Sized> {
16+
ptr: &'a T,
17+
}
18+
#[automatically_derived]
19+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
20+
::core::ops::DispatchFromDyn<MyPointer<'a, __S>> for MyPointer<'a, T> {
21+
}
22+
#[automatically_derived]
23+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
24+
::core::ops::CoerceUnsized<MyPointer<'a, __S>> for MyPointer<'a, T> {
25+
}
26+
27+
#[repr(transparent)]
28+
pub struct MyPointer2<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
29+
data: &'a mut T,
30+
x: core::marker::PhantomData<X>,
31+
}
32+
#[automatically_derived]
33+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized, X: MyTrait<T> +
34+
MyTrait<__S>> ::core::ops::DispatchFromDyn<MyPointer2<'a, __S, X>> for
35+
MyPointer2<'a, T, X> {
36+
}
37+
#[automatically_derived]
38+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized, X: MyTrait<T> +
39+
MyTrait<__S>> ::core::ops::CoerceUnsized<MyPointer2<'a, __S, X>> for
40+
MyPointer2<'a, T, X> {
41+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//@ check-pass
2+
3+
#![feature(derive_smart_pointer)]
4+
5+
#[derive(core::marker::SmartPointer)]
6+
#[repr(transparent)]
7+
pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
8+
data: &'a mut T,
9+
x: core::marker::PhantomData<X>,
10+
}
11+
12+
pub trait OnDrop {
13+
fn on_drop(&mut self);
14+
}
15+
16+
#[derive(core::marker::SmartPointer)]
17+
#[repr(transparent)]
18+
pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
19+
where
20+
T: OnDrop,
21+
{
22+
data: &'a mut T,
23+
x: core::marker::PhantomData<X>,
24+
}
25+
26+
pub trait MyTrait<T: ?Sized> {}
27+
28+
#[derive(core::marker::SmartPointer)]
29+
#[repr(transparent)]
30+
pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
31+
where
32+
T: MyTrait<T>,
33+
{
34+
data: &'a mut T,
35+
x: core::marker::PhantomData<X>,
36+
}
37+
38+
#[derive(core::marker::SmartPointer)]
39+
#[repr(transparent)]
40+
pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
41+
data: &'a mut T,
42+
x: core::marker::PhantomData<X>,
43+
}
44+
45+
#[derive(core::marker::SmartPointer)]
46+
#[repr(transparent)]
47+
pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
48+
where
49+
Ptr5Companion<T>: MyTrait<T>,
50+
Ptr5Companion2: MyTrait<T>,
51+
{
52+
data: &'a mut T,
53+
x: core::marker::PhantomData<X>,
54+
}
55+
56+
pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
57+
pub struct Ptr5Companion2;
58+
59+
#[derive(core::marker::SmartPointer)]
60+
#[repr(transparent)]
61+
pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
62+
data: &'a mut T,
63+
x: core::marker::PhantomData<X>,
64+
}
65+
66+
// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/
67+
#[repr(transparent)]
68+
#[derive(core::marker::SmartPointer)]
69+
pub struct ListArc<#[pointee] T, const ID: u64 = 0>
70+
where
71+
T: ListArcSafe<ID> + ?Sized,
72+
{
73+
arc: *const T,
74+
}
75+
76+
pub trait ListArcSafe<const ID: u64> {}
77+
78+
fn main() {}

0 commit comments

Comments
 (0)