Skip to content

Check yield terminator's resume type in borrowck #119563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions compiler/rustc_borrowck/src/type_check/input_output.rs
Original file line number Diff line number Diff line change
@@ -94,31 +94,22 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
);
}

debug!(
"equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
body.yield_ty(),
universal_regions.yield_ty
);

// We will not have a universal_regions.yield_ty if we yield (by accident)
// outside of a coroutine and return an `impl Trait`, so emit a span_delayed_bug
// because we don't want to panic in an assert here if we've already got errors.
if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() {
self.tcx().dcx().span_delayed_bug(
body.span,
format!(
"Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})",
body.yield_ty(),
universal_regions.yield_ty,
),
if let Some(mir_yield_ty) = body.yield_ty() {
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
self.equate_normalized_input_or_output(
universal_regions.yield_ty.unwrap(),
mir_yield_ty,
yield_span,
);
}

if let (Some(mir_yield_ty), Some(ur_yield_ty)) =
(body.yield_ty(), universal_regions.yield_ty)
{
if let Some(mir_resume_ty) = body.resume_ty() {
let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span);
self.equate_normalized_input_or_output(
universal_regions.resume_ty.unwrap(),
mir_resume_ty,
yield_span,
);
}

// Return types are a bit more complex. They may contain opaque `impl Trait` types.
1 change: 1 addition & 0 deletions compiler/rustc_borrowck/src/type_check/liveness/mod.rs
Original file line number Diff line number Diff line change
@@ -183,6 +183,7 @@ impl<'cx, 'tcx> Visitor<'tcx> for LiveVariablesVisitor<'cx, 'tcx> {
match ty_context {
TyContext::ReturnTy(SourceInfo { span, .. })
| TyContext::YieldTy(SourceInfo { span, .. })
| TyContext::ResumeTy(SourceInfo { span, .. })
| TyContext::UserTy(span)
| TyContext::LocalDecl { source_info: SourceInfo { span, .. }, .. } => {
span_bug!(span, "should not be visiting outside of the CFG: {:?}", ty_context);
26 changes: 24 additions & 2 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
@@ -1450,13 +1450,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
}
TerminatorKind::Yield { value, .. } => {
TerminatorKind::Yield { value, resume_arg, .. } => {
self.check_operand(value, term_location);

let value_ty = value.ty(body, tcx);
match body.yield_ty() {
None => span_mirbug!(self, term, "yield in non-coroutine"),
Some(ty) => {
let value_ty = value.ty(body, tcx);
if let Err(terr) = self.sub_types(
value_ty,
ty,
@@ -1474,6 +1474,28 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
}

match body.resume_ty() {
None => span_mirbug!(self, term, "yield in non-coroutine"),
Some(ty) => {
let resume_ty = resume_arg.ty(body, tcx);
if let Err(terr) = self.sub_types(
ty,
resume_ty.ty,
term_location.to_locations(),
ConstraintCategory::Yield,
) {
span_mirbug!(
self,
term,
"type of resume place is {:?}, but the resume type is {:?}: {:?}",
resume_ty,
ty,
terr
);
}
}
}
}
}
}
12 changes: 9 additions & 3 deletions compiler/rustc_borrowck/src/universal_regions.rs
Original file line number Diff line number Diff line change
@@ -76,6 +76,8 @@ pub struct UniversalRegions<'tcx> {
pub unnormalized_input_tys: &'tcx [Ty<'tcx>],

pub yield_ty: Option<Ty<'tcx>>,

pub resume_ty: Option<Ty<'tcx>>,
}

/// The "defining type" for this MIR. The key feature of the "defining
@@ -525,9 +527,12 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
debug!("build: extern regions = {}..{}", first_extern_index, first_local_index);
debug!("build: local regions = {}..{}", first_local_index, num_universals);

let yield_ty = match defining_ty {
DefiningTy::Coroutine(_, args) => Some(args.as_coroutine().yield_ty()),
_ => None,
let (resume_ty, yield_ty) = match defining_ty {
DefiningTy::Coroutine(_, args) => {
let tys = args.as_coroutine();
(Some(tys.resume_ty()), Some(tys.yield_ty()))
}
_ => (None, None),
};

UniversalRegions {
@@ -541,6 +546,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
unnormalized_output_ty: *unnormalized_output_ty,
unnormalized_input_tys,
yield_ty,
resume_ty,
}
}

9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
@@ -250,6 +250,9 @@ pub struct CoroutineInfo<'tcx> {
/// The yield type of the function, if it is a coroutine.
pub yield_ty: Option<Ty<'tcx>>,

/// The resume type of the function, if it is a coroutine.
pub resume_ty: Option<Ty<'tcx>>,

/// Coroutine drop glue.
pub coroutine_drop: Option<Body<'tcx>>,

@@ -385,6 +388,7 @@ impl<'tcx> Body<'tcx> {
coroutine: coroutine_kind.map(|coroutine_kind| {
Box::new(CoroutineInfo {
yield_ty: None,
resume_ty: None,
coroutine_drop: None,
coroutine_layout: None,
coroutine_kind,
@@ -551,6 +555,11 @@ impl<'tcx> Body<'tcx> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.yield_ty)
}

#[inline]
pub fn resume_ty(&self) -> Option<Ty<'tcx>> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
}

#[inline]
pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
@@ -996,6 +996,12 @@ macro_rules! super_body {
TyContext::YieldTy(SourceInfo::outermost(span))
);
}
if let Some(resume_ty) = $(& $mutability)? gen.resume_ty {
$self.visit_ty(
resume_ty,
TyContext::ResumeTy(SourceInfo::outermost(span))
);
}
}

for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) {
@@ -1244,6 +1250,8 @@ pub enum TyContext {

YieldTy(SourceInfo),

ResumeTy(SourceInfo),

/// A type found at some location.
Location(Location),
}
27 changes: 17 additions & 10 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
@@ -488,17 +488,17 @@ fn construct_fn<'tcx>(

let arguments = &thir.params;

let (yield_ty, return_ty) = if coroutine_kind.is_some() {
let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
let coroutine_sig = match coroutine_ty.kind() {
ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
_ => {
span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
}
};
(Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
(Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
} else {
(None, fn_sig.output())
(None, None, fn_sig.output())
};

if let Some(custom_mir_attr) =
@@ -562,9 +562,12 @@ fn construct_fn<'tcx>(
} else {
None
};
if yield_ty.is_some() {

if coroutine_kind.is_some() {
body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
}

body
}

@@ -631,28 +634,29 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
let hir_id = tcx.local_def_id_to_hir_id(def_id);
let coroutine_kind = tcx.coroutine_kind(def_id);

let (inputs, output, yield_ty) = match tcx.def_kind(def_id) {
let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
DefKind::Const
| DefKind::AssocConst
| DefKind::AnonConst
| DefKind::InlineConst
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
| DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
let sig = tcx.liberate_late_bound_regions(
def_id.to_def_id(),
tcx.fn_sig(def_id).instantiate_identity(),
);
(sig.inputs().to_vec(), sig.output(), None)
(sig.inputs().to_vec(), sig.output(), None, None)
}
DefKind::Closure if coroutine_kind.is_some() => {
let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
let ty::Coroutine(_, args) = coroutine_ty.kind() else {
bug!("expected type of coroutine-like closure to be a coroutine")
};
let args = args.as_coroutine();
let resume_ty = args.resume_ty();
let yield_ty = args.yield_ty();
let return_ty = args.return_ty();
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(yield_ty))
(vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
}
DefKind::Closure => {
let closure_ty = tcx.type_of(def_id).instantiate_identity();
@@ -666,7 +670,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
ty::ClosureKind::FnOnce => closure_ty,
};
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None)
([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
}
dk => bug!("{:?} is not a body: {:?}", def_id, dk),
};
@@ -705,7 +709,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
Some(guar),
);

body.coroutine.as_mut().map(|gen| gen.yield_ty = yield_ty);
body.coroutine.as_mut().map(|gen| {
gen.yield_ty = yield_ty;
gen.resume_ty = resume_ty;
});

body
}
1 change: 1 addition & 0 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1733,6 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}

