Skip to content

Commit e73fbc6

Browse files
committed
rustc_typeck: unify expected return types with formal return types to propagate coercions through calls of generic functions.
1 parent 21ec0c8 commit e73fbc6

File tree

5 files changed

+180
-23
lines changed

5 files changed

+180
-23
lines changed

Diff for: src/librustc/middle/infer/mod.rs

+33
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,39 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
613613
self.commit_unconditionally(move || self.try(move |_| f()))
614614
}
615615

616+
/// Execute `f` and commit only the region bindings if successful.
617+
/// The function f must be very careful not to leak any non-region
618+
/// variables that get created.
619+
pub fn commit_regions_if_ok<T, E, F>(&self, f: F) -> Result<T, E> where
620+
F: FnOnce() -> Result<T, E>
621+
{
622+
debug!("commit_regions_if_ok()");
623+
let CombinedSnapshot { type_snapshot,
624+
int_snapshot,
625+
float_snapshot,
626+
region_vars_snapshot } = self.start_snapshot();
627+
628+
let r = self.try(move |_| f());
629+
630+
// Roll back any non-region bindings - they should be resolved
631+
// inside `f`, with, e.g. `resolve_type_vars_if_possible`.
632+
self.type_variables
633+
.borrow_mut()
634+
.rollback_to(type_snapshot);
635+
self.int_unification_table
636+
.borrow_mut()
637+
.rollback_to(int_snapshot);
638+
self.float_unification_table
639+
.borrow_mut()
640+
.rollback_to(float_snapshot);
641+
642+
// Commit region vars that may escape through resolved types.
643+
self.region_vars
644+
.commit(region_vars_snapshot);
645+
646+
r
647+
}
648+
616649
/// Execute `f`, unroll bindings on panic
617650
pub fn try<T, E, F>(&self, f: F) -> Result<T, E> where
618651
F: FnOnce(&CombinedSnapshot) -> Result<T, E>

Diff for: src/librustc_typeck/check/callee.rs

