22
22
23
23
using llvm::cast;
24
24
using llvm::dyn_cast;
25
+ using llvm::isa;
25
26
26
27
namespace Carbon {
27
28
29
+ TypeChecker::ReturnTypeContext::ReturnTypeContext (
30
+ Nonnull<const Value*> orig_return_type, bool is_omitted)
31
+ : is_auto_(isa<AutoType>(orig_return_type)),
32
+ deduced_return_type_ (is_auto_ ? std::nullopt
33
+ : std::optional(orig_return_type)),
34
+ is_omitted_(is_omitted) {}
35
+
28
36
void PrintTypeEnv (TypeEnv types, llvm::raw_ostream& out) {
29
37
llvm::ListSeparator sep;
30
38
for (const auto & [name, type] : types) {
@@ -622,18 +630,17 @@ auto TypeChecker::TypeCheckPattern(
622
630
auto TypeChecker::TypeCheckCase (Nonnull<const Value*> expected,
623
631
Nonnull<Pattern*> pat, Nonnull<Statement*> body,
624
632
TypeEnv types, Env values,
625
- Nonnull<const Value*>& ret_type,
626
- bool is_omitted_ret_type)
633
+ Nonnull<ReturnTypeContext*> return_type_context)
627
634
-> std::pair<Nonnull<Pattern*>, Nonnull<Statement*>> {
628
635
auto pat_res = TypeCheckPattern (pat, types, values, expected);
629
- auto res =
630
- TypeCheckStmt (body, pat_res.types , values, ret_type, is_omitted_ret_type);
636
+ auto res = TypeCheckStmt (body, pat_res.types , values, return_type_context);
631
637
return std::make_pair (pat, res.stmt );
632
638
}
633
639
634
640
auto TypeChecker::TypeCheckStmt (Nonnull<Statement*> s, TypeEnv types,
635
- Env values, Nonnull<const Value*>& ret_type,
636
- bool is_omitted_ret_type) -> TCStatement {
641
+ Env values,
642
+ Nonnull<ReturnTypeContext*> return_type_context)
643
+ -> TCStatement {
637
644
switch (s->Tag ()) {
638
645
case Statement::Kind::Match: {
639
646
auto & match = cast<Match>(*s);
@@ -644,7 +651,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
644
651
for (auto & clause : match.Clauses ()) {
645
652
new_clauses.push_back (TypeCheckCase (res_type, clause.first ,
646
653
clause.second , types, values,
647
- ret_type, is_omitted_ret_type ));
654
+ return_type_context ));
648
655
}
649
656
auto new_s = arena->New <Match>(s->SourceLoc (), res.exp , new_clauses);
650
657
return TCStatement (new_s, types);
@@ -654,8 +661,8 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
654
661
auto cnd_res = TypeCheckExp (while_stmt.Cond (), types, values);
655
662
ExpectType (s->SourceLoc (), " condition of `while`" , arena->New <BoolType>(),
656
663
cnd_res.type );
657
- auto body_res = TypeCheckStmt (while_stmt. Body (), types, values, ret_type,
658
- is_omitted_ret_type );
664
+ auto body_res =
665
+ TypeCheckStmt (while_stmt. Body (), types, values, return_type_context );
659
666
auto new_s =
660
667
arena->New <While>(s->SourceLoc (), cnd_res.exp , body_res.stmt );
661
668
return TCStatement (new_s, types);
@@ -666,8 +673,8 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
666
673
case Statement::Kind::Block: {
667
674
auto & block = cast<Block>(*s);
668
675
if (block.Stmt ()) {
669
- auto stmt_res = TypeCheckStmt (*block. Stmt (), types, values, ret_type,
670
- is_omitted_ret_type );
676
+ auto stmt_res =
677
+ TypeCheckStmt (*block. Stmt (), types, values, return_type_context );
671
678
return TCStatement (arena->New <Block>(s->SourceLoc (), stmt_res.stmt ),
672
679
types);
673
680
} else {
@@ -685,13 +692,13 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
685
692
}
686
693
case Statement::Kind::Sequence: {
687
694
auto & seq = cast<Sequence>(*s);
688
- auto stmt_res = TypeCheckStmt (seq. Stmt (), types, values, ret_type,
689
- is_omitted_ret_type );
695
+ auto stmt_res =
696
+ TypeCheckStmt (seq. Stmt (), types, values, return_type_context );
690
697
auto checked_types = stmt_res.types ;
691
698
std::optional<Nonnull<Statement*>> next_stmt;
692
699
if (seq.Next ()) {
693
700
auto next_res = TypeCheckStmt (*seq.Next (), checked_types, values,
694
- ret_type, is_omitted_ret_type );
701
+ return_type_context );
695
702
next_stmt = next_res.stmt ;
696
703
checked_types = next_res.types ;
697
704
}
@@ -720,12 +727,12 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
720
727
auto cnd_res = TypeCheckExp (if_stmt.Cond (), types, values);
721
728
ExpectType (s->SourceLoc (), " condition of `if`" , arena->New <BoolType>(),
722
729
cnd_res.type );
723
- auto then_res = TypeCheckStmt (if_stmt. ThenStmt (), types, values, ret_type,
724
- is_omitted_ret_type );
730
+ auto then_res =
731
+ TypeCheckStmt (if_stmt. ThenStmt (), types, values, return_type_context );
725
732
std::optional<Nonnull<Statement*>> else_stmt;
726
733
if (if_stmt.ElseStmt ()) {
727
734
auto else_res = TypeCheckStmt (*if_stmt.ElseStmt (), types, values,
728
- ret_type, is_omitted_ret_type );
735
+ return_type_context );
729
736
else_stmt = else_res.stmt ;
730
737
}
731
738
auto new_s =
@@ -735,17 +742,24 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
735
742
case Statement::Kind::Return: {
736
743
auto & ret = cast<Return>(*s);
737
744
auto res = TypeCheckExp (ret.Exp (), types, values);
738
- if (ret_type->Tag () == Value::Kind::AutoType) {
739
- // The following infers the return type from the first 'return'
740
- // statement. This will get more difficult with subtyping, when we
741
- // should infer the least-upper bound of all the 'return' statements.
742
- ret_type = res.type ;
745
+ if (return_type_context->is_auto ()) {
746
+ if (return_type_context->deduced_return_type ()) {
747
+ // Only one return is allowed when the return type is `auto`.
748
+ FATAL_COMPILATION_ERROR (s->SourceLoc ())
749
+ << " Only one return is allowed in a function with an `auto` "
750
+ " return type." ;
751
+ } else {
752
+ // Infer the auto return from the first `return` statement.
753
+ return_type_context->set_deduced_return_type (res.type );
754
+ }
743
755
} else {
744
- ExpectType (s->SourceLoc (), " return" , ret_type, res.type );
756
+ ExpectType (s->SourceLoc (), " return" ,
757
+ *return_type_context->deduced_return_type (), res.type );
745
758
}
746
- if (ret.IsOmittedExp () != is_omitted_ret_type ) {
759
+ if (ret.IsOmittedExp () != return_type_context-> is_omitted () ) {
747
760
FATAL_COMPILATION_ERROR (s->SourceLoc ())
748
- << *s << " should" << (is_omitted_ret_type ? " not" : " " )
761
+ << *s << " should"
762
+ << (return_type_context->is_omitted () ? " not" : " " )
749
763
<< " provide a return value, to match the function's signature." ;
750
764
}
751
765
return TCStatement (
@@ -754,8 +768,8 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
754
768
}
755
769
case Statement::Kind::Continuation: {
756
770
auto & cont = cast<Continuation>(*s);
757
- TCStatement body_result = TypeCheckStmt (cont. Body (), types, values,
758
- ret_type, is_omitted_ret_type );
771
+ TCStatement body_result =
772
+ TypeCheckStmt (cont. Body (), types, values, return_type_context );
759
773
auto new_continuation = arena->New <Continuation>(
760
774
s->SourceLoc (), cont.ContinuationVariable (), body_result.stmt );
761
775
types.Set (cont.ContinuationVariable (), arena->New <ContinuationType>());
@@ -875,9 +889,13 @@ auto TypeChecker::TypeCheckFunDef(FunctionDefinition* f, TypeEnv types,
875
889
}
876
890
std::optional<Nonnull<Statement*>> body_stmt;
877
891
if (f->body ()) {
878
- auto res = TypeCheckStmt (*f->body (), param_res.types , values, return_type,
879
- f->is_omitted_return_type ());
892
+ ReturnTypeContext return_type_context (return_type,
893
+ f->is_omitted_return_type ());
894
+ auto res = TypeCheckStmt (*f->body (), param_res.types , values,
895
+ &return_type_context);
880
896
body_stmt = res.stmt ;
897
+ // Save the return type in case it changed.
898
+ return_type = *return_type_context.deduced_return_type ();
881
899
}
882
900
auto body = CheckOrEnsureReturn (body_stmt, f->is_omitted_return_type (),
883
901
f->source_loc ());
0 commit comments