Skip to content

Commit 0f5e564

Browse files
committed
handle conformance requirement on extension in distributed funcs
1 parent 436ecb2 commit 0f5e564

File tree

4 files changed

+70
-96
lines changed

4 files changed

+70
-96
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor);
5050
/// Similar to `getDistributedSerializationRequirementType`, however, from the
5151
/// perspective of a concrete function. This way we're able to get the
5252
/// serialization requirement for specific members, also in protocols.
53-
Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member);
53+
Type getSerializationRequirementTypesForMember(
54+
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);
5455

5556
/// Get specific 'SerializationRequirement' as defined in 'nominal'
5657
/// type, which must conform to the passed 'protocol' which is expected
@@ -114,17 +115,6 @@ getDistributedSerializationRequirements(
114115
ProtocolDecl *protocol,
115116
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
116117

117-
/// Given any set of generic requirements, locate those which are about the
118-
/// `SerializationRequirement`. Those need to be applied in the parameter and
119-
/// return type checking of distributed targets.
120-
void
121-
extractDistributedSerializationRequirements(
122-
ASTContext &C,
123-
ArrayRef<Requirement> allRequirements,
124-
llvm::SmallPtrSet<ProtocolDecl *, 2> &into);
125-
126118
}
127119

128-
// ==== ------------------------------------------------------------------------
129-
130120
#endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */

lib/AST/DistributedDecl.cpp

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
9595
llvm_unreachable("Unable to fetch ActorSystem type!");
9696
}
9797

98-
Type swift::getConcreteReplacementForMemberSerializationRequirement(
99-
ValueDecl *member) {
98+
Type swift::getSerializationRequirementTypesForMember(
99+
ValueDecl *member,
100+
llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements) {
100101
auto &C = member->getASTContext();
101102
auto *DC = member->getDeclContext();
102103
auto DA = C.getDistributedActorDecl();
@@ -117,6 +118,18 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
117118
signature = DC->getGenericSignatureOfContext();
118119
}
119120

121+
// Also store all `SerializationRequirement : SomeProtocol` requirements
122+
for (auto requirement: signature.getRequirements()) {
123+
if (requirement.getFirstType()->isEqual(SerReqAssocType) &&
124+
requirement.getKind() == RequirementKind::Conformance) {
125+
if (auto nominal = requirement.getSecondType()->getAnyNominal()) {
126+
if (auto protocol = dyn_cast<ProtocolDecl>(nominal)) {
127+
serializationRequirements.insert(protocol);
128+
}
129+
}
130+
}
131+
}
132+
120133
// Note that this may be null, e.g. if we're a distributed func inside
121134
// a protocol that did not declare a specific actor system requirement.
122135
return signature->getConcreteType(SerReqAssocType);
@@ -1222,46 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12221235
return true;
12231236
}
12241237

1225-
void
1226-
swift::extractDistributedSerializationRequirements(
1227-
ASTContext &C,
1228-
ArrayRef<Requirement> allRequirements,
1229-
llvm::SmallPtrSet<ProtocolDecl *, 2> &into) {
1230-
auto DA = C.getDistributedActorDecl();
1231-
auto daSerializationReqAssocType =
1232-
DA->getAssociatedType(C.Id_SerializationRequirement);
1233-
1234-
for (auto req : allRequirements) {
1235-
// FIXME: Seems unprincipled
1236-
if (req.getKind() != RequirementKind::SameType &&
1237-
req.getKind() != RequirementKind::Conformance)
1238-
continue;
1239-
1240-
if (auto dependentMemberType =
1241-
req.getFirstType()->getAs<DependentMemberType>()) {
1242-
if (dependentMemberType->getAssocType() == daSerializationReqAssocType) {
1243-
// auto requirementProto = req.getSecondType();
1244-
// if (auto proto = dyn_cast_or_null<ProtocolDecl>(
1245-
// requirementProto->getAnyNominal())) {
1246-
// into.insert(proto);
1247-
// } else {
1248-
// auto serialReqType = requirementProto->castTo<ExistentialType>()
1249-
// ->getConstraintType();
1250-
// auto flattenedRequirements =
1251-
// flattenDistributedSerializationTypeToRequiredProtocols(
1252-
// serialReqType.getPointer());
1253-
// for (auto p : flattenedRequirements) {
1254-
// into.insert(p);
1255-
// }
1256-
auto layout = req.getSecondType()->getExistentialLayout();
1257-
for (auto p : layout.getProtocols()) {
1258-
into.insert(p);
1259-
}
1260-
}
1261-
}
1262-
}
1263-
}
1264-
12651238
/******************************************************************************/
12661239
/********************** Distributed Functions *********************************/
12671240
/******************************************************************************/

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
377377
static bool checkDistributedTargetResultType(
378378
ModuleDecl *module, ValueDecl *valueDecl,
379379
Type serializationRequirement,
380+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements,
380381
bool diagnose) {
381382
auto &C = valueDecl->getASTContext();
382383

383-
if (!serializationRequirement || serializationRequirement->hasError())
384+
if (serializationRequirement && serializationRequirement->hasError()) {
385+
return false;
386+
}
387+
if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) {
384388
return false; // error of the type would be diagnosed elsewhere
389+
}
385390