+19-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use super::check_argument_types;
1414
use super::check_expr;
1515
use super::check_method_argument_types;
1616
use super::err_args;
17+
use super::Expectation;
18+
use super::expected_types_for_fn_args;
1719
use super::FnCtxt;
1820
use super::LvaluePreference;
1921
use super::method;
@@ -65,7 +67,8 @@ pub fn check_legal_trait_for_method_call(ccx: &CrateCtxt, span: Span, trait_id:
6567
pub fn check_call<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
6668
call_expr: &ast::Expr,
6769
callee_expr: &ast::Expr,
68-
arg_exprs: &[P<ast::Expr>])
70+
arg_exprs: &[P<ast::Expr>],
71+
expected: Expectation<'tcx>)
6972
{
7073
check_expr(fcx, callee_expr);
7174
let original_callee_ty = fcx.expr_ty(callee_expr);
@@ -84,15 +87,15 @@ pub fn check_call<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
8487
match result {
8588
None => {
8689
// this will report an error since original_callee_ty is not a fn
87-
confirm_builtin_call(fcx, call_expr, original_callee_ty, arg_exprs);
90+
confirm_builtin_call(fcx, call_expr, original_callee_ty, arg_exprs, expected);
8891
}
8992

9093
Some(CallStep::Builtin) => {
91-
confirm_builtin_call(fcx, call_expr, callee_ty, arg_exprs);
94+
confirm_builtin_call(fcx, call_expr, callee_ty, arg_exprs, expected);
9295
}
9396

9497
Some(CallStep::Overloaded(method_callee)) => {
95-
confirm_overloaded_call(fcx, call_expr, arg_exprs, method_callee);
98+
confirm_overloaded_call(fcx, call_expr, arg_exprs, method_callee, expected);
9699
}
97100
}
98101
}
@@ -153,7 +156,8 @@ fn try_overloaded_call_step<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
153156
fn confirm_builtin_call<'a,'tcx>(fcx: &FnCtxt<'a,'tcx>,
154157
call_expr: &ast::Expr,
155158
callee_ty: Ty<'tcx>,
156-
arg_exprs: &[P<ast::Expr>])
159+
arg_exprs: &[P<ast::Expr>],
160+
expected: Expectation<'tcx>)
157161
{
158162
let error_fn_sig;
159163

@@ -192,9 +196,15 @@ fn confirm_builtin_call<'a,'tcx>(fcx: &FnCtxt<'a,'tcx>,
192196
fcx.normalize_associated_types_in(call_expr.span, &fn_sig);
193197

194198
// Call the generic checker.
199+
let expected_arg_tys = expected_types_for_fn_args(fcx,
200+
call_expr.span,
201+
expected,
202+
fn_sig.output,
203+
fn_sig.inputs.as_slice());
195204
check_argument_types(fcx,
196205
call_expr.span,
197206
fn_sig.inputs.as_slice(),
207+
&expected_arg_tys[],
198208
arg_exprs,
199209
AutorefArgs::No,
200210
fn_sig.variadic,
@@ -206,15 +216,17 @@ fn confirm_builtin_call<'a,'tcx>(fcx: &FnCtxt<'a,'tcx>,
206216
fn confirm_overloaded_call<'a,'tcx>(fcx: &FnCtxt<'a, 'tcx>,
207217
call_expr: &ast::Expr,
208218
arg_exprs: &[P<ast::Expr>],
209-
method_callee: ty::MethodCallee<'tcx>)
219+
method_callee: ty::MethodCallee<'tcx>,
220+
expected: Expectation<'tcx>)
210221
{
211222
let output_type = check_method_argument_types(fcx,
212223
call_expr.span,
213224
method_callee.ty,
214225
call_expr,
215226
arg_exprs,
216227
AutorefArgs::No,
217-
TupleArgumentsFlag::TupleArguments);
228+
TupleArgumentsFlag::TupleArguments,
229+
expected);
218230
let method_call = ty::MethodCall::expr(call_expr.id);
219231
fcx.inh.method_map.borrow_mut().insert(method_call, method_callee);
220232
write_call(fcx, call_expr, output_type);

Diff for: src/librustc_typeck/check/mod.rs

+99-16
Original file line numberDiff line numberDiff line change
@@ -2559,7 +2559,8 @@ fn lookup_method_for_for_loop<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
25592559
iterator_expr,
25602560
&[],
25612561
AutorefArgs::No,
2562-
DontTupleArguments);
2562+
DontTupleArguments,
2563+
NoExpectation);
25632564

25642565
match method {
25652566
Some(method) => {
@@ -2601,7 +2602,8 @@ fn check_method_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26012602
callee_expr: &ast::Expr,
26022603
args_no_rcvr: &[P<ast::Expr>],
26032604
autoref_args: AutorefArgs,
2604-
tuple_arguments: TupleArgumentsFlag)
2605+
tuple_arguments: TupleArgumentsFlag,
2606+
expected: Expectation<'tcx>)
26052607
-> ty::FnOutput<'tcx> {
26062608
if ty::type_is_error(method_fn_ty) {
26072609
let err_inputs = err_args(fcx.tcx(), args_no_rcvr.len());
@@ -2614,6 +2616,7 @@ fn check_method_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26142616
check_argument_types(fcx,
26152617
sp,
26162618
&err_inputs[],
2619+
&[],
26172620
args_no_rcvr,
26182621
autoref_args,
26192622
false,
@@ -2623,9 +2626,15 @@ fn check_method_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26232626
match method_fn_ty.sty {
26242627
ty::ty_bare_fn(_, ref fty) => {
26252628
// HACK(eddyb) ignore self in the definition (see above).
2629+
let expected_arg_tys = expected_types_for_fn_args(fcx,
2630+
sp,
2631+
expected,
2632+
fty.sig.0.output,
2633+
&fty.sig.0.inputs[1..]);
26262634
check_argument_types(fcx,
26272635
sp,
26282636
&fty.sig.0.inputs[1..],
2637+
&expected_arg_tys[],
26292638
args_no_rcvr,
26302639
autoref_args,
26312640
fty.sig.0.variadic,
@@ -2645,6 +2654,7 @@ fn check_method_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26452654
fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26462655
sp: Span,
26472656
fn_inputs: &[Ty<'tcx>],
2657+
expected_arg_tys: &[Ty<'tcx>],
26482658
args: &[P<ast::Expr>],
26492659
autoref_args: AutorefArgs,
26502660
variadic: bool,
@@ -2659,6 +2669,7 @@ fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26592669
1
26602670
};
26612671

2672+
let mut expected_arg_tys = expected_arg_tys;
26622673
let expected_arg_count = fn_inputs.len();
26632674
let formal_tys = if tuple_arguments == TupleArguments {
26642675
let tuple_type = structurally_resolved_type(fcx, sp, fn_inputs[0]);
@@ -2671,23 +2682,32 @@ fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26712682
if arg_types.len() == 1 {""} else {"s"},
26722683
args.len(),
26732684
if args.len() == 1 {" was"} else {"s were"});
2685+
expected_arg_tys = &[][];
26742686
err_args(fcx.tcx(), args.len())
26752687
} else {
2688+
expected_arg_tys = match expected_arg_tys.get(0) {
2689+
Some(&ty) => match ty.sty {
2690+
ty::ty_tup(ref tys) => &**tys,
2691+
_ => &[]
2692+
},
2693+
None => &[]
2694+
};
26762695
(*arg_types).clone()
26772696
}
26782697
}
26792698
_ => {
26802699
span_err!(tcx.sess, sp, E0059,
26812700
"cannot use call notation; the first type parameter \
26822701
for the function trait is neither a tuple nor unit");
2702+
expected_arg_tys = &[][];
26832703
err_args(fcx.tcx(), args.len())
26842704
}
26852705
}
26862706
} else if expected_arg_count == supplied_arg_count {
2687-
fn_inputs.iter().map(|a| *a).collect()
2707+
fn_inputs.to_vec()
26882708
} else if variadic {
26892709
if supplied_arg_count >= expected_arg_count {
2690-
fn_inputs.iter().map(|a| *a).collect()
2710+
fn_inputs.to_vec()
26912711
} else {
26922712
span_err!(tcx.sess, sp, E0060,
26932713
"this function takes at least {} parameter{} \
@@ -2696,6 +2716,7 @@ fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
26962716
if expected_arg_count == 1 {""} else {"s"},
26972717
supplied_arg_count,
26982718
if supplied_arg_count == 1 {" was"} else {"s were"});
2719+
expected_arg_tys = &[][];
26992720
err_args(fcx.tcx(), supplied_arg_count)
27002721
}
27012722
} else {
@@ -2705,6 +2726,7 @@ fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
27052726
if expected_arg_count == 1 {""} else {"s"},
27062727
supplied_arg_count,
27072728
if supplied_arg_count == 1 {" was"} else {"s were"});
2729+
expected_arg_tys = &[][];
27082730
err_args(fcx.tcx(), supplied_arg_count)
27092731
};
27102732

@@ -2768,7 +2790,25 @@ fn check_argument_types<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
27682790
AutorefArgs::No => {}
27692791
}
27702792

2771-
check_expr_coercable_to_type(fcx, &**arg, formal_ty);
2793+
// The special-cased logic below has three functions:
2794+
// 1. Provide as good of an expected type as possible.
2795+
let expected = expected_arg_tys.get(i).map(|&ty| {
2796+
Expectation::rvalue_hint(ty)
2797+
});
2798+
2799+
check_expr_with_unifier(fcx, &**arg,
2800+
expected.unwrap_or(ExpectHasType(formal_ty)),
2801+
NoPreference, || {
2802+
// 2. Coerce to the most detailed type that could be coerced
2803+
// to, which is `expected_ty` if `rvalue_hint` returns an
2804+
// `ExprHasType(expected_ty)`, or the `formal_ty` otherwise.
2805+
let coerce_ty = expected.and_then(|e| e.only_has_type(fcx));
2806+
demand::coerce(fcx, arg.span, coerce_ty.unwrap_or(formal_ty), &**arg);
2807+
2808+
// 3. Relate the expected type and the formal one,
2809+
// if the expected type was used for the coercion.
2810+
coerce_ty.map(|ty| demand::suptype(fcx, arg.span, formal_ty, ty));
2811+
});
27722812
}
27732813
}
27742814
}
@@ -3008,6 +3048,45 @@ enum TupleArgumentsFlag {
30083048
TupleArguments,
30093049
}
30103050

3051+
/// Unifies the return type with the expected type early, for more coercions
3052+
/// and forward type information on the argument expressions.
3053+
fn expected_types_for_fn_args<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
3054+
call_span: Span,
3055+
expected_ret: Expectation<'tcx>,
3056+
formal_ret: ty::FnOutput<'tcx>,
3057+
formal_args: &[Ty<'tcx>])
3058+
-> Vec<Ty<'tcx>> {
3059+
let expected_args = expected_ret.only_has_type(fcx).and_then(|ret_ty| {
3060+
if let ty::FnConverging(formal_ret_ty) = formal_ret {
3061+
fcx.infcx().commit_regions_if_ok(|| {
3062+
// Attempt to apply a subtyping relationship between the formal
3063+
// return type (likely containing type variables if the function
3064+
// is polymorphic) and the expected return type.
3065+
// No argument expectations are produced if unification fails.
3066+
let origin = infer::Misc(call_span);
3067+
let ures = fcx.infcx().sub_types(false, origin, formal_ret_ty, ret_ty);
3068+
// FIXME(#15760) can't use try! here, FromError doesn't default
3069+
// to identity so the resulting type is not constrained.
3070+
if let Err(e) = ures {
3071+
return Err(e);
3072+
}
3073+
3074+
// Record all the argument types, with the substitutions
3075+
// produced from the above subtyping unification.
3076+
Ok(formal_args.iter().map(|ty| {
3077+
fcx.infcx().resolve_type_vars_if_possible(ty)
3078+
}).collect())
3079+
}).ok()
3080+
} else {
3081+
None
3082+
}
3083+
}).unwrap_or(vec![]);
3084+
debug!("expected_types_for_fn_args(formal={} -> {}, expected={} -> {})",
3085+
formal_args.repr(fcx.tcx()), formal_ret.repr(fcx.tcx()),
3086+
expected_args.repr(fcx.tcx()), expected_ret.repr(fcx.tcx()));
3087+
expected_args
3088+
}
3089+
30113090
/// Invariant:
30123091
/// If an expression has any sub-expressions that result in a type error,
30133092
/// inspecting that expression's type with `ty::type_is_error` will return
@@ -3029,12 +3108,13 @@ fn check_expr_with_unifier<'a, 'tcx, F>(fcx: &FnCtxt<'a, 'tcx>,
30293108
expr.repr(fcx.tcx()), expected.repr(fcx.tcx()));
30303109

30313110
// Checks a method call.
3032-
fn check_method_call(fcx: &FnCtxt,
3033-
expr: &ast::Expr,
3034-
method_name: ast::SpannedIdent,
3035-
args: &[P<ast::Expr>],
3036-
tps: &[P<ast::Ty>],
3037-
lvalue_pref: LvaluePreference) {
3111+
fn check_method_call<'a, 'tcx>(fcx: &FnCtxt<'a, 'tcx>,
3112+
expr: &ast::Expr,
3113+
method_name: ast::SpannedIdent,
3114+
args: &[P<ast::Expr>],
3115+
tps: &[P<ast::Ty>],
3116+
expected: Expectation<'tcx>,
3117+
lvalue_pref: LvaluePreference) {
30383118
let rcvr = &*args[0];
30393119
check_expr_with_lvalue_pref(fcx, &*rcvr, lvalue_pref);
30403120

@@ -3071,7 +3151,8 @@ fn check_expr_with_unifier<'a, 'tcx, F>(fcx: &FnCtxt<'a, 'tcx>,
30713151
expr,
30723152
&args[1..],
30733153
AutorefArgs::No,
3074-
DontTupleArguments);
3154+
DontTupleArguments,
3155+
expected);
30753156

30763157
write_call(fcx, expr, ret_ty);
30773158
}
@@ -3182,7 +3263,8 @@ fn check_expr_with_unifier<'a, 'tcx, F>(fcx: &FnCtxt<'a, 'tcx>,
31823263
op_ex,
31833264
args,
31843265
autoref_args,
3185-
DontTupleArguments) {
3266+
DontTupleArguments,
3267+
NoExpectation) {
31863268
ty::FnConverging(result_type) => result_type,
31873269
ty::FnDiverging => fcx.tcx().types.err
31883270
}
@@ -3198,7 +3280,8 @@ fn check_expr_with_unifier<'a, 'tcx, F>(fcx: &FnCtxt<'a, 'tcx>,
31983280
op_ex,
31993281
args,
32003282
autoref_args,
3201-
DontTupleArguments);
3283+
DontTupleArguments,
3284+
NoExpectation);
32023285
fcx.tcx().types.err
32033286
}
32043287
}
@@ -4045,10 +4128,10 @@ fn check_expr_with_unifier<'a, 'tcx, F>(fcx: &FnCtxt<'a, 'tcx>,
40454128
fcx.write_ty(id, fcx.node_ty(b.id));
40464129
}
40474130
ast::ExprCall(ref callee, ref args) => {
4048-
callee::check_call(fcx, expr, &**callee, &args[]);
4131+
callee::check_call(fcx, expr, &**callee, &args[], expected);
40494132
}
40504133
ast::ExprMethodCall(ident, ref tps, ref args) => {
4051-
check_method_call(fcx, expr, ident, &args[], &tps[], lvalue_pref);
4134+
check_method_call(fcx, expr, ident, &args[], &tps[], expected, lvalue_pref);
40524135
let arg_tys = args.iter().map(|a| fcx.expr_ty(&**a));
40534136
let args_err = arg_tys.fold(false,
40544137
|rest_err, a| {

Diff for: src/test/run-pass/coerce-expect-unsized.rs

+3
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@ pub fn main() {
3030
let _: &Fn(int) -> _ = &{ |x| (x as u8) };
3131
let _: &Show = &if true { false } else { true };
3232
let _: &Show = &match true { true => 'a', false => 'b' };
33+
34+
let _: Box<[int]> = Box::new([1, 2, 3]);
35+
let _: Box<Fn(int) -> _> = Box::new(|x| (x as u8));
3336
}

0 commit comments

Comments
 (0)