body.coroutine.as_mut().unwrap().yield_ty = None;
body.coroutine.as_mut().unwrap().resume_ty = None;
body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);

// Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
35 changes: 35 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes-2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(coroutine_trait)]
#![feature(coroutines)]

use std::ops::Coroutine;

struct Contravariant<'a>(fn(&'a ()));
struct Covariant<'a>(fn() -> &'a ());

fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
|_: Covariant<'short>| {
let a: Covariant<'long> = yield ();
//~^ ERROR lifetime may not live long enough
}
}

fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
|_: Contravariant<'long>| {
let a: Contravariant<'short> = yield ();
//~^ ERROR lifetime may not live long enough
}
}

fn good1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'long>> {
|_: Covariant<'long>| {
let a: Covariant<'short> = yield ();
}
}

fn good2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'short>> {
|_: Contravariant<'short>| {
let a: Contravariant<'long> = yield ();
}
}

fn main() {}
36 changes: 36 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes-2.rs:11:16
|
LL | fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
| ------ ----- lifetime `'long` defined here
| |
| lifetime `'short` defined here
LL | |_: Covariant<'short>| {
LL | let a: Covariant<'long> = yield ();
| ^^^^^^^^^^^^^^^^ type annotation requires that `'short` must outlive `'long`
|
= help: consider adding the following bound: `'short: 'long`
help: consider adding 'move' keyword before the nested closure
|
LL | move |_: Covariant<'short>| {
| ++++

error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes-2.rs:18:40
|
LL | fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
| ------ ----- lifetime `'long` defined here
| |
| lifetime `'short` defined here
LL | |_: Contravariant<'long>| {
LL | let a: Contravariant<'short> = yield ();
| ^^^^^^^^ yielding this value requires that `'short` must outlive `'long`
|
= help: consider adding the following bound: `'short: 'long`
help: consider adding 'move' keyword before the nested closure
|
LL | move |_: Contravariant<'long>| {
| ++++

error: aborting due to 2 previous errors

27 changes: 27 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#![feature(coroutine_trait)]
#![feature(coroutines)]
#![allow(unused)]

use std::ops::Coroutine;
use std::ops::CoroutineState;
use std::pin::pin;

fn mk_static(s: &str) -> &'static str {
let mut storage: Option<&'static str> = None;

let mut coroutine = pin!(|_: &str| {
let x: &'static str = yield ();
//~^ ERROR lifetime may not live long enough
storage = Some(x);
});

coroutine.as_mut().resume(s);
coroutine.as_mut().resume(s);

storage.unwrap()
}

fn main() {
let s = mk_static(&String::from("hello, world"));
println!("{s}");
}
11 changes: 11 additions & 0 deletions tests/ui/coroutine/check-resume-ty-lifetimes.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error: lifetime may not live long enough
--> $DIR/check-resume-ty-lifetimes.rs:13:16
|
LL | fn mk_static(s: &str) -> &'static str {
| - let's call the lifetime of this reference `'1`
...
LL | let x: &'static str = yield ();
| ^^^^^^^^^^^^ type annotation requires that `'1` must outlive `'static`

error: aborting due to 1 previous error