@@ -386,11 +386,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
386386static bool checkDistributedTargetResultType (
387387 ModuleDecl *module , ValueDecl *valueDecl,
388388 Type serializationRequirement,
389+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements,
389390 bool diagnose) {
390391 auto &C = valueDecl->getASTContext ();
391392
392- if (!serializationRequirement || serializationRequirement->hasError ())
393+ if (serializationRequirement && serializationRequirement->hasError ()) {
394+ return false ;
395+ }
396+ if ((!serializationRequirement || serializationRequirement->hasError ()) && serializationRequirements.empty ()) {
393397 return false ; // error of the type would be diagnosed elsewhere
398+ }
394399
395400 Type resultType;
396401 if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
@@ -404,37 +409,43 @@ static bool checkDistributedTargetResultType(
404409 if (resultType->isVoid ())
405410 return false ;
406411
412+
413+ // Collect extra "SerializationRequirement: SomeProtocol" requirements
414+ if (serializationRequirement && !serializationRequirement->hasError ()) {
415+ auto srl = serializationRequirement->getExistentialLayout ();
416+ for (auto s: srl.getProtocols ()) {
417+ serializationRequirements.insert (s);
418+ }
419+ }
420+
407421 auto isCodableRequirement =
408422 checkDistributedSerializationRequirementIsExactlyCodable (
409423 C, serializationRequirement);
410424
411- if (serializationRequirement && !serializationRequirement->hasError ()) {
412- auto srl = serializationRequirement->getExistentialLayout ();
413- for (auto serializationReq: srl.getProtocols ()) {
414- auto conformance =
415- TypeChecker::conformsToProtocol (resultType, serializationReq, module );
416- if (conformance.isInvalid ()) {
417- if (diagnose) {
418- llvm::StringRef conformanceToSuggest = isCodableRequirement ?
419- " Codable" : // Codable is a typealias, easier to diagnose like that
420- serializationReq->getNameStr ();
421-
422- auto diag = valueDecl->diagnose (
423- diag::distributed_actor_target_result_not_codable,
424- resultType,
425- valueDecl,
426- conformanceToSuggest
427- );
428-
429- if (isCodableRequirement) {
430- if (auto resultNominalType = resultType->getAnyNominal ()) {
431- addCodableFixIt (resultNominalType, diag);
432- }
425+ for (auto serializationReq: serializationRequirements) {
426+ auto conformance =
427+ TypeChecker::conformsToProtocol (resultType, serializationReq, module );
428+ if (conformance.isInvalid ()) {
429+ if (diagnose) {
430+ llvm::StringRef conformanceToSuggest = isCodableRequirement ?
431+ " Codable" : // Codable is a typealias, easier to diagnose like that
432+ serializationReq->getNameStr ();
433+
434+ auto diag = valueDecl->diagnose (
435+ diag::distributed_actor_target_result_not_codable,
436+ resultType,
437+ valueDecl,
438+ conformanceToSuggest
439+ );
440+
441+ if (isCodableRequirement) {
442+ if (auto resultNominalType = resultType->getAnyNominal ()) {
443+ addCodableFixIt (resultNominalType, diag);
433444 }
434- } // end if: diagnose
445+ }
446+ } // end if: diagnose
435447
436- return true ;
437- }
448+ return true ;
438449 }
439450 }
440451
@@ -502,16 +513,16 @@ bool CheckDistributedFunctionRequest::evaluate(
502513 }
503514
504515 auto &C = func->getASTContext ();
505- auto DC = func->getDeclContext ();
506516 auto module = func->getParentModule ();
507517
508518 // / If no distributed module is available, then no reason to even try checks.
509519 if (!C.getLoadedModule (C.Id_Distributed ))
510520 return true ;
511521
512- Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement (func) ;
513- for ( auto param: *func-> getParameters ()) {
522+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements ;
523+ Type serializationReqType = getSerializationRequirementTypesForMember (func, serializationRequirements);
514524
525+ for (auto param: *func->getParameters ()) {
515526 // --- Check the parameter conforming to serialization requirements
516527 if (serializationReqType && !serializationReqType->hasError ()) {
517528 // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -574,12 +585,11 @@ bool CheckDistributedFunctionRequest::evaluate(
574585 }
575586 }
576587
577- if (serializationReqType && !serializationReqType->hasError ()) {
578- // --- Result type must be either void or a codable type
579- if (checkDistributedTargetResultType (module , func, serializationReqType,
580- /* diagnose=*/ true )) {
581- return true ;
582- }
588+ // --- Result type must be either void or a serialization requirement conforming type
589+ if (checkDistributedTargetResultType (
590+ module , func, serializationReqType, serializationRequirements,
591+ /* diagnose=*/ true )) {
592+ return true ;
583593 }
584594
585595 return false ;
@@ -627,15 +637,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
627637 DC->getSelfNominalTypeDecl ()->getDistributedActorSystemProperty ();
628638 auto systemDecl = systemVar->getInterfaceType ()->getAnyNominal ();
629639
630- // auto serializationRequirements =
631- // getDistributedSerializationRequirementProtocols(
632- // systemDecl,
633- // C.getProtocol(KnownProtocolKind::DistributedActorSystem));
640+ auto serializationRequirements =
641+ getDistributedSerializationRequirementProtocols (
642+ systemDecl,
643+ C.getProtocol (KnownProtocolKind::DistributedActorSystem));
644+
634645 auto serializationRequirement =
635- getConcreteReplacementForMemberSerializationRequirement (systemVar);
646+ getSerializationRequirementTypesForMember (systemVar, serializationRequirements );
636647
637648 auto module = var->getModuleContext ();
638- if (checkDistributedTargetResultType (module , var, serializationRequirement, diagnose)) {
649+ if (checkDistributedTargetResultType (module , var, serializationRequirement, serializationRequirements, diagnose)) {
639650 return true ;
640651 }
641652
0 commit comments