386391
Type resultType;
387392
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
@@ -395,37 +400,43 @@ static bool checkDistributedTargetResultType(
395400
if (resultType->isVoid())
396401
return false;
397402

403+
404+
// Collect extra "SerializationRequirement: SomeProtocol" requirements
405+
if (serializationRequirement && !serializationRequirement->hasError()) {
406+
auto srl = serializationRequirement->getExistentialLayout();
407+
for (auto s: srl.getProtocols()) {
408+
serializationRequirements.insert(s);
409+
}
410+
}
411+
398412
auto isCodableRequirement =
399413
checkDistributedSerializationRequirementIsExactlyCodable(
400414
C, serializationRequirement);
401415

402-
if (serializationRequirement && !serializationRequirement->hasError()) {
403-
auto srl = serializationRequirement->getExistentialLayout();
404-
for (auto serializationReq: srl.getProtocols()) {
405-
auto conformance =
406-
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
407-
if (conformance.isInvalid()) {
408-
if (diagnose) {
409-
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
410-
"Codable" : // Codable is a typealias, easier to diagnose like that
411-
serializationReq->getNameStr();
412-
413-
auto diag = valueDecl->diagnose(
414-
diag::distributed_actor_target_result_not_codable,
415-
resultType,
416-
valueDecl,
417-
conformanceToSuggest
418-
);
419-
420-
if (isCodableRequirement) {
421-
if (auto resultNominalType = resultType->getAnyNominal()) {
422-
addCodableFixIt(resultNominalType, diag);
423-
}
416+
for (auto serializationReq: serializationRequirements) {
417+
auto conformance =
418+
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
419+
if (conformance.isInvalid()) {
420+
if (diagnose) {
421+
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
422+
"Codable" : // Codable is a typealias, easier to diagnose like that
423+
serializationReq->getNameStr();
424+
425+
auto diag = valueDecl->diagnose(
426+
diag::distributed_actor_target_result_not_codable,
427+
resultType,
428+
valueDecl,
429+
conformanceToSuggest
430+
);
431+
432+
if (isCodableRequirement) {
433+
if (auto resultNominalType = resultType->getAnyNominal()) {
434+
addCodableFixIt(resultNominalType, diag);
424435
}
425-
} // end if: diagnose
436+
}
437+
} // end if: diagnose
426438

427-
return true;
428-
}
439+
return true;
429440
}
430441
}
431442

@@ -493,16 +504,16 @@ bool CheckDistributedFunctionRequest::evaluate(
493504
}
494505

495506
auto &C = func->getASTContext();
496-
auto DC = func->getDeclContext();
497507
auto module = func->getParentModule();
498508

499509
/// If no distributed module is available, then no reason to even try checks.
500510
if (!C.getLoadedModule(C.Id_Distributed))
501511
return true;
502512

503-
Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func);
504-
for (auto param: *func->getParameters()) {
513+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
514+
Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements);
505515

516+
for (auto param: *func->getParameters()) {
506517
// --- Check the parameter conforming to serialization requirements
507518
if (serializationReqType && !serializationReqType->hasError()) {
508519
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -565,12 +576,11 @@ bool CheckDistributedFunctionRequest::evaluate(
565576
}
566577
}
567578

568-
if (serializationReqType && !serializationReqType->hasError()) {
569-
// --- Result type must be either void or a codable type
570-
if (checkDistributedTargetResultType(module, func, serializationReqType,
571-
/*diagnose=*/true)) {
572-
return true;
573-
}
579+
// --- Result type must be either void or a serialization requirement conforming type
580+
if (checkDistributedTargetResultType(
581+
module, func, serializationReqType, serializationRequirements,
582+
/*diagnose=*/true)) {
583+
return true;
574584
}
575585

576586
return false;
@@ -618,15 +628,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
618628
DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty();
619629
auto systemDecl = systemVar->getInterfaceType()->getAnyNominal();
620630

621-
// auto serializationRequirements =
622-
// getDistributedSerializationRequirementProtocols(
623-
// systemDecl,
624-
// C.getProtocol(KnownProtocolKind::DistributedActorSystem));
631+
auto serializationRequirements =
632+
getDistributedSerializationRequirementProtocols(
633+
systemDecl,
634+
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
635+
625636
auto serializationRequirement =
626-
getConcreteReplacementForMemberSerializationRequirement(systemVar);
637+
getSerializationRequirementTypesForMember(systemVar, serializationRequirements);
627638

628639
auto module = var->getModuleContext();
629-
if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) {
640+
if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) {
630641
return true;
631642
}
632643

test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ extension NoSerializationRequirementYet
8282

8383
extension NoSerializationRequirementYet
8484
where SerializationRequirement: Codable {
85-
// expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}}
85+
// expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}}
8686
distributed func test4() -> NotCodable {
8787
.init()
8888
}

0 commit comments

Comments
 (